226 lines
12 KiB
Rust

mod config;
mod db;
mod dto;
mod metrics;
mod routes;
mod schema;
use std::fs::File;
use rayon_core::ThreadPoolBuilder;
use rustracing::tag::Tag;
use rustracing_jaeger::Span;
use tiny_http::{Method, Request, Response, ResponseBox, Server, StatusCode};
use crate::{
db::DBConnection,
metrics::TRACER,
routes::{get_reply, header, AppError}
};
static THREAD_COUNT: usize = 10;
fn index_resp() -> Response<File> {
Response::from_file(File::open(std::path::Path::new("./static/index.html")).unwrap())
.with_header(header("content-type", "text/html; charset=utf-8"))
}
fn parse_body<S>(span: &mut Span, req: &mut Request) -> Result<S, AppError>
where S: for<'a> serde::Deserialize<'a> {
let _span = TRACER.span("parse_body").child_of(span.context().unwrap()).start();
serde_json::from_reader(req.as_reader()).map_err(|_| AppError::BadRequest("Invalid query data"))
}
fn handle_request(mut req: Request, db: db::DBPool) {
let path = req.url().to_string();
let resp = if !path.starts_with("/api") {
match req.method() {
&Method::Get =>
if !(path.contains('\\') || path.contains("..") || path.contains(':')) {
let path_str = "./static".to_owned() + &path;
let path = std::path::Path::new(&path_str);
if path.exists() {
let resp = Response::from_file(File::open(path).unwrap());
match path.extension().map(|s| s.to_str()).unwrap_or(None) {
Some("html") => resp.with_header(header("content-type", "text/html; charset=utf-8")),
Some("css") => resp.with_header(header("content-type", "text/css; charset=utf-8")),
Some("js") => resp.with_header(header(
"content-type",
"application/x-javascript; charset=utf-8"
)),
Some("svg") => resp.with_header(header("content-type", "image/svg+xml")),
_ => resp
}
.boxed()
} else {
index_resp().boxed()
}
} else {
index_resp().boxed()
},
_ => Response::empty(StatusCode::from(405)).boxed()
}
} else {
let meth = req.method().clone();
let mut span = TRACER
.span("handle_api_request")
.tag(Tag::new("http.target", path))
.tag(Tag::new("http.method", meth.to_string()))
.start();
let resp = match handle_api_request(&mut span, &mut req, db) {
Ok(v) => v,
Err(v) => {
let code = match v {
AppError::BadRequest(_) => 400,
AppError::Unauthorized(_) => 401,
AppError::Forbidden(_) => 403,
AppError::NotFound => 404,
AppError::InternalError(_) => 500
};
Response::from_data(
serde_json::to_vec(&dto::responses::Error {
statusCode: code,
message: match v {
AppError::BadRequest(v) => v.to_string(),
AppError::Unauthorized(v) => v.to_string(),
AppError::Forbidden(v) => v.to_string(),
AppError::NotFound => "Not found".to_owned(),
AppError::InternalError(v) => v.to_string()
}
})
.unwrap()
)
.with_header(header("content-type", "application/json; charset=utf-8"))
.with_status_code(code)
.boxed()
}
};
span.set_tag(|| Tag::new("http.status_code", resp.status_code().0 as i64));
resp
};
req.respond(resp).expect("Failed to send response");
}
#[rustfmt::skip]
fn handle_api_request(span: &mut Span, req: &mut Request, pool: db::DBPool) -> Result<ResponseBox, AppError> {
metrics::REQUEST.inc();
let db = &mut db::DBConnection::from(pool.get().unwrap());
let (path, query) = {
let url = req.url().to_string();
let mut splits = url.splitn(2, '?');
(
splits.next().unwrap().to_string(),
splits.next().unwrap_or("").to_string()
)
};
match (path.as_str(), req.method()) {
("/api/metrics", Method::Get) => metrics::get_metrics(),
("/api/auth/login", Method::Post) => parse_body(span, req).and_then(|v|routes::auth::basic::login(span, req, db, v)),
("/api/auth/signup", Method::Post) => parse_body(span, req).and_then(|v| routes::auth::basic::signup(span, req, db, v)),
("/api/auth/gitlab", Method::Get) => routes::auth::gitlab::gitlab(span, req, db),
("/api/auth/gitlab_callback", Method::Get) => routes::auth::gitlab::gitlab_callback(span, req, db, &query),
("/api/fs/download", Method::Post) => routes::fs::routes::download(span, req, db),
("/api/fs/download_multi", Method::Post) => routes::fs::routes::download_multi(span, req, db),
_ => {
let span_auth = TRACER.span("parse_auth_and_path").child_of(span.context().unwrap()).start();
let header = req.headers().iter().find(|h| h.field.as_str().as_str().eq_ignore_ascii_case("Authorization"))
.ok_or(AppError::Unauthorized("Unauthorized"))?;
let auth = header.value.as_str();
let token = auth.starts_with("Bearer ").then(|| auth.trim_start_matches("Bearer "))
.ok_or(AppError::Unauthorized("Invalid auth header"))?;
let info = routes::filters::authorize_jwt(span, token, db)?;
let (path, last_id) = path.to_string().rsplit_once('/')
.map(|(short_path, last)|
last.parse::<i32>()
.map_or((path.clone(), None), |i| (short_path.to_string(), Some(i)))
)
.unwrap_or((path.to_string(), None));
drop(span_auth);
let span = &mut TRACER.span("handle_auth_request").child_of(span.context().unwrap()).start();
match (path.as_str(), req.method(), last_id) {
("/api/admin/users", Method::Get, None) => routes::admin::users(span, req, db, info),
("/api/admin/set_role", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::set_role(span, req, db, info, v)),
("/api/admin/logout", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::logout(span, req, db, info, v)),
("/api/admin/delete", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::delete_user(span, req, db, info, v)),
("/api/admin/disable_2fa", Method::Post, None) => parse_body(span, req).and_then(|v| routes::admin::disable_2fa(span, req, db, info, v)),
("/api/admin/is_admin", Method::Post, None) => get_reply(&dto::responses::Success { statusCode: 200 }),
("/api/admin/get_token", Method::Get, Some(v)) => routes::admin::get_token(span, req, db, info, v),
("/api/auth/refresh", Method::Post, None) => routes::auth::basic::refresh(span, req, db, info),
("/api/auth/logout_all", Method::Post, None) => routes::auth::basic::logout_all(span, req, db, info),
("/api/auth/change_password", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::basic::change_password(span, req, db, info, v)),
("/api/auth/2fa/setup", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::tfa::tfa_setup(span, req, db, info, v)),
("/api/auth/2fa/complete", Method::Post, None) => parse_body(span, req).and_then(|v| routes::auth::tfa::tfa_complete(span, req, db, info, v)),
("/api/auth/2fa/disable", Method::Post, None) => routes::auth::tfa::tfa_disable(span, req, db, info),
("/api/user/info", Method::Get, None) => routes::user::info(span, req, db, info),
("/api/user/delete", Method::Post, None) => routes::user::delete_user(span, req, db, info),
("/api/fs/root", Method::Get, None) => routes::fs::routes::root(span, req, db, info),
("/api/fs/node", Method::Get, Some(v)) => routes::fs::routes::node(span, req, db, info, v),
("/api/fs/path", Method::Get, Some(v)) => routes::fs::routes::path(span, req, db, info, v),
("/api/fs/create_folder", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_node(span, req, db, info, v, false)),
("/api/fs/create_file", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_node(span, req, db, info, v, true)),
("/api/fs/delete", Method::Post, Some(v)) => routes::fs::routes::delete_node(span, req, db, info, v, &pool),
("/api/fs/upload", Method::Post, Some(v)) => routes::fs::routes::upload(span, req, db, info, v),
("/api/fs/create_zip", Method::Post, None) => parse_body(span, req).and_then(|v| routes::fs::routes::create_zip(span, req, db, info, v, &pool)),
("/api/fs/download_preview", Method::Get, Some(v)) => routes::fs::routes::download_preview(span, req, db, info, v),
("/api/fs/get_type", Method::Get, Some(v)) => routes::fs::routes::get_type(span, req, db, info, v),
_ => AppError::NotFound.err()
}
}
}
}
fn main() {
let _ = config::CONFIG;
let db_pool: db::DBPool = db::build_pool();
db::run_migrations(&mut db_pool.get().unwrap());
if !std::path::Path::new("files").exists() {
std::fs::create_dir("files").expect("Failed to create files directory");
}
if !std::path::Path::new("temp").is_dir() {
std::fs::create_dir("temp").expect("Failed to create temp dir");
}
std::fs::read_dir("temp").expect("Failed to iter temp dir").for_each(|dir| {
std::fs::remove_file(dir.expect("Failed to retrieve temp dir entry").path())
.expect("Failed to delete file in temp dir");
});
metrics::init(DBConnection::from(db_pool.get().unwrap()));
let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let ctrlc_shutdown = shutdown.clone();
let server = std::sync::Arc::new(Server::http("0.0.0.0:2345").unwrap());
let ctrlc_server = server.clone();
ctrlc::set_handler(move || {
ctrlc_shutdown.store(true, std::sync::atomic::Ordering::Relaxed);
ctrlc_server.unblock();
})
.expect("Could not set ctrl-c handler");
let pool = ThreadPoolBuilder::new()
.num_threads(THREAD_COUNT)
.thread_name(|i| format!("Http listener {}", i))
.build()
.unwrap();
'server: loop {
match server.recv() {
Ok(req) => {
let inner_pool = db_pool.clone();
pool.spawn(move || handle_request(req, inner_pool))
}
Err(_) =>
if shutdown.load(std::sync::atomic::Ordering::Relaxed) {
break 'server;
},
}
}
println!("Quitting");
}