mod config; mod db; mod dto; mod metrics; mod routes; mod schema; use std::fs::File; use rayon_core::ThreadPoolBuilder; use rustracing::tag::Tag; use rustracing_jaeger::Span; use tiny_http::{Method, Request, Response, ResponseBox, Server, StatusCode}; use crate::{ db::DBConnection, metrics::TRACER, routes::{header, AppError} }; static THREAD_COUNT: usize = 10; fn index_resp() -> Response { Response::from_file(File::open(std::path::Path::new("./static/index.html")).unwrap()) .with_header(header("content-type", "text/html; charset=utf-8")) } fn parse_body(span: &mut Span, req: &mut Request) -> Result where S: for<'a> serde::Deserialize<'a> { let _span = TRACER.span("parse_body").child_of(span.context().unwrap()).start(); serde_json::from_reader(req.as_reader()).map_err(|_| AppError::BadRequest("Invalid query data")) } fn handle_request(mut req: Request, db: db::DBPool) { let path = req.url().to_string(); let resp = if !path.starts_with("/api") { match req.method() { &Method::Get => if !(path.contains('\\') || path.contains("..") || path.contains(':')) { let path_str = "./static".to_owned() + &path; let path = std::path::Path::new(&path_str); if path.is_file() { let resp = Response::from_file(File::open(path).unwrap()); match path.extension().map(|s| s.to_str()).unwrap_or(None) { Some("html") => resp.with_header(header("content-type", "text/html; charset=utf-8")), Some("css") => resp.with_header(header("content-type", "text/css; charset=utf-8")), Some("js") => resp.with_header(header( "content-type", "application/x-javascript; charset=utf-8" )), Some("svg") => resp.with_header(header("content-type", "image/svg+xml")), _ => resp } .boxed() } else { index_resp().boxed() } } else { index_resp().boxed() }, _ => Response::empty(StatusCode::from(405)).boxed() } } else { let meth = req.method().clone(); let mut span = TRACER .span("handle_api_request") .tag(Tag::new("http.target", path)) .tag(Tag::new("http.method", meth.to_string())) .start(); let resp = match handle_api_request(&mut span, &mut req, db) { Ok(v) => v, Err(v) => { let code = match v { AppError::BadRequest(_) => 400, AppError::Unauthorized(_) => 401, AppError::Forbidden(_) => 403, AppError::NotFound => 404, AppError::InternalError(_) => 500 }; let msg = match v { AppError::BadRequest(v) => v.to_string(), AppError::Unauthorized(v) => v.to_string(), AppError::Forbidden(v) => v.to_string(), AppError::NotFound => "Not found".to_owned(), AppError::InternalError(v) => v.to_string() }; span.set_tag(|| Tag::new("http.error_msg", msg.clone())); Response::from_data( serde_json::to_vec(&dto::responses::Error { statusCode: code, message: msg }) .unwrap() ) .with_header(header("content-type", "application/json; charset=utf-8")) .with_status_code(code) .boxed() } }; span.set_tag(|| Tag::new("http.status_code", resp.status_code().0 as i64)); resp }; req.respond(resp).expect("Failed to send response"); } #[rustfmt::skip] fn handle_api_request(span: &mut Span, req: &mut Request, pool: db::DBPool) -> Result { metrics::REQUEST.inc(); let db = &mut db::DBConnection::from(pool.get().unwrap()); let (path, query) = { let url = req.url().to_string(); let mut splits = url.splitn(2, '?'); ( splits.next().unwrap().to_string(), splits.next().unwrap_or("").to_string() ) }; match (path.as_str(), req.method()) { ("/api/metrics", Method::Get) => metrics::get_metrics(), ("/api/auth/login", Method::Post) => parse_body(span, req).and_then(|v|routes::auth::basic::login(span, req, db, v)), ("/api/auth/signup", Method::Post) => parse_body(span, req).and_then(|v| routes::auth::basic::signup(span, req, db, v)), ("/api/auth/send_key", Method::Post) => parse_body(span, req).and_then(|v| routes::auth::basic::send_key(span, req, db, v)), ("/api/auth/reset", Method::Post) => parse_body(span, req).and_then(|v| routes::auth::basic::reset(span, req, db, v)), ("/api/auth/gitlab", Method::Get) => routes::auth::gitlab::gitlab(span, req, db), ("/api/auth/gitlab_callback", Method::Get) => routes::auth::gitlab::gitlab_callback(span, req, db, &query), ("/api/fs/download", Method::Post) => routes::fs::routes::download(span, req, db), ("/api/fs/download_multi", Method::Post) => routes::fs::routes::download_multi(span, req, db), _ => { let span_auth = TRACER.span("parse_auth_and_path").child_of(span.context().unwrap()).start(); let header = req.headers().iter().find(|h| h.field.as_str().as_str().eq_ignore_ascii_case("Authorization")) .ok_or(AppError::Unauthorized("Unauthorized"))?; let auth = header.value.as_str(); let token = auth.starts_with("Bearer ").then(|| auth.trim_start_matches("Bearer ")) .ok_or(AppError::Unauthorized("Invalid auth header"))?; let info = routes::filters::authorize_jwt(span, token, db)?; let (path, last_id) = path.to_string().rsplit_once('/') .map(|(short_path, last)| last.parse::() .map_or((path.clone(), None), |i| (short_path.to_string(), Some(i))) ) .unwrap_or((path.to_string(), None)); drop(span_auth); let span = &mut TRACER.span("handle_auth_request").child_of(span.context().unwrap()).start(); match (path.as_str(), req.method(), last_id) { ("/api/admin/users", Method::Get, None) => routes::admin::users(span, req, db, info), ("/api/admin/set_role", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::set_role(span, req, db, info, v)), ("/api/admin/logout", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::logout(span, req, db, info, v)), ("/api/admin/delete", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::delete_user(span, req, db, info, v)), ("/api/admin/disable_2fa", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::disable_2fa(span, req, db, info, v)), ("/api/admin/is_admin", Method::Get, None) => routes::admin::is_admin(span, req, db, info), ("/api/admin/get_token", Method::Get, Some(v)) => routes::admin::get_token(span, req, db, info, v), ("/api/auth/refresh", Method::Post, None) => routes::auth::basic::refresh(span, req, db, info), ("/api/auth/logout_all", Method::Post, None) => routes::auth::basic::logout_all(span, req, db, info), ("/api/auth/change_password", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::basic::change_password(span, req, db, info, v)), ("/api/auth/2fa/setup", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::tfa::tfa_setup(span, req, db, info, v)), ("/api/auth/2fa/complete", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::tfa::tfa_complete(span, req, db, info, v)), ("/api/auth/2fa/disable", Method::Post, None) => routes::auth::tfa::tfa_disable(span, req, db, info), ("/api/user/info", Method::Get, None) => routes::user::info(span, req, db, info), ("/api/user/delete", Method::Post, None) => routes::user::delete_user(span, req, db, info), ("/api/fs/root", Method::Get, None) => routes::fs::routes::root(span, req, db, info), ("/api/fs/node", Method::Get, Some(v)) => routes::fs::routes::node(span, req, db, info, v), ("/api/fs/path", Method::Get, Some(v)) => routes::fs::routes::path(span, req, db, info, v), ("/api/fs/create_folder", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_node(span, req, db, info, v, false)), ("/api/fs/create_file", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_node(span, req, db, info, v, true)), ("/api/fs/delete", Method::Post, Some(v)) => routes::fs::routes::delete_node(span, req, db, info, v, &pool), ("/api/fs/upload", Method::Post, Some(v)) => routes::fs::routes::upload(span, req, db, info, v), ("/api/fs/create_zip", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_zip(span, req, db, info, v, &pool)), ("/api/fs/download_preview", Method::Get, Some(v)) => routes::fs::routes::download_preview(span, req, db, info, v), ("/api/fs/get_type", Method::Get, Some(v)) => routes::fs::routes::get_type(span, req, db, info, v), ("/api/fs/move", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::move_node(span, req, db, info, v)), _ => AppError::NotFound.err() } } } } fn main() { println!("Loading config..."); let _ = config::CONFIG; let db_pool: db::DBPool = db::build_pool(); println!("Running migrations..."); db::run_migrations(&mut db_pool.get().unwrap()); if !std::path::Path::new("files").exists() { std::fs::create_dir("files").expect("Failed to create files directory"); } if !std::path::Path::new("temp").is_dir() { std::fs::create_dir("temp").expect("Failed to create temp dir"); } println!("Cleaning up temp..."); std::fs::read_dir("temp").expect("Failed to iter temp dir").for_each(|dir| { std::fs::remove_file(dir.expect("Failed to retrieve temp dir entry").path()) .expect("Failed to delete file in temp dir"); }); println!("Loading metrics..."); metrics::init(DBConnection::from(db_pool.get().unwrap())); let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); let ctrlc_shutdown = shutdown.clone(); let server = std::sync::Arc::new(Server::http("0.0.0.0:2345").unwrap()); let ctrlc_server = server.clone(); ctrlc::set_handler(move || { ctrlc_shutdown.store(true, std::sync::atomic::Ordering::Relaxed); ctrlc_server.unblock(); }) .expect("Could not set ctrl-c handler"); let pool = ThreadPoolBuilder::new() .num_threads(THREAD_COUNT) .thread_name(|i| format!("Http listener {}", i)) .build() .unwrap(); println!("Listening on 0.0.0.0:2345"); 'server: loop { match server.recv() { Ok(req) => { let inner_pool = db_pool.clone(); pool.spawn(move || handle_request(req, inner_pool)) } Err(_) => if shutdown.load(std::sync::atomic::Ordering::Relaxed) { break 'server; }, } } println!("Goodbye"); }