Rewrote the server in cpp with the frontend in svelte

This commit is contained in:
2023-10-20 13:02:21 +02:00
commit 03b22ebb61
4168 changed files with 831370 additions and 0 deletions

78
src/data/data.cxx Normal file
View File

@@ -0,0 +1,78 @@
#include <fstream>
#include <inipp.h>
#include "data_internal.hxx"
static constexpr std::string ini_name = "config.ini";
std::shared_ptr<spdlog::logger> data_logger;
Data::Data() {
if (!data_logger)
data_logger = spdlog::default_logger()->clone("data");
if (!std::filesystem::exists(files_dir))
std::filesystem::create_directory(files_dir);
load_config();
load();
validate();
start_save_thread();
}
Data::~Data() {
shutdown();
}
void Data::shutdown() {
data_logger->info("Stopping data saver");
shutdown_flag.test_and_set();
save();
save_thread.join();
data_logger->info("Data saver stopped");
}
void Data::load_config() {
inipp::Ini<char> ini;
if (!std::filesystem::exists(ini_name)) {
ini.sections["web"] = {
{"port", "80"}
};
ini.sections["smtp"] = {
{"host", "127.0.0.1"},
{"port", "25"},
{"user", "username"},
{"pass", "password"},
{"from", "fileserver@example.com"}
};
ini.sections["admin"] = {
{"mail", "admin@example.com"}
};
std::ofstream f{ini_name};
ini.generate(f);
data_logger->critical("Missing config, generated example");
crash();
}
data_logger->info("Loading config");
std::ifstream f{ini_name};
ini.parse(f);
if (!ini.errors.empty()) {
data_logger->critical("Failed to parse config:");
for (const auto &err : ini.errors)
data_logger->critical(err);
crash();
}
if (!ini.sections.contains("web")) { data_logger->critical("Missing section web"); crash(); }
if (!ini.sections.contains("smtp")) { data_logger->critical("Missing section smtp"); crash(); }
if (!ini.sections.contains("admin")) { data_logger->critical("Missing section admin"); crash(); }
std::string entry;
#define get_entry(section, key) if (!ini.sections[#section].contains(#key)) {data_logger->critical("Missing " #section ":" #key); crash();} entry = ini.sections[#section][#key]
get_entry(web, port); config.server_port = std::stoul(entry);
get_entry(smtp, host); config.smtp_host = entry;
get_entry(smtp, port); config.smtp_port = std::stoul(entry);
get_entry(smtp, user); config.smtp_user = entry;
get_entry(smtp, pass); config.smtp_pass = entry;
get_entry(smtp, from); config.smtp_from = entry;
get_entry(admin, mail); config.admin_mail = entry;
#undef get_entry
}

81
src/data/data.hxx Normal file
View File

@@ -0,0 +1,81 @@
#ifndef FILESERVER_DATA_HXX
#define FILESERVER_DATA_HXX
#include <optional>
#include <list>
#include <unordered_map>
#include <string>
#include <chrono>
#include <mutex>
#include <shared_mutex>
#include <filesystem>
#include <thread>
#include <atomic>
#include <cstdint>
static const std::filesystem::path files_dir = "files";
struct Node {
std::uint64_t id;
std::string name;
bool file, preview;
std::shared_ptr<Node> parent;
std::uint64_t size;
std::list<std::shared_ptr<Node>> children;
};
struct User {
std::uint64_t id;
std::string name, password, tfa_secret;
bool enabled, admin, tfa_enabled, tfa_mail;
std::uint64_t next_node_id = 1;
std::unordered_map<std::uint64_t, std::shared_ptr<Node>> nodes;
std::filesystem::path user_dir;
std::shared_mutex node_lock;
};
struct Token {
using clock = std::chrono::steady_clock;
static constexpr clock::duration token_lifetime = std::chrono::minutes{60};
explicit Token(const std::shared_ptr<User> &user) : user(user) { refresh(); }
void refresh() { expire = clock::now() + token_lifetime; }
clock::time_point expire;
std::shared_ptr<User> user, sudo_original_user{nullptr};
};
struct Config {
std::string smtp_host, smtp_user, smtp_pass, smtp_from, admin_mail; // TODO: Send mail to admin on crash
std::uint16_t smtp_port, server_port;
};
struct Data {
static constexpr std::uint64_t current_version = 1;
Data();
~Data();
void load_config();
void load();
void validate();
void save();
void start_save_thread();
void shutdown();
Config config;
std::uint64_t version = current_version, next_user_id = 0;
std::unordered_map<std::uint64_t, std::shared_ptr<User>> users;
std::unordered_map<std::string, std::shared_ptr<Token>> tokens;
std::unordered_map<std::string, std::pair<std::uint64_t, Token::clock::time_point>> mail_otp;
std::unordered_map<std::string, std::pair<std::uint64_t, Token::clock::time_point>> recovery_keys;
std::thread save_thread;
std::atomic_flag save_flag, shutdown_flag;
std::shared_mutex user_lock, token_lock;
std::mutex mail_otp_lock, recovery_lock;
};
#endif //FILESERVER_DATA_HXX

View File

@@ -0,0 +1,34 @@
#ifndef FILESERVER_DATA_INTERNAL_HXX
#define FILESERVER_DATA_INTERNAL_HXX
#include <filesystem>
#include <cstdio>
#include <spdlog/spdlog.h>
#include "data.hxx"
#include "../util/crash.hxx"
extern std::shared_ptr<spdlog::logger> data_logger;
static constexpr std::string data_new_file = "data.new.json";
static constexpr std::string data_cur_file = "data.json";
static constexpr std::string data_old_file = "data.old.json";
struct SaveNode {
std::unique_ptr<Node> node;
std::uint64_t id, parent_id;
std::list<std::uint64_t> children_ids;
};
struct FileWrapper {
FileWrapper(const char *file, const char *mode) {
f = std::fopen(file, mode);
}
~FileWrapper() {
std::fclose(f);
}
std::FILE *f;
};
#endif //FILESERVER_DATA_INTERNAL_HXX

143
src/data/data_load.cxx Normal file
View File

@@ -0,0 +1,143 @@
#define RAPIDJSON_HAS_STDSTRING 1
#include <rapidjson/filereadstream.h>
#include <rapidjson/document.h>
#include <rapidjson/error/en.h>
#include "data_internal.hxx"
#define FIND_MEMBER(name, ty) m = doc.FindMember(#name); \
if (m == doc.MemberEnd() || !m->value.Is##ty()) { data_logger->error("Missing or invalid "#name); throw std::exception{}; }
#define ASSIGN_MEMBER(lhs, name, ty) FIND_MEMBER(name, ty) lhs = m->value.Get##ty()
SaveNode load_node(const rapidjson::Value &doc) {
rapidjson::Value::ConstMemberIterator m;
auto node = std::make_unique<Node>();
std::uint64_t id, parent;
node->parent = nullptr;
ASSIGN_MEMBER(node->id, id, Uint64);
id = node->id;
data_logger->debug("Loading node {}", id);
ASSIGN_MEMBER(node->name, name, String);
ASSIGN_MEMBER(node->file, file, Bool);
ASSIGN_MEMBER(node->preview, preview, Bool);
ASSIGN_MEMBER(node->size, size, Uint64);
ASSIGN_MEMBER(parent, parent, Uint64);
std::list<std::uint64_t> children;
FIND_MEMBER(children, Array)
for (const auto &v : m->value.GetArray())
if (!v.IsUint64()) { data_logger->error("Invalid child id"); throw std::exception{}; }
else children.push_back(v.GetUint64());
return {
.node = std::move(node),
.id = id,
.parent_id = parent,
.children_ids = children
};
}
std::unique_ptr<User> load_user(const rapidjson::Value &doc) {
rapidjson::Value::ConstMemberIterator m;
auto user = std::make_unique<User>();
ASSIGN_MEMBER(user->id, id, Uint64);
data_logger->debug("Loading user {}", user->id);
ASSIGN_MEMBER(user->name, name, String);
ASSIGN_MEMBER(user->password, password, String);
ASSIGN_MEMBER(user->tfa_secret, tfa_secret, String);
ASSIGN_MEMBER(user->enabled, enabled, Bool);
ASSIGN_MEMBER(user->admin, admin, Bool);
ASSIGN_MEMBER(user->tfa_enabled, tfa_enabled, Bool);
ASSIGN_MEMBER(user->tfa_mail, tfa_mail, Bool);
ASSIGN_MEMBER(user->next_node_id, next_node_id, Uint64);
std::list<SaveNode> nodes;
FIND_MEMBER(nodes, Array)
for (const auto &v : m->value.GetArray())
nodes.push_back(load_node(v));
for (auto &node : nodes)
user->nodes.emplace(node.id, node.node.release());
for (const auto &snode : nodes) {
auto &node = user->nodes.at(snode.id);
if (node->id != 0)
node->parent = user->nodes.at(snode.parent_id);
for (const auto &child : snode.children_ids)
node->children.push_back(user->nodes.at(child));
}
user->user_dir = files_dir / std::to_string(user->id);
return user;
}
bool load_from_file(Data &data, const std::string &file) {
char buf[65536];
bool success = false;
data_logger->info("Loading data from {}", file);
data.users.clear();
auto start = std::chrono::high_resolution_clock::now();
try {
FileWrapper f{file.c_str(), "r"};
rapidjson::FileReadStream fstream{f.f, buf, sizeof(buf)};
rapidjson::Document doc;
rapidjson::Document::MemberIterator m;
doc.ParseStream(fstream);
if (doc.HasParseError()) {
data_logger->error("Failed to parse data");
auto e = rapidjson::GetParseError_En(doc.GetParseError());
data_logger->error("{}", e);
throw std::exception{};
}
ASSIGN_MEMBER(data.version, version, Uint64);
if (data.version != Data::current_version) {
data_logger->error("Version is {}, current version is {}. Refusing to load!", data.version, Data::current_version);
throw std::exception{};
}
ASSIGN_MEMBER(data.next_user_id, next_user_id, Uint64);
FIND_MEMBER(users, Array)
for (const auto &v : m->value.GetArray()) {
auto user = load_user(v);
data.users.emplace(user->id, std::move(user));
}
success = true;
} catch (std::exception &_) {}
auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - start);
data_logger->info("Data loading took {} ms", dur.count());
return success;
}
void Data::load() {
if (std::filesystem::exists(data_cur_file)) {
bool ok = load_from_file(*this, data_cur_file);
if (!ok) {
data_logger->warn("Retrying from old file");
if (!std::filesystem::exists(data_old_file)) {
data_logger->critical("Old file missing");
crash();
}
ok = load_from_file(*this, data_old_file);
if (!ok) {
data_logger->critical("Failed to load data");
crash();
}
}
} else if (std::filesystem::exists(data_old_file)) {
data_logger->warn("Data file missing, loading from old file");
bool ok = load_from_file(*this, data_old_file);
if (!ok) {
data_logger->critical("Failed to load data");
crash();
}
} else data_logger->info("No data file found for loading");
}

109
src/data/data_save.cxx Normal file
View File

@@ -0,0 +1,109 @@
#define RAPIDJSON_HAS_STDSTRING 1
#include <rapidjson/filewritestream.h>
#include <rapidjson/writer.h>
#include "data_internal.hxx"
#define KEY(x) writer.Key(#x, sizeof(#x)-1)
using Writer = rapidjson::Writer<rapidjson::FileWriteStream>;
void save_node(Writer &writer, Node *node) {
writer.StartObject();
KEY(id); writer.Uint64(node->id);
KEY(name); writer.String(node->name);
KEY(file); writer.Bool(node->file);
KEY(preview); writer.Bool(node->preview);
KEY(size); writer.Uint64(node->size);
KEY(parent); writer.Uint64(node->parent == nullptr ? 0 : node->parent->id);
KEY(children);
writer.StartArray();
for (const auto &child : node->children)
writer.Uint64(child->id);
writer.EndArray();
writer.EndObject();
}
void save_user(Writer &writer, User *user) {
std::shared_lock lock{user->node_lock};
writer.StartObject();
KEY(id); writer.Uint64(user->id);
KEY(name); writer.String(user->name);
KEY(password); writer.String(user->password);
KEY(tfa_secret); writer.String(user->tfa_secret);
KEY(enabled); writer.Bool(user->enabled);
KEY(admin); writer.Bool(user->admin);
KEY(tfa_enabled); writer.Bool(user->tfa_enabled);
KEY(tfa_mail); writer.Bool(user->tfa_mail);
KEY(next_node_id); writer.Uint64(user->next_node_id);
KEY(nodes);
writer.StartArray();
for (const auto &node : user->nodes)
save_node(writer, node.second.get());
writer.EndArray();
writer.EndObject();
}
char save_buf[65536];
void save(Data* data) {
data_logger->info("Saving data");
try {
{
FileWrapper f{data_new_file.c_str(), "w"};
rapidjson::FileWriteStream stream{f.f, save_buf, sizeof(save_buf)};
Writer writer{stream};
std::shared_lock lock{data->user_lock};
writer.StartObject();
KEY(version); writer.Uint64(data->version);
KEY(next_user_id); writer.Uint64(data->next_user_id);
KEY(users);
writer.StartArray();
for (const auto &user : data->users)
save_user(writer, user.second.get());
writer.EndArray();
writer.EndObject();
}
data_logger->info("Finished writing data");
if (std::filesystem::exists(data_cur_file))
std::filesystem::copy_file(data_cur_file, data_old_file, std::filesystem::copy_options::overwrite_existing);
std::filesystem::rename(data_new_file, data_cur_file);
data_logger->info("Save done");
} catch (std::exception &e) {
data_logger->error("Error while saving: {}, retrying...", e.what());
data->save_flag.test_and_set();
}
}
void save_worker(Data *data) {
while (!data->shutdown_flag.test()) {
data->save_flag.wait(false);
do {
data->save_flag.clear();
std::this_thread::sleep_for(std::chrono::seconds{2});
} while (data->save_flag.test());
save(data);
}
data_logger->info("Data saver stopping");
save(data);
}
void Data::start_save_thread() {
save();
shutdown_flag.clear();
this->save_thread = std::thread{save_worker, this};
}
void Data::save() {
save_flag.test_and_set();
save_flag.notify_all();
}

View File

@@ -0,0 +1,71 @@
#include "data_internal.hxx"
template<typename... Args>
inline void ok(bool ok, spdlog::format_string_t<Args...> fmt, Args &&...args) {
if (!ok) {
data_logger->critical("Data validation failed:");
data_logger->critical(fmt, std::forward<Args>(args)...);
crash();
}
}
void validate_node(const std::string &username, const std::filesystem::path &user_dir, Node *node) {
auto name = "User " + username + " node " + std::to_string(node->id) + ":" + node->name;
if (node->id != 0) {
ok(!node->name.empty(), "{} is missing name", name);
ok(node->parent != nullptr, "{} is missing parent", name);
}
if (node->file) {
auto file = user_dir / std::to_string(node->id);
ok(std::filesystem::exists(file), "{} is missing file on disk", name);
auto size = std::filesystem::file_size(file);
ok(node->size == size, "{} size does not match (Node,Disk): {} != {}", name, node->size, size);
if (node->preview)
ok(std::filesystem::exists(file.replace_extension("png")), "{} is missing preview on disk", name);
}
}
void validate_user(User *user) {
auto name = std::to_string(user->id) + ":" + user->name;
ok(!user->name.empty(), "User {} is missing name", user->id);
ok(!user->password.empty(), "User {} is missing password", name);
ok(user->nodes.contains(0), "User {} is missing root", name);
if (user->tfa_enabled && !user->tfa_mail)
ok(!user->tfa_secret.empty(), "User {} is missing tfa secret", name);
ok(!user->user_dir.empty(), "User {} user_dir is not set", name);
ok(std::filesystem::exists(user->user_dir), "User {} is missing files directory", name);
std::uint64_t last_node_id = 0;
for (const auto &node : user->nodes) {
validate_node(name, user->user_dir, node.second.get());
last_node_id = std::max(last_node_id, node.first);
}
ok(user->next_node_id > last_node_id, "User {}: Next node id {} must be larger than the last used id {}", name, user->next_node_id, last_node_id);
}
void Data::validate() {
data_logger->info("Validating data");
auto start = std::chrono::high_resolution_clock::now();
try {
std::uint64_t last_user_id = 0;
for (const auto &user: users) {
validate_user(user.second.get());
last_user_id = std::max(last_user_id, user.first);
}
if (!users.empty())
ok(next_user_id > last_user_id, "Next user id {} must be larger than the last used id {}", next_user_id, last_user_id);
} catch (std::exception &e) {
data_logger->critical("Data validation failed with error: {}", e.what());
crash();
}
auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - start);
data_logger->info("Data validation took {} ms", dur.count());
data_logger->info("Data validated");
}

85
src/main.cxx Normal file
View File

@@ -0,0 +1,85 @@
#include <memory>
#include <csignal>
#include <corvusoft/restbed/resource.hpp>
#include <corvusoft/restbed/session.hpp>
#include <corvusoft/restbed/settings.hpp>
#include <corvusoft/restbed/service.hpp>
#include <spdlog/sinks/basic_file_sink.h>
#include "util/logging.hxx"
#include "server/server.hxx"
#include "index_html.h"
#include "favicon_svg.h"
std::shared_ptr<restbed::Service> g_service = nullptr;
const static restbed::Bytes index_html_bytes{index_html, index_html + index_html_len};
const static restbed::Bytes favicon_bytes{favicon_svg, favicon_svg + favicon_svg_len};
void signal_shutdown(const int) {
spdlog::info("Recieved stop signal");
g_service->stop();
}
int main() {
auto file_sink = std::make_shared<spdlog::sinks::basic_file_sink_mt>("log.txt");
spdlog::default_logger()->sinks().push_back(file_sink);
spdlog::set_level(spdlog::level::trace);
auto mrpc_resource = std::make_shared<restbed::Resource>();
mrpc_resource->set_path("/mrpc");
Server server{mrpc_resource};
#define mk_res(url) auto url##_resource = std::make_shared<restbed::Resource>(); \
url##_resource->set_path("/" #url); \
url##_resource->set_method_handler("POST", [&server](auto s){ server.url(s); })
mk_res(download);
mk_res(download_multi);
mk_res(upload);
#undef mk_res
auto index_resource = std::make_shared<restbed::Resource>();
index_resource->set_path("/");
index_resource->set_method_handler("GET", [](const std::shared_ptr<restbed::Session>& s){
s->yield(
200,
index_html_bytes,
std::multimap<std::string, std::string>{
{"Content-Type", "text/html"},
{"Content-Length", std::to_string(index_html_len)}
}
);
});
auto favicon_resource = std::make_shared<restbed::Resource>();
favicon_resource->set_path("/favicon.svg");
favicon_resource->set_method_handler("GET", [](const std::shared_ptr<restbed::Session>& s){
s->yield(
200,
favicon_bytes,
std::multimap<std::string, std::string>{
{"Content-Type", "image/svg+xml"},
{"Content-Length", std::to_string(favicon_svg_len)}
}
);
});
auto settings = std::make_shared<restbed::Settings>();
settings->set_port(server.config.server_port);
settings->set_default_header("Connection", "keep-alive");
g_service = std::make_shared<restbed::Service>();
g_service->set_logger(std::make_shared<logging::RestbedLogger>());
g_service->set_signal_handler(SIGINT, signal_shutdown);
g_service->set_signal_handler(SIGTERM, signal_shutdown);
g_service->publish(mrpc_resource);
g_service->publish(download_resource);
g_service->publish(download_multi_resource);
g_service->publish(upload_resource);
g_service->publish(index_resource);
g_service->publish(favicon_resource);
g_service->start(settings);
g_service.reset();
return 0;
}

102
src/server/admin.cxx Normal file
View File

@@ -0,0 +1,102 @@
#include <spdlog/spdlog.h>
#include "server_internal.hxx"
#define check_admin_response() check_user_response(); if (!user->admin) return { .e = "Forbidden" };
#define check_admin_optional() check_user_optional(); if (!user->admin) return "Forbidden";
// TODO Log admin action
mrpc::Response<std::vector<mrpc::UserInfo>> Server::Admin_list_users(std::string &&token) {
check_admin_response();
{
std::shared_lock lock{user_lock};
std::vector<mrpc::UserInfo> info;
info.reserve(users.size());
for (const auto &us : users) {
const auto u = us.second.get();
info.push_back(mrpc::UserInfo {
.id = u->id,
.name = u->name,
.tfa = u->tfa_enabled,
.admin = u->admin,
.enabled = u->enabled
});
}
return { .o = std::move(info) };
}
}
std::optional<std::string> Server::Admin_delete_user(std::string &&token, std::uint64_t &&user_id) {
check_admin_optional();
auto target = get_user(user_id);
if (!target) return "Invalid user";
delete_user(target);
save();
return std::nullopt;
}
std::optional<std::string> Server::Admin_logout(std::string &&token, std::uint64_t &&user_id) {
check_admin_optional();
logout_user(user_id);
return std::nullopt;
}
std::optional<std::string> Server::Admin_disable_tfa(std::string &&token, std::uint64_t &&user_id) {
check_admin_optional();
auto u = get_user(user_id);
if (u) u->tfa_enabled = false;
save();
return std::nullopt;
}
std::optional<std::string> Server::Admin_set_admin(std::string &&token, std::uint64_t &&user_id, bool &&admin) {
check_admin_optional();
auto u = get_user(user_id);
if (u) u->admin = admin;
save();
return std::nullopt;
}
std::optional<std::string> Server::Admin_set_enabled(std::string &&token, std::uint64_t &&user_id, bool &&enabled) {
check_admin_optional();
auto u = get_user(user_id);
if (u) u->enabled = enabled;
save();
return std::nullopt;
}
std::optional<std::string> Server::Admin_sudo(std::string &&token, std::uint64_t &&user_id) {
check_admin_optional();
auto u = get_user(user_id);
if (!u)
return "Invalid user";
{
std::unique_lock tlock{token_lock};
auto &t = tokens.at(token);
t->sudo_original_user = user;
t->user = u;
t->refresh();
}
return std::nullopt;
}
std::optional<std::string> Server::Admin_unsudo(std::string &&token) {
check_user_optional();
{
std::unique_lock lock{token_lock};
auto &t = tokens.at(token);
if (t->sudo_original_user == nullptr)
return "Unauthorized";
t->user = t->sudo_original_user;
t->sudo_original_user = nullptr;
t->refresh();
}
return std::nullopt;
}
std::optional<std::string> Server::Admin_shutdown(std::string &&token) {
check_admin_optional();
spdlog::info("Received rpc shutdown request from admin user {}", user->name);
g_service->stop();
return std::nullopt;
}

226
src/server/auth.cxx Normal file
View File

@@ -0,0 +1,226 @@
#include <botan_all.h>
#include "server_internal.hxx"
std::string hash_password(const std::string &password) {
return Botan::argon2_generate_pwhash(password.c_str(), password.size(), *auth_rng, 1, 64*1024, 4);
}
std::optional<std::string> Server::Auth_signup(std::string &&username, std::string &&password) {
std::unique_lock lock{user_lock};
for (const auto &user : users) {
if (user.second->name == username)
return "User already exists";
}
if (password.length() < 6)
return "Password must be 6 characters or longer";
auto hash = hash_password(password);
auto id = next_user_id++;
std::shared_ptr<Node> root{new Node{
.id = 0,
.name = "",
.file = false,
.preview = false,
.parent = nullptr,
.size = 0
}};
std::shared_ptr<User> user{new User{
.id = id,
.name = username,
.password = hash,
.enabled = false,
.admin = false,
.tfa_enabled = false
}};
user->nodes.emplace(0, std::move(root));
user->user_dir = files_dir / std::to_string(user->id);
std::filesystem::create_directory(user->user_dir);
users.emplace(id, std::move(user));
save();
return std::nullopt;
}
mrpc::Response<mrpc::LoginResponse> Server::Auth_login(std::string &&username, std::string &&password, std::optional<std::string> &&otp) {
std::shared_lock lock{user_lock};
std::shared_ptr<User> user = nullptr;
for (const auto &u : users) {
if (u.second->name == username) {
user = u.second;
break;
}
}
if (user == nullptr)
return { .e = "Invalid username or password" };
if (!Botan::argon2_check_pwhash(password.c_str(), password.size(), user->password))
return { .e = "Invalid username or password" };
if (!user->enabled)
return { .e = "User is disabled" };
if (user->tfa_enabled) {
if (otp.has_value()) {
bool ok = user->tfa_mail ? check_mail_code(user, otp.value()) : check_tfa_code(user, otp.value());
if (!ok)
return { .e = "Invalid code" };
} else {
if (user->tfa_mail)
send_tfa_mail(user);
mrpc::LoginResponse response { .otp_needed = true };
return { .o = std::move(response) };
}
}
auto token = Botan::hex_encode(auth_rng->random_vec(16), false);
{
std::unique_lock tlock{token_lock};
tokens.emplace(token, std::make_unique<Token>(user));
}
mrpc::LoginResponse response {
.otp_needed = false,
.token = token
};
return { .o = std::move(response) };
}
void Server::Auth_logout(std::string &&token) {
std::unique_lock lock{token_lock};
tokens.erase(token);
}
std::optional<std::string> Server::Auth_logout_all(std::string &&token) {
check_user_optional();
logout_user(user->id);
return std::nullopt;
}
void Server::Auth_send_recovery_key(std::string &&username) {
std::shared_lock lock{user_lock};
User *user = nullptr;
for (const auto &u : users) {
if (u.second->name == username) {
user = u.second.get();
break;
}
}
if (user == nullptr)
return;
auto code = Botan::hex_encode(auth_rng->random_vec(5), true);
{
std::lock_guard rlock{recovery_lock};
recovery_keys.emplace(code, std::make_pair(user->id, Token::clock::now() + std::chrono::minutes{5}));
}
send_mail(user->name, "MFileserver - Password recovery", "Your recovery key is: " + code + "\r\nIt is valid for 5 minutes");
}
std::optional<std::string> Server::Auth_reset_password(std::string &&key, std::string &&password) {
std::lock_guard lock{recovery_lock};
auto now = Token::clock::now();
for (auto it = recovery_keys.begin(); it != recovery_keys.end();) {
if (now >= it->second.second)
recovery_keys.erase(it++);
else
++it;
}
auto entry = recovery_keys.find(key);
if (entry == recovery_keys.end())
return "Invalid key";
auto user = get_user(entry->second.first);
if (!user)
return "Invalid key";
if (password.length() < 6)
return "Password must be 6 characters or longer";
recovery_keys.erase(key);
logout_user(user->id);
{
std::shared_lock ulock{user_lock};
user->password = hash_password(password);
}
save();
return std::nullopt;
}
std::optional<std::string> Server::Auth_change_password(std::string &&token, std::string &&old_pw, std::string &&new_pw) {
check_user_optional();
std::shared_lock lock{user_lock};
if (!Botan::argon2_check_pwhash(old_pw.c_str(), old_pw.size(), user->password))
return "Old password is wrong";
if (new_pw.length() < 6)
return "Password must be 6 characters or longer";
user->password = hash_password(new_pw);
save();
return std::nullopt;
}
std::optional<std::string> Server::Auth_tfa_setup_mail(std::string &&token) {
check_user_optional();
if (user->tfa_enabled)
return "Tfa is already enabled";
user->tfa_mail = true;
send_tfa_mail(user);
return std::nullopt;
}
mrpc::Response<std::string> Server::Auth_tfa_setup_totp(std::string &&token) {
check_user_response();
if (user->tfa_enabled)
return { .e = "Tfa is already enabled" };
user->tfa_mail = false;
Botan::OctetString secret{*auth_rng, 16};
user->tfa_secret = secret.to_string();
return { .o = Botan::base32_encode(secret) };
}
std::optional<std::string> Server::Auth_tfa_complete(std::string &&token, std::string &&otp) {
check_user_optional();
bool ok = user->tfa_mail ? check_mail_code(user, otp) : check_tfa_code(user, otp);
if (!ok)
return "Invalid code";
user->tfa_enabled = true;
logout_user(user->id);
save();
return std::nullopt;
}
std::optional<std::string> Server::Auth_tfa_disable(std::string &&token) {
check_user_optional();
user->tfa_enabled = false;
logout_user(user->id);
save();
return std::nullopt;
}
std::optional<std::string> Server::Auth_delete_user(std::string &&token) {
check_user_optional();
delete_user(user);
save();
return std::nullopt;
}
mrpc::Response<mrpc::Session> Server::Auth_session_info(std::string &&token) {
check_user_response();
auto t = get_token(token);
if (!t)
return { .e = "Invalid token" };
mrpc::Session info {
.name = user->name,
.tfa_enabled = user->tfa_enabled,
.admin = user->admin,
.sudo = t->sudo_original_user != nullptr
};
return { .o = info };
}

183
src/server/download.cxx Normal file
View File

@@ -0,0 +1,183 @@
#include <filesystem>
#include <string_view>
#include <ranges>
#include <fstream>
#include <deque>
#include <charconv>
#include <miniz.h>
#include <corvusoft/restbed/session.hpp>
#include <corvusoft/restbed/request.hpp>
#include <corvusoft/restbed/response.hpp>
#include "server_internal.hxx"
void Server::download(const std::shared_ptr<restbed::Session> &s) {
const auto req = s->get_request();
std::size_t body_len = req->get_header("Content-Length", 0);
s->fetch(body_len, [this](const std::shared_ptr<restbed::Session> &s, const restbed::Bytes &b){
std::string body{b.cbegin(), b.cend()};
if (body.empty())
return s->close(400, "empty body");
std::string node_str, token;
for (const auto part : std::views::split(body, '&')) {
std::string_view part_view{part};
auto equal_pos = part_view.find_first_of('=');
auto key = part_view.substr(0, equal_pos);
if (key == "node")
node_str = part_view.substr(equal_pos+1);
else if (key == "token")
token = part_view.substr(equal_pos+1);
}
if (node_str.empty())
return s->close(400, "Missing node");
if (token.empty())
return s->close(400, "Missing token");
std::uint64_t node_id;
auto res = std::from_chars(node_str.data(), node_str.data() + node_str.size(), node_id);
if (res.ec != std::errc{})
return s->close(400, "Invalid node");
check_user() return s->close(400, "Invalid user");
{
std::shared_lock lock{user->node_lock};
auto node = get_node(user, node_id);
if (!node) return s->close(400, "Invalid node");
auto mime = get_mime_type(node->name);
s->yield(
200,
"",
std::multimap<std::string, std::string>{
{"Content-Type", mime},
{"Content-Length", std::to_string(node->size)},
{"Content-Disposition", "attachment; filename=\"" + node->name + "\""}
},
[&](const std::shared_ptr<restbed::Session>& s) {
std::shared_lock lock{user->node_lock};
restbed::Bytes buf(1024*1024*4, 0);
std::ifstream f{user->user_dir / std::to_string(node->id)};
while (!f.eof()) {
f.read((char*)buf.data(), buf.size());
buf.resize(f.gcount());
s->yield(buf);
}
s->close();
}
);
}
});
}
size_t zip_write_func(void *pOpaque, mz_uint64 _file_ofs, const void *pBuf, size_t n) {
auto s = (restbed::Session*)pOpaque;
if (n > 0) {
restbed::Bytes buf(n, 0);
std::memcpy(buf.data(), pBuf, n);
std::stringstream ss;
ss << std::hex << n;
s->yield(ss.str() + "\r\n");
s->yield(buf);
s->yield("\r\n");
}
return n;
}
void Server::download_multi(const std::shared_ptr<restbed::Session> &s) {
const auto req = s->get_request();
const auto body_len = req->get_header("Content-Length", 0);
s->fetch(body_len, [this](const std::shared_ptr<restbed::Session> &s, const restbed::Bytes &b){
std::string body{b.cbegin(), b.cend()};
if (body.empty())
return s->close(400, "empty body");
std::string nodes_str, token;
for (const auto part : std::views::split(body, '&')) {
std::string_view part_view{part};
auto equal_pos = part_view.find_first_of('=');
auto key = part_view.substr(0, equal_pos);
if (key == "nodes")
nodes_str = part_view.substr(equal_pos+1);
else if (key == "token")
token = part_view.substr(equal_pos+1);
}
if (nodes_str.empty())
return s->close(400, "Missing nodes");
if (token.empty())
return s->close(400, "Missing token");
std::vector<std::uint64_t> node_ids;
for (const auto part : std::views::split(nodes_str, '.')) {
std::uint64_t node_id;
auto res = std::from_chars(part.data(), part.data() + part.size(), node_id);
if (res.ec != std::errc{})
return s->close(400, "Invalid node " + std::string{std::string_view{part}});
node_ids.push_back(node_id);
}
check_user() return s->close(400, "Invalid user");
{
std::shared_lock lock{user->node_lock};
std::vector<std::shared_ptr<Node>> nodes;
for (auto node_id : node_ids) {
auto node = get_node(user, node_id);
if (!node) return s->close(400, "Invalid node " + std::to_string(node_id));
nodes.push_back(node);
}
s->yield(
200,
"",
std::multimap<std::string, std::string>{
{"Content-Type", "application/zip"},
{"Content-Disposition", "attachment; filename=\"files.zip\""},
{"Transfer-Encoding", "chunked"}
}
);
std::thread zip_thread{[nodes = nodes, user = user, s = s] {
std::shared_lock lock{user->node_lock};
mz_zip_archive archive;
mz_zip_zero_struct(&archive);
archive.m_pWrite = zip_write_func;
archive.m_pIO_opaque = s.get();
mz_zip_writer_init_v2(&archive, 0, MZ_ZIP_FLAG_WRITE_ZIP64);
std::deque<std::pair<std::shared_ptr<Node>, std::filesystem::path>> todo;
for (const auto &node : nodes)
todo.emplace_back(node, std::filesystem::path{});
auto handle_file = [&user, &archive](const std::pair<std::shared_ptr<Node>, std::filesystem::path> &i) {
auto path = i.second / i.first->name;
auto real_path = user->user_dir / std::to_string(i.first->id);
mz_zip_writer_add_file(&archive, path.c_str(), real_path.c_str(), nullptr, 0, MZ_DEFAULT_COMPRESSION);
};
while (!todo.empty()) {
const auto &node = todo.front();
if (node.first->file) {
handle_file(node);
} else {
auto path = node.second / node.first->name;
auto dir_path = path.string() + "/";
mz_zip_writer_add_mem(&archive, dir_path.c_str(), nullptr, 0, 0);
for (const auto &child : node.first->children) {
auto p = std::make_pair(child, path);
if (child->file)
handle_file(p);
else
todo.push_back(p);
}
}
todo.pop_front();
}
mz_zip_writer_finalize_archive(&archive);
mz_zip_writer_end(&archive);
s->close("0\r\n\r\n");
}};
zip_thread.detach();
}
});
}

281
src/server/fs.cxx Normal file
View File

@@ -0,0 +1,281 @@
#include <memory>
#include <fstream>
#include <stack>
#include <unordered_set>
#include "server_internal.hxx"
mrpc::Node node_to_node(const std::shared_ptr<Node>& node) {
return {
.id = node->id,
.name = node->name,
.file = node->file,
.preview = node->preview,
.parent = node->id == 0 ? std::nullopt : std::optional{node->parent->id},
.size = node->file ? std::optional{node->size} : std::nullopt,
.children = std::nullopt
};
}
std::shared_ptr<Node> get_node(const std::shared_ptr<User>& user, std::uint64_t id) {
std::shared_lock lock{user->node_lock};
const auto &entry = user->nodes.find(id);
return entry == user->nodes.end() ? nullptr : entry->second;
}
std::string get_path(std::shared_ptr<Node> node) {
std::string ret;
while(node) {
if (node->id != 0) {
ret.insert(0, node->name);
ret.insert(0, "/");
}
node = node->parent;
}
return ret.empty() ? "/" : ret;
}
void Server::delete_node(const std::shared_ptr<User> &user, std::uint64_t id, const std::function<void(std::string)>& log) {
std::unique_lock lock{user->node_lock};
std::stack<std::shared_ptr<Node>> todo;
{
auto start = user->nodes.find(id);
if (start == user->nodes.end()) return;
todo.push(start->second);
}
while (!todo.empty()) {
auto node = todo.top();
auto log_path = get_path(node);
if (!node->children.empty()) {
log("Entering " + log_path + "\n");
for (const auto &child : node->children)
todo.push(child);
node->children.clear();
continue;
}
log("Deleting " + log_path + "...");
if (node->file) {
auto path = user->user_dir / std::to_string(node->id);
std::filesystem::remove(path);
if (node->preview)
std::filesystem::remove(path.replace_extension("png"));
}
if (node->parent)
node->parent->children.remove(node);
node->parent.reset();
user->nodes.erase(node->id);
log(" Done\n");
todo.pop();
}
}
std::uint64_t Server::nodes_size(const std::shared_ptr<User> &user, const std::vector<std::uint64_t> &ids) {
std::uint64_t total = 0;
std::deque<Node*> todo;
for (const auto &id : ids) {
auto node = get_node(user, id);
if (node)
todo.push_back(node.get());
}
while (!todo.empty()) {
auto node = todo.front();
if (node->file) {
total += node->size;
} else {
for (const auto &child : node->children)
todo.push_back(child.get());
}
todo.pop_front();
}
return total;
}
mrpc::Response<mrpc::Node> Server::FS_get_node(std::string &&token, std::uint64_t &&node_id) {
check_user_response();
auto node = get_node(user, node_id);
if (!node) return { .e = "Invalid node" };
mrpc::Node n = node_to_node(node);
if (!node->file) {
std::vector<mrpc::Node> children;
children.reserve(node->children.size());
for (const auto &child : node->children)
children.push_back(node_to_node(child));
n.children = std::move(children);
}
return { .o = std::move(n) };
}
mrpc::Response<std::vector<mrpc::PathSegment>> Server::FS_get_path(std::string &&token, std::uint64_t &&node_id) {
check_user_response();
std::shared_lock lock{user->node_lock};
auto node = get_node(user, node_id);
if (!node) return { .e = "Invalid node" };
std::vector<mrpc::PathSegment> segments;
while (node) {
if (node->file) {
segments.push_back(std::move(mrpc::PathSegment { .name = node->name }));
} else {
segments.push_back(std::move(mrpc::PathSegment { .name = node->id == 0 ? "/" : node->name, .id = node->id }));
}
node = node->parent;
}
std::reverse(segments.begin(), segments.end());
return { .o = std::move(segments) };
}
mrpc::Response<std::uint64_t> Server::FS_get_nodes_size(std::string &&token, std::vector<std::uint64_t> &&nodes) {
check_user_response();
std::shared_lock lock{user->node_lock};
return { .o = nodes_size(user, nodes) };
}
mrpc::Response<mrpc::CreateNodeInfo> Server::FS_create_node(std::string &&token, bool &&file, std::uint64_t &&parent_id, std::string &&name) {
check_user_response();
auto parent = get_node(user, parent_id);
if (!parent) return { .e = "Invalid parent" };
{
std::unique_lock lock{user->node_lock};
std::shared_ptr<Node> child = nullptr;
for (const auto &c: parent->children) {
if (c->name == name) {
child = c;
break;
}
}
if (child)
return {.o = mrpc::CreateNodeInfo{
.id = child->id,
.exists = true,
.file = child->file
}};
auto id = user->next_node_id++;
child = std::make_shared<Node>(Node{
.id = id,
.name = name,
.file = file,
.preview = false,
.parent = parent,
.size = 0
});
if (file) {
auto path = files_dir / std::to_string(user->id) / std::to_string(id);
std::ofstream{path};
}
user->nodes.emplace(id, child);
parent->children.push_back(child);
save();
return {.o = mrpc::CreateNodeInfo{
.id = id,
.exists = false,
.file = file
}};
}
}
std::optional<std::string> Server::FS_move_nodes(std::string &&token, std::vector<std::uint64_t> &&node_ids_vec, std::uint64_t &&parent_id) {
check_user_optional();
auto parent = get_node(user, parent_id);
if (!parent) return "Invalid parent";
if (node_ids_vec.empty())
return std::nullopt;
std::unordered_set<std::uint64_t> node_ids{node_ids_vec.begin(), node_ids_vec.end()};
{
std::unique_lock lock{user->node_lock};
{
auto node_parent = parent;
while (parent) {
if (node_ids.contains(node_parent->id))
return "Tried to move node into one of it's subfolders";
node_parent = node_parent->parent;
}
}
std::vector<std::shared_ptr<Node>> nodes;
std::unordered_set<std::string> node_names;
nodes.reserve(node_ids.size());
node_names.reserve(node_ids.size());
for (auto id : node_ids) {
auto node = get_node(user, id);
if (!node) return "Invalid node " + std::to_string(id);
if (!node_names.emplace(node->name).second)
return "Tried to move multiple nodes with the name " + node->name;
nodes.push_back(node);
}
for (const auto &child : parent->children) {
if (node_names.contains(child->name))
return "Tried to overwrite existing folder/file";
}
for (const auto &node : nodes) {
node->parent->children.remove(node);
node->parent = parent;
parent->children.push_back(node);
}
}
save();
return std::nullopt;
}
void Server::FS_delete_nodes(std::string &&token, std::vector<std::uint64_t> &&nodes, mrpc::MRPCStream<std::string> &&stream) {
check_user() {
stream.close();
return;
}
std::thread deleter{[this, nodes = std::move(nodes), user = std::move(user), stream = std::move(stream)](){
for (const auto& node : nodes) {
if (node == 0)
continue;
delete_node(user, node, [&stream](const std::string &log){ stream.send(log); });
}
stream.close();
save();
}};
deleter.detach();
}
mrpc::Response<std::string> Server::FS_download_preview(std::string &&token, std::uint64_t &&node_id) {
check_user_response();
std::shared_lock lock{user->node_lock};
auto node = get_node(user, node_id);
if (!node) return { .e = "Invalid node" };
if (!node->preview)
return { .e = "No preview" };
std::vector<std::uint8_t> preview;
auto path = (user->user_dir / std::to_string(node_id)).replace_extension("png");
auto size = std::filesystem::file_size(path);
preview.resize(size);
std::FILE *f = std::fopen(path.c_str(), "rb");
std::fread(preview.data(), sizeof(std::uint8_t), size, f);
std::fclose(f);
return { .o = Botan::base64_encode(preview) };
}
mrpc::Response<std::string> Server::FS_get_mime(std::string &&token, std::uint64_t &&node_id) {
check_user_response();
auto node = get_node(user, node_id);
if (!node) return { .e = "Invalid node" };
if (!node->file) return { .e = "Node is a directory" };
return { .o = get_mime_type(node->name) };
}

136
src/server/mail.cxx Normal file
View File

@@ -0,0 +1,136 @@
#include <asio.hpp>
#include <botan_all.h>
#include <botan_asio/asio_stream.h>
#include <spdlog/spdlog.h>
#include "server_internal.hxx"
struct SocketData {
asio::streambuf socket_buf;
std::istream socket_istream{&socket_buf};
std::string last_send;
};
struct Exception : public std::exception {
explicit Exception(std::string msg) : msg(std::move(msg)) {}
[[nodiscard]] const char* what() const noexcept override { return msg.c_str(); }
std::string msg;
};
template<typename Socket>
void expect_code(SocketData &data, Socket &s, const std::string& code) {
std::string line;
do {
std::string buf;
asio::read_until(s, data.socket_buf, "\n");
std::getline(data.socket_istream, line, '\n');
} while (line[3] == '-');
if (std::string_view{line}.substr(0, 3) != code)
throw Exception{"Request: '" + data.last_send + "' Expected: " + code + " got: " + line};
}
template<typename Socket>
void send(SocketData &data, Socket &s, std::string l) {
data.last_send = l.substr(0, l.find_first_of('\r'));
l += "\r\n";
asio::write(s, asio::buffer(l, l.size()));
}
std::string get_date() {
char buf[80];
auto t = std::time(nullptr);
auto ti = localtime (&t);
std::strftime(buf, sizeof(buf), "%a, %d %B %Y %T %z", ti);//Format time as string
return std::string{buf};
}
std::string get_hostname() {
try {
return asio::ip::host_name();
} catch (std::exception &_) {
return "";
}
}
struct CredMan : public Botan::Credentials_Manager {
Botan::System_Certificate_Store str;
CredMan() = default;
std::vector<Botan::Certificate_Store*> trusted_certificate_authorities(const std::string&, const std::string&) override {
return {&str};
}
};
struct Policy : public Botan::TLS::Strict_Policy {
[[nodiscard]] bool require_cert_revocation_info() const override { return false; }
};
void Server::send_mail(const std::string &email, const std::string &title, const std::string &body) {
static std::string host_name = get_hostname();
try {
std::string msg = fmt::format(
"Date: {}\r\n"
"To: <{}>\r\n"
"From: MFileserver <{}>\r\n"
"Subject: {}\r\n"
"\r\n"
"{}\r\n.",
get_date(), email, config.smtp_from, title, body
);
asio::io_service ctx;
auto ssl_ctx = std::make_shared<Botan::TLS::Context>(
std::make_shared<CredMan>(),
std::make_shared<Botan::AutoSeeded_RNG>(),
std::make_shared<Botan::TLS::Session_Manager_Noop>(),
std::make_shared<Policy>()
);
asio::ip::tcp::socket s{ctx};
asio::ip::tcp::resolver res{ctx};
SocketData data;
asio::connect(s, res.resolve(config.smtp_host, std::to_string(config.smtp_port)));
expect_code(data, s, "220");
send(data, s, "EHLO " + host_name);
expect_code(data, s, "250");
send(data, s, "STARTTLS");
expect_code(data, s, "220");
// switch_to_ssl
Botan::TLS::Stream<asio::ip::tcp::socket> ss(std::move(s), ssl_ctx);
ss.handshake(Botan::TLS::Connection_Side::Client);
send(data, ss, "EHLO " + host_name);
expect_code(data, ss, "250");
send(data, ss, "AUTH LOGIN");
expect_code(data, ss, "334");
send(data, ss, Botan::base64_encode((std::uint8_t*)config.smtp_user.data(), config.smtp_user.size()));
expect_code(data, ss, "334");
send(data, ss, Botan::base64_encode((std::uint8_t*)config.smtp_pass.data(), config.smtp_pass.size()));
expect_code(data, ss, "235");
send(data, ss, "MAIL FROM:<" + config.smtp_from + ">");
expect_code(data, ss, "250");
send(data, ss, "RCPT TO:<" + email + ">");
expect_code(data, ss, "250");
send(data, ss, "DATA");
expect_code(data, ss, "354");
send(data, ss, msg);
expect_code(data, ss, "250");
send(data, ss, "QUIT");
expect_code(data, ss, "221");
} catch (const std::exception &e) {
spdlog::error("Failed to send mail");
spdlog::error(e.what());
}
}

View File

@@ -0,0 +1,618 @@
#include "fileserver.hxx"
#include <corvusoft/restbed/session.hpp>
#include <corvusoft/restbed/resource.hpp>
#include <corvusoft/restbed/request.hpp>
using namespace mrpc;
inline std::string& operator<<(std::string &v, const rapidjson::Value &j) {
if (!j.IsString())
throw std::exception{};
v = j.GetString();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::string &v, mrpc::MRPCJWriter &w) {
w.String(v);
return w;
}
inline std::int8_t& operator<<(std::int8_t &v, const rapidjson::Value &j) {
if (!j.IsInt())
throw std::exception{};
v = j.GetInt();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::int8_t &v, mrpc::MRPCJWriter &w) {
w.Int(v);
return w;
}
inline std::int16_t& operator<<(std::int16_t &v, const rapidjson::Value &j) {
if (!j.IsInt())
throw std::exception{};
v = j.GetInt();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::int16_t &v, mrpc::MRPCJWriter &w) {
w.Int(v);
return w;
}
inline std::int32_t& operator<<(std::int32_t &v, const rapidjson::Value &j) {
if (!j.IsInt())
throw std::exception{};
v = j.GetInt();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::int32_t &v, mrpc::MRPCJWriter &w) {
w.Int(v);
return w;
}
inline std::int64_t& operator<<(std::int64_t &v, const rapidjson::Value &j) {
if (!j.IsInt64())
throw std::exception{};
v = j.GetInt64();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::int64_t &v, mrpc::MRPCJWriter &w) {
w.Int64(v);
return w;
}
inline std::uint8_t& operator<<(std::uint8_t &v, const rapidjson::Value &j) {
if (!j.IsUint())
throw std::exception{};
v = j.GetUint();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::uint8_t &v, mrpc::MRPCJWriter &w) {
w.Uint(v);
return w;
}
inline std::uint16_t& operator<<(std::uint16_t &v, const rapidjson::Value &j) {
if (!j.IsUint())
throw std::exception{};
v = j.GetUint();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::uint16_t &v, mrpc::MRPCJWriter &w) {
w.Uint(v);
return w;
}
inline std::uint32_t& operator<<(std::uint32_t &v, const rapidjson::Value &j) {
if (!j.IsUint())
throw std::exception{};
v = j.GetUint();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::uint32_t &v, mrpc::MRPCJWriter &w) {
w.Uint(v);
return w;
}
inline std::uint64_t& operator<<(std::uint64_t &v, const rapidjson::Value &j) {
if (!j.IsUint64())
throw std::exception{};
v = j.GetUint64();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::uint64_t &v, mrpc::MRPCJWriter &w) {
w.Uint64(v);
return w;
}
inline bool& operator<<(bool &v, const rapidjson::Value &j) {
if (!j.IsBool())
throw std::exception{};
v = j.GetBool();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const bool &v, mrpc::MRPCJWriter &w) {
w.Bool(v);
return w;
}
inline std::float_t& operator<<(std::float_t &v, const rapidjson::Value &j) {
if (!j.IsDouble())
throw std::exception{};
v = j.GetDouble();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::float_t &v, mrpc::MRPCJWriter &w) {
w.Double(v);
return w;
}
inline std::double_t& operator<<(std::double_t &v, const rapidjson::Value &j) {
if (!j.IsDouble())
throw std::exception{};
v = j.GetDouble();
return v;
}
inline mrpc::MRPCJWriter& operator>>(const std::double_t &v, mrpc::MRPCJWriter &w) {
w.Double(v);
return w;
}
inline mrpc::MRPCJWriter& operator>>(const std::nullptr_t &, mrpc::MRPCJWriter &w) {
w.Null();
return w;
}
template<typename T>
inline std::vector<T>& operator<<(std::vector<T> &v, const rapidjson::Value &j);
template<typename T>
inline std::optional<T>& operator<<(std::optional<T> &v, const rapidjson::Value &j) {
if (j.IsNull())
v = std::nullopt;
else {
T t;
t << j;
v = std::move(t);
}
return v;
}
template<typename T>
inline std::vector<T>& operator<<(std::vector<T> &v, const rapidjson::Value &j) {
if (!j.IsArray())
throw std::exception{};
for (const auto &e : j.GetArray()) {
T t;
t << e;
v.push_back(std::move(t));
}
return v;
}
template<typename T>
inline mrpc::MRPCJWriter& operator>>(const std::vector<T> &v, mrpc::MRPCJWriter &w);
template<typename T>
inline mrpc::MRPCJWriter& operator>>(const std::optional<T> &v, mrpc::MRPCJWriter &w) {
if (v.has_value())
v.value() >> w;
else
w.Null();
return w;
}
template<typename T>
inline mrpc::MRPCJWriter& operator>>(const std::vector<T> &v, mrpc::MRPCJWriter &w) {
w.StartArray();
for (const auto &e : v)
e >> w;
w.EndArray();
return w;
}
inline const rapidjson::Value& json_get(const rapidjson::Value &j, const char *key) {
auto member = j.FindMember(key);
if (member == j.MemberEnd())
throw std::exception{};
return member->value;
}
namespace mrpc {
template<typename T>
MRPCJWriter& Response<T>::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("e", 1);
e >> __w;
__w.Key("o", 1);
o >> __w;
__w.EndObject();
return __w;
}
template<typename T>
Response<T>& Response<T>::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
e << json_get(__j, "e");
o << json_get(__j, "o");
return *this;
}
MRPCJWriter& LoginResponse::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("otp_needed", 10);
otp_needed >> __w;
__w.Key("token", 5);
token >> __w;
__w.EndObject();
return __w;
}
LoginResponse& LoginResponse::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
otp_needed << json_get(__j, "otp_needed");
token << json_get(__j, "token");
return *this;
}
MRPCJWriter& Session::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("name", 4);
name >> __w;
__w.Key("tfa_enabled", 11);
tfa_enabled >> __w;
__w.Key("admin", 5);
admin >> __w;
__w.Key("sudo", 4);
sudo >> __w;
__w.EndObject();
return __w;
}
Session& Session::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
name << json_get(__j, "name");
tfa_enabled << json_get(__j, "tfa_enabled");
admin << json_get(__j, "admin");
sudo << json_get(__j, "sudo");
return *this;
}
MRPCJWriter& UserInfo::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("id", 2);
id >> __w;
__w.Key("name", 4);
name >> __w;
__w.Key("tfa", 3);
tfa >> __w;
__w.Key("admin", 5);
admin >> __w;
__w.Key("enabled", 7);
enabled >> __w;
__w.EndObject();
return __w;
}
UserInfo& UserInfo::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
id << json_get(__j, "id");
name << json_get(__j, "name");
tfa << json_get(__j, "tfa");
admin << json_get(__j, "admin");
enabled << json_get(__j, "enabled");
return *this;
}
MRPCJWriter& Node::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("id", 2);
id >> __w;
__w.Key("name", 4);
name >> __w;
__w.Key("file", 4);
file >> __w;
__w.Key("preview", 7);
preview >> __w;
__w.Key("parent", 6);
parent >> __w;
__w.Key("size", 4);
size >> __w;
__w.Key("children", 8);
children >> __w;
__w.EndObject();
return __w;
}
Node& Node::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
id << json_get(__j, "id");
name << json_get(__j, "name");
file << json_get(__j, "file");
preview << json_get(__j, "preview");
parent << json_get(__j, "parent");
size << json_get(__j, "size");
children << json_get(__j, "children");
return *this;
}
MRPCJWriter& CreateNodeInfo::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("id", 2);
id >> __w;
__w.Key("exists", 6);
exists >> __w;
__w.Key("file", 4);
file >> __w;
__w.EndObject();
return __w;
}
CreateNodeInfo& CreateNodeInfo::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
id << json_get(__j, "id");
exists << json_get(__j, "exists");
file << json_get(__j, "file");
return *this;
}
MRPCJWriter& ZipInfo::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("done", 4);
done >> __w;
__w.Key("progress", 8);
progress >> __w;
__w.Key("total", 5);
total >> __w;
__w.EndObject();
return __w;
}
ZipInfo& ZipInfo::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
done << json_get(__j, "done");
progress << json_get(__j, "progress");
total << json_get(__j, "total");
return *this;
}
MRPCJWriter& PathSegment::operator>>(MRPCJWriter &__w) const {
__w.StartObject();
__w.Key("name", 4);
name >> __w;
__w.Key("id", 2);
id >> __w;
__w.EndObject();
return __w;
}
PathSegment& PathSegment::operator<<(const rapidjson::Value &__j) {
using namespace mrpc;
name << json_get(__j, "name");
id << json_get(__j, "id");
return *this;
}
template<typename T>
void send_msg(const std::shared_ptr<restbed::Session> &c, const T &v) {
if (c->is_closed())
return;
rapidjson::StringBuffer s;
mrpc::MRPCJWriter writer{s};
v >> writer;
const auto body_ptr = s.GetString();
const auto body = restbed::Bytes{body_ptr, body_ptr+s.GetLength()};
c->yield(
200,
body,
std::multimap<std::string, std::string>{
{"Content-Type", "application/json"},
{"Content-Length", std::to_string(body.size())}
}
);
}
template<typename T>
void send_sse_msg(const std::shared_ptr<restbed::Session> &c, const T &v) {
if (c->is_closed())
return;
rapidjson::StringBuffer s;
std::memcpy(s.Push(5), "data:", 5);
mrpc::MRPCJWriter writer{s};
v >> writer;
std::memcpy(s.Push(2), "\n\n", 2);
const auto body_ptr = s.GetString();
const auto body = restbed::Bytes{body_ptr, body_ptr+s.GetLength()};
c->yield(body);
}
mrpc::MRPCStreamImpl::MRPCStreamImpl(const std::shared_ptr<restbed::Session> &conn) : conn(conn) {
conn->yield(
200,
std::multimap<std::string, std::string>{
{"Cache-Control", "no-cache"},
{"Content-Type", "text/event-stream"}
}
);
}
void mrpc::MRPCStreamImpl::close() const noexcept { conn->close("data:null\n\n"); }
bool mrpc::MRPCStreamImpl::is_open() const noexcept { return conn->is_open(); }
template<> void MRPCStream<std::string>::send(const std::string &v) const noexcept { send_sse_msg(conn, v); }
mrpc::MRPCServer::MRPCServer(std::shared_ptr<restbed::Resource> &r) {
r->set_method_handler("POST", [this](const std::shared_ptr<restbed::Session>& s) {
const auto req = s->get_request();
const auto body_len = req->get_header("Content-Length", 0);
s->fetch(body_len, [this](const std::shared_ptr<restbed::Session> &s, auto &&body) {
try { msg_handler(s, body); }
catch (const std::exception &_) { s->close(400); }
});
});
}
void mrpc::MRPCServer::msg_handler(const std::shared_ptr<restbed::Session> __c, const restbed::Bytes &__msg) {
rapidjson::Document __j;
__j.Parse((const char*)__msg.data(), __msg.size());
if (__j.HasParseError())
throw std::exception{};
std::string __service, __method;
__service << json_get(__j, "service");
__method << json_get(__j, "method");
auto __data_member = __j.FindMember("data");
if (__data_member == __j.MemberEnd() || !__data_member->value.IsObject())
throw std::exception{};
auto &__data = __data_member->value;
if (__service == "Auth") {
if (__method == "signup") {
std::string username; username << json_get(__data, "username");
std::string password; password << json_get(__data, "password");
send_msg(__c, Auth_signup(std::move(username), std::move(password)));
} else if (__method == "login") {
std::string username; username << json_get(__data, "username");
std::string password; password << json_get(__data, "password");
std::optional<std::string> otp; otp << json_get(__data, "otp");
send_msg(__c, Auth_login(std::move(username), std::move(password), std::move(otp)));
} else if (__method == "send_recovery_key") {
std::string username; username << json_get(__data, "username");
Auth_send_recovery_key(std::move(username)); send_msg(__c, nullptr);
} else if (__method == "reset_password") {
std::string key; key << json_get(__data, "key");
std::string password; password << json_get(__data, "password");
send_msg(__c, Auth_reset_password(std::move(key), std::move(password)));
} else if (__method == "change_password") {
std::string token; token << json_get(__data, "token");
std::string old_pw; old_pw << json_get(__data, "old_pw");
std::string new_pw; new_pw << json_get(__data, "new_pw");
send_msg(__c, Auth_change_password(std::move(token), std::move(old_pw), std::move(new_pw)));
} else if (__method == "logout") {
std::string token; token << json_get(__data, "token");
Auth_logout(std::move(token)); send_msg(__c, nullptr);
} else if (__method == "logout_all") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_logout_all(std::move(token)));
} else if (__method == "tfa_setup_mail") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_tfa_setup_mail(std::move(token)));
} else if (__method == "tfa_setup_totp") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_tfa_setup_totp(std::move(token)));
} else if (__method == "tfa_complete") {
std::string token; token << json_get(__data, "token");
std::string otp; otp << json_get(__data, "otp");
send_msg(__c, Auth_tfa_complete(std::move(token), std::move(otp)));
} else if (__method == "tfa_disable") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_tfa_disable(std::move(token)));
} else if (__method == "delete_user") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_delete_user(std::move(token)));
} else if (__method == "session_info") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Auth_session_info(std::move(token)));
}
else { throw std::exception{}; }
} else if (__service == "Admin") {
if (__method == "list_users") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Admin_list_users(std::move(token)));
} else if (__method == "delete_user") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
send_msg(__c, Admin_delete_user(std::move(token), std::move(user)));
} else if (__method == "logout") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
send_msg(__c, Admin_logout(std::move(token), std::move(user)));
} else if (__method == "disable_tfa") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
send_msg(__c, Admin_disable_tfa(std::move(token), std::move(user)));
} else if (__method == "set_admin") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
bool admin; admin << json_get(__data, "admin");
send_msg(__c, Admin_set_admin(std::move(token), std::move(user), std::move(admin)));
} else if (__method == "set_enabled") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
bool enabled; enabled << json_get(__data, "enabled");
send_msg(__c, Admin_set_enabled(std::move(token), std::move(user), std::move(enabled)));
} else if (__method == "sudo") {
std::string token; token << json_get(__data, "token");
std::uint64_t user; user << json_get(__data, "user");
send_msg(__c, Admin_sudo(std::move(token), std::move(user)));
} else if (__method == "unsudo") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Admin_unsudo(std::move(token)));
} else if (__method == "shutdown") {
std::string token; token << json_get(__data, "token");
send_msg(__c, Admin_shutdown(std::move(token)));
}
else { throw std::exception{}; }
} else if (__service == "FS") {
if (__method == "get_node") {
std::string token; token << json_get(__data, "token");
std::uint64_t node; node << json_get(__data, "node");
send_msg(__c, FS_get_node(std::move(token), std::move(node)));
} else if (__method == "get_path") {
std::string token; token << json_get(__data, "token");
std::uint64_t node; node << json_get(__data, "node");
send_msg(__c, FS_get_path(std::move(token), std::move(node)));
} else if (__method == "get_nodes_size") {
std::string token; token << json_get(__data, "token");
std::vector<std::uint64_t> nodes; nodes << json_get(__data, "nodes");
send_msg(__c, FS_get_nodes_size(std::move(token), std::move(nodes)));
} else if (__method == "create_node") {
std::string token; token << json_get(__data, "token");
bool file; file << json_get(__data, "file");
std::uint64_t parent; parent << json_get(__data, "parent");
std::string name; name << json_get(__data, "name");
send_msg(__c, FS_create_node(std::move(token), std::move(file), std::move(parent), std::move(name)));
} else if (__method == "move_nodes") {
std::string token; token << json_get(__data, "token");
std::vector<std::uint64_t> nodes; nodes << json_get(__data, "nodes");
std::uint64_t parent; parent << json_get(__data, "parent");
send_msg(__c, FS_move_nodes(std::move(token), std::move(nodes), std::move(parent)));
} else if (__method == "delete_nodes") {
auto __stream = MRPCStream<std::string>{__c};
std::string token; token << json_get(__data, "token");
std::vector<std::uint64_t> nodes; nodes << json_get(__data, "nodes");
FS_delete_nodes(std::move(token), std::move(nodes), std::move(__stream));
} else if (__method == "download_preview") {
std::string token; token << json_get(__data, "token");
std::uint64_t node; node << json_get(__data, "node");
send_msg(__c, FS_download_preview(std::move(token), std::move(node)));
} else if (__method == "get_mime") {
std::string token; token << json_get(__data, "token");
std::uint64_t node; node << json_get(__data, "node");
send_msg(__c, FS_get_mime(std::move(token), std::move(node)));
}
else { throw std::exception{}; }
}
else { throw std::exception{}; }
}
}

View File

@@ -0,0 +1,161 @@
#pragma once
#ifndef MRPC_GEN_H
#define MRPC_GEN_H
#include <memory>
#include <string>
#include <vector>
#include <optional>
#include <cstdint>
#include <cmath>
#include <corvusoft/restbed/byte.hpp>
#define RAPIDJSON_HAS_STDSTRING 1
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include <rapidjson/document.h>
namespace restbed {
class Resource;
class Session;
}
namespace mrpc {
using MRPCJWriter = rapidjson::Writer<rapidjson::StringBuffer>;
template<typename T>
struct Response;
struct LoginResponse;
struct Session;
struct UserInfo;
struct Node;
struct CreateNodeInfo;
struct ZipInfo;
struct PathSegment;
template<typename T>
struct Response {
std::optional<std::string> e;
std::optional<T> o;
MRPCJWriter& operator >>(MRPCJWriter&) const;
Response& operator <<(const rapidjson::Value&);
};
struct LoginResponse {
bool otp_needed;
std::optional<std::string> token;
MRPCJWriter& operator >>(MRPCJWriter&) const;
LoginResponse& operator <<(const rapidjson::Value&);
};
struct Session {
std::string name;
bool tfa_enabled;
bool admin;
bool sudo;
MRPCJWriter& operator >>(MRPCJWriter&) const;
Session& operator <<(const rapidjson::Value&);
};
struct UserInfo {
std::uint64_t id;
std::string name;
bool tfa;
bool admin;
bool enabled;
MRPCJWriter& operator >>(MRPCJWriter&) const;
UserInfo& operator <<(const rapidjson::Value&);
};
struct Node {
std::uint64_t id;
std::string name;
bool file;
bool preview;
std::optional<std::uint64_t> parent;
std::optional<std::uint64_t> size;
std::optional<std::vector<Node>> children;
MRPCJWriter& operator >>(MRPCJWriter&) const;
Node& operator <<(const rapidjson::Value&);
};
struct CreateNodeInfo {
std::uint64_t id;
bool exists;
bool file;
MRPCJWriter& operator >>(MRPCJWriter&) const;
CreateNodeInfo& operator <<(const rapidjson::Value&);
};
struct ZipInfo {
bool done;
std::uint64_t progress;
std::uint64_t total;
MRPCJWriter& operator >>(MRPCJWriter&) const;
ZipInfo& operator <<(const rapidjson::Value&);
};
struct PathSegment {
std::string name;
std::optional<std::uint64_t> id;
MRPCJWriter& operator >>(MRPCJWriter&) const;
PathSegment& operator <<(const rapidjson::Value&);
};
struct MRPCStreamImpl {
void close() const noexcept;
bool is_open() const noexcept;
protected:
explicit MRPCStreamImpl(const std::shared_ptr<restbed::Session> &conn);
std::shared_ptr<restbed::Session> conn;
};
template<typename T>
struct MRPCStream final : MRPCStreamImpl {
explicit MRPCStream(const std::shared_ptr<restbed::Session> &conn) : MRPCStreamImpl(conn) {}
void send(const T &v) const noexcept;
};
template struct MRPCStream<std::string>;
struct MRPCServer {
MRPCServer() = delete;
explicit MRPCServer(std::shared_ptr<restbed::Resource>&);
private:
virtual std::optional<std::string> Auth_signup(std::string &&username, std::string &&password) = 0;
virtual Response<LoginResponse> Auth_login(std::string &&username, std::string &&password, std::optional<std::string> &&otp) = 0;
virtual void Auth_send_recovery_key(std::string &&username) = 0;
virtual std::optional<std::string> Auth_reset_password(std::string &&key, std::string &&password) = 0;
virtual std::optional<std::string> Auth_change_password(std::string &&token, std::string &&old_pw, std::string &&new_pw) = 0;
virtual void Auth_logout(std::string &&token) = 0;
virtual std::optional<std::string> Auth_logout_all(std::string &&token) = 0;
virtual std::optional<std::string> Auth_tfa_setup_mail(std::string &&token) = 0;
virtual Response<std::string> Auth_tfa_setup_totp(std::string &&token) = 0;
virtual std::optional<std::string> Auth_tfa_complete(std::string &&token, std::string &&otp) = 0;
virtual std::optional<std::string> Auth_tfa_disable(std::string &&token) = 0;
virtual std::optional<std::string> Auth_delete_user(std::string &&token) = 0;
virtual Response<Session> Auth_session_info(std::string &&token) = 0;
virtual Response<std::vector<UserInfo>> Admin_list_users(std::string &&token) = 0;
virtual std::optional<std::string> Admin_delete_user(std::string &&token, std::uint64_t &&user) = 0;
virtual std::optional<std::string> Admin_logout(std::string &&token, std::uint64_t &&user) = 0;
virtual std::optional<std::string> Admin_disable_tfa(std::string &&token, std::uint64_t &&user) = 0;
virtual std::optional<std::string> Admin_set_admin(std::string &&token, std::uint64_t &&user, bool &&admin) = 0;
virtual std::optional<std::string> Admin_set_enabled(std::string &&token, std::uint64_t &&user, bool &&enabled) = 0;
virtual std::optional<std::string> Admin_sudo(std::string &&token, std::uint64_t &&user) = 0;
virtual std::optional<std::string> Admin_unsudo(std::string &&token) = 0;
virtual std::optional<std::string> Admin_shutdown(std::string &&token) = 0;
virtual Response<Node> FS_get_node(std::string &&token, std::uint64_t &&node) = 0;
virtual Response<std::vector<PathSegment>> FS_get_path(std::string &&token, std::uint64_t &&node) = 0;
virtual Response<std::uint64_t> FS_get_nodes_size(std::string &&token, std::vector<std::uint64_t> &&nodes) = 0;
virtual Response<CreateNodeInfo> FS_create_node(std::string &&token, bool &&file, std::uint64_t &&parent, std::string &&name) = 0;
virtual std::optional<std::string> FS_move_nodes(std::string &&token, std::vector<std::uint64_t> &&nodes, std::uint64_t &&parent) = 0;
virtual void FS_delete_nodes(std::string &&token, std::vector<std::uint64_t> &&nodes, MRPCStream<std::string>&&) = 0;
virtual Response<std::string> FS_download_preview(std::string &&token, std::uint64_t &&node) = 0;
virtual Response<std::string> FS_get_mime(std::string &&token, std::uint64_t &&node) = 0;
virtual void msg_handler(std::shared_ptr<restbed::Session>, const restbed::Bytes&) final;
};
}
#endif // MRPC_GEN_H

87
src/server/server.cxx Normal file
View File

@@ -0,0 +1,87 @@
#include "server_internal.hxx"
std::shared_ptr<Token> Server::get_token(const std::string &token) {
std::shared_lock lock{token_lock};
const auto &entry = tokens.find(token);
if (entry == tokens.end())
return nullptr;
return entry->second;
}
std::shared_ptr<User> Server::get_user(std::uint64_t id) {
std::shared_lock lock{user_lock};
const auto &entry = users.find(id);
if (entry == users.end())
return nullptr;
return entry->second;
}
std::shared_ptr<User> Server::is_token_valid(const std::string &token) {
auto t = get_token(token);
if (!t)
return nullptr;
if (Token::clock::now() <= t->expire) {
t->refresh();
return t->user;
}
{
std::unique_lock lock{token_lock};
tokens.erase(token);
}
return nullptr;
}
void Server::logout_user(std::uint64_t id) {
std::unique_lock lock{token_lock};
for (auto it = tokens.begin(); it != tokens.end();) {
if (it->second->user->id == id)
tokens.erase(it++);
else
++it;
}
}
void Server::delete_user(const std::shared_ptr<User> &user) {
std::unique_lock lock{user_lock};
logout_user(user->id);
delete_node(user, 0, [](const std::string&){});
users.erase(user->id);
}
void Server::send_tfa_mail(const std::shared_ptr<User> &user) {
std::lock_guard lock{mail_otp_lock};
std::string code; code.reserve(10);
for (int i = 0; i < 10; ++i) {
auto j = auth_rng->next_byte();
while (j > 249) j = auth_rng->next_byte();
code.push_back('0' + (j%10));
}
mail_otp.emplace(code, std::make_pair(user->id, Token::clock::now() + std::chrono::minutes{10}));
send_mail(user->name, "MFileserver - TFA code", "Your code is: " + code + "\r\nIt is valid for 10 minutes");
}
bool Server::check_mail_code(const std::shared_ptr<User> &user, const std::string &code) {
std::lock_guard lock{mail_otp_lock};
auto now = Token::clock::now();
for (auto it = mail_otp.begin(); it != mail_otp.end();) {
if (now >= it->second.second)
mail_otp.erase(it++);
else
++it;
}
const auto &entry = mail_otp.find(code);
bool ok = entry != mail_otp.end() && entry->second.first == user->id;
if (ok)
mail_otp.erase(code);
return ok;
}
bool Server::check_tfa_code(const std::shared_ptr<User> &user, const std::string &code_str) {
Botan::OctetString secret{user->tfa_secret};
Botan::TOTP totp{secret};
try {
std::uint32_t code = std::stoul(code_str);
return totp.verify_totp(code, std::chrono::system_clock::now(), 1);
} catch (std::exception &_) {}
return false;
}

66
src/server/server.hxx Normal file
View File

@@ -0,0 +1,66 @@
#ifndef FILESERVER_SERVER_HXX
#define FILESERVER_SERVER_HXX
#include <corvusoft/restbed/service.hpp>
#include "mrpc/fileserver.hxx"
#include "../data/data.hxx"
extern std::shared_ptr<restbed::Service> g_service;
struct Server final : public mrpc::MRPCServer, public Data {
explicit Server(std::shared_ptr<restbed::Resource> &ptr) : MRPCServer(ptr), Data() {}
std::shared_ptr<Token> get_token(const std::string&);
std::shared_ptr<User> is_token_valid(const std::string&);
std::shared_ptr<User> get_user(std::uint64_t id);
static void delete_node(const std::shared_ptr<User> &user, std::uint64_t id, const std::function<void(std::string)>& log);
void logout_user(std::uint64_t id);
void delete_user(const std::shared_ptr<User> &user);
void send_tfa_mail(const std::shared_ptr<User> &user);
static bool check_tfa_code(const std::shared_ptr<User> &user, const std::string &code);
bool check_mail_code(const std::shared_ptr<User> &user, const std::string &code);
void send_mail(const std::string& email, const std::string& title, const std::string& body);
std::uint64_t nodes_size(const std::shared_ptr<User> &user, const std::vector<std::uint64_t> &ids);
void download(const std::shared_ptr<restbed::Session>&);
void download_multi(const std::shared_ptr<restbed::Session>&);
void upload(const std::shared_ptr<restbed::Session>&);
private:
std::optional<std::string> Auth_signup(std::string &&username, std::string &&password) override;
mrpc::Response<mrpc::LoginResponse> Auth_login(std::string &&username, std::string &&password, std::optional<std::string> &&otp) override;
void Auth_send_recovery_key(std::string &&username) override;
std::optional<std::string> Auth_reset_password(std::string &&key, std::string &&password) override;
std::optional<std::string> Auth_change_password(std::string &&token, std::string &&old_pw, std::string &&new_pw) override;
void Auth_logout(std::string &&token) override;
std::optional<std::string> Auth_logout_all(std::string &&token) override;
std::optional<std::string> Auth_tfa_setup_mail(std::string &&token) override;
mrpc::Response<std::string> Auth_tfa_setup_totp(std::string &&token) override;
std::optional<std::string> Auth_tfa_complete(std::string &&token, std::string &&otp) override;
std::optional<std::string> Auth_tfa_disable(std::string &&token) override;
std::optional<std::string> Auth_delete_user(std::string &&token) override;
mrpc::Response<mrpc::Session> Auth_session_info(std::string &&token) override;
mrpc::Response<std::vector<mrpc::UserInfo>> Admin_list_users(std::string &&token) override;
std::optional<std::string> Admin_delete_user(std::string &&token, std::uint64_t &&user) override;
std::optional<std::string> Admin_logout(std::string &&token, std::uint64_t &&user) override;
std::optional<std::string> Admin_disable_tfa(std::string &&token, std::uint64_t &&user) override;
std::optional<std::string> Admin_set_admin(std::string &&token, std::uint64_t &&user, bool &&admin) override;
std::optional<std::string> Admin_set_enabled(std::string &&token, std::uint64_t &&user, bool &&enabled) override;
std::optional<std::string> Admin_sudo(std::string &&token, std::uint64_t &&user) override;
std::optional<std::string> Admin_unsudo(std::string &&token) override;
std::optional<std::string> Admin_shutdown(std::string &&token) override;
mrpc::Response<mrpc::Node> FS_get_node(std::string &&token, std::uint64_t &&node) override;
mrpc::Response<std::vector<mrpc::PathSegment>> FS_get_path(std::string &&token, std::uint64_t &&node) override;
mrpc::Response<std::uint64_t> FS_get_nodes_size(std::string &&token, std::vector<std::uint64_t> &&nodes) override;
mrpc::Response<mrpc::CreateNodeInfo> FS_create_node(std::string &&token, bool &&file, std::uint64_t &&parent, std::string &&name) override;
std::optional<std::string> FS_move_nodes(std::string &&token, std::vector<std::uint64_t> &&nodes, std::uint64_t &&parent) override;
void FS_delete_nodes(std::string &&token, std::vector<std::uint64_t> &&nodes, mrpc::MRPCStream<std::string> &&stream) override;
mrpc::Response<std::string> FS_download_preview(std::string &&token, std::uint64_t &&node) override;
mrpc::Response<std::string> FS_get_mime(std::string &&token, std::uint64_t &&node) override;
};
#endif //FILESERVER_SERVER_HXX

View File

@@ -0,0 +1,70 @@
#ifndef FILESERVER_SERVER_INTERNAL_HXX
#define FILESERVER_SERVER_INTERNAL_HXX
#include <botan_all.h>
#include "server.hxx"
// TODO log user action with __FUNC__
#define check_user() auto user = is_token_valid(token); if (!user || !user->enabled)
#define check_user_response() check_user() return { .e = "Unauthorized" }
#define check_user_optional() check_user() return "Unauthorized"
#if defined(BOTAN_HAS_SYSTEM_RNG)
static std::unique_ptr<Botan::RNG> auth_rng = std::make_unique<Botan::System_RNG>();
#else
static std::unique_ptr<Botan::RNG> auth_rng = std::make_unique<Botan::AutoSeeded_RNG>();
#endif
// https://developer.mozilla.org/en-US/docs/Web/Media/Formats/Image_types#common_image_file_types
static const std::unordered_map<std::string, std::string> mime_type_map = {
{".apng" , "image/apng"},
{".avif" , "image/avif"},
{".bmp" , "image/bmp"},
{".gif" , "image/gif"},
{".jpg" , "image/jpeg"},
{".jpeg" , "image/jpeg"},
{".jfif" , "image/jpeg"},
{".pjpeg", "image/jpeg"},
{".pjp" , "image/jpeg"},
{".png" , "image/png"},
{".svg" , "image/svg"},
{".webp" , "image/webp"},
{".aac" , "audio/aac"},
{".flac" , "audio/flac"},
{".mp3" , "audio/mp3"},
{".m4a" , "audio/mp4"},
{".oga" , "audio/ogg"},
{".ogg" , "audio/ogg"},
{".wav" , "audio/wav"},
{".3gp" , "video/3gpp"},
{".mpg" , "video/mpeg"},
{".mpeg" , "video/mpeg"},
{".mp4" , "video/mp4"},
{".m4v" , "video/mp4"},
{".m4p" , "video/mp4"},
{".ogv" , "video/ogg"},
{".mov" , "video/quicktime"},
{".webm" , "video/webm"},
{".mkv" , "video/x-matroska"},
{".mk3d" , "video/x-matroska"},
{".mks" , "video/x-matroska"},
{".pdf" , "application/pdf"}
};
static const std::string& get_mime_type(const std::filesystem::path &filename) {
static const std::string octet = "application/octet-stream";
const auto &entry = mime_type_map.find(filename.extension());
if (entry != mime_type_map.end())
return entry->second;
else
return octet;
}
std::shared_ptr<Node> get_node(const std::shared_ptr<User>& user, std::uint64_t id);
#endif //FILESERVER_SERVER_INTERNAL_HXX

113
src/server/upload.cxx Normal file
View File

@@ -0,0 +1,113 @@
#include <fstream>
#include <corvusoft/restbed/session.hpp>
#include <corvusoft/restbed/request.hpp>
#include <corvusoft/restbed/response.hpp>
#include <spdlog/spdlog.h>
#include <stb_image.h>
#include <stb_image_resize2.h>
#include <stb_image_write.h>
#include "server_internal.hxx"
static constexpr std::size_t chunk_size = 1024*1024, max_image_size = 1024*1024*50, preview_size=480;
static const std::set<std::string> image_extension = {".png", ".jpg", ".jpeg", ".tga", ".bmp", ".psd", ".gif", ".jfif", ".pjpeg", ".pjp"};
struct UploadInfo {
Server *server;
std::shared_lock<std::shared_mutex> node_lock;
std::size_t to_read;
std::filesystem::path path;
std::ofstream file;
std::shared_ptr<Node> node;
};
void make_preview(const std::shared_ptr<UploadInfo>& info) {
int x, y, channels;
auto img = std::unique_ptr<stbi_uc, decltype(&free)>
{stbi_load(info->path.c_str(), &x, &y, &channels, 0), &free};
if (!img)
return;
float x_ration = (float)preview_size / (float)x, y_ration = (float)preview_size / (float)y;
float ratio = std::min(x_ration, y_ration);
int new_x = (int)((float)(x)*ratio), new_y = (int)((float)(y)*ratio);
stbir_pixel_layout layout;
switch (channels) {
case 1: layout = STBIR_1CHANNEL; break;
case 2: layout = STBIR_2CHANNEL; break;
case 3: layout = STBIR_RGB; break;
case 4: layout = STBIR_RGBA; break;
default: return;
}
auto rimg = std::unique_ptr<unsigned char, decltype(&free)>
{stbir_resize_uint8_linear(img.get(), x, y, 0, nullptr, new_x, new_y, 0, layout), &free};
if (!rimg)
return;
auto png_path = info->path.replace_extension("png");
if (!stbi_write_png(png_path.c_str(), new_x, new_y, channels, rimg.get(), 0))
return;
info->node->preview = true;
}
void fetch_handler(const std::shared_ptr<restbed::Session> &s, const restbed::Bytes &bytes) {
std::shared_ptr<UploadInfo> info = s->get("upload");
std::size_t read = bytes.size();
info->to_read -= std::min(read, info->to_read);
info->file.write((char*)bytes.data(), bytes.size());
if (info->to_read > 0)
return s->fetch(std::min(info->to_read, chunk_size), fetch_handler);
info->file.close();
s->close(200);
std::size_t real_size = std::filesystem::file_size(info->path);
info->node->size = real_size;
auto ext = std::filesystem::path{info->node->name}.extension().string();
if (real_size < max_image_size && image_extension.contains(ext))
make_preview(info);
info->node_lock.unlock();
info->server->save();
}
void Server::upload(const std::shared_ptr<restbed::Session> &s) {
const auto req = s->get_request();
if (!req->has_header("X-Node"))
return s->close(400, "Missing node");
if (!req->has_header("X-Token"))
return s->close(400, "Missing token");
if (req->get_header("Transfer-Encoding") == "chunked") {
spdlog::error("Encountered a chunked upload!");
return s->close(500, "Sorry but your browser is not supported yet");
}
std::uint64_t node_id = req->get_header("X-Node", 0);
std::string token = req->get_header("X-Token");
check_user() return s->close(400, "Invalid user");
{
std::shared_lock lock{user->node_lock};
auto node = get_node(user, node_id);
if (!node) return s->close(400, "Invalid node");
if (!node->file) return s->close(400, "Can't upload to a directory");
std::size_t to_read = req->get_header("Content-Length", 0);
auto path = user->user_dir / std::to_string(node->id);
if (node->preview) {
node->preview = false;
std::filesystem::remove(path.replace_extension("png"));
}
std::shared_ptr<UploadInfo> info{new UploadInfo{
.server = this,
.node_lock = std::shared_lock{user->node_lock},
.to_read = to_read,
.path = path,
.file = std::ofstream{path, std::ios_base::out|std::ios_base::trunc|std::ios_base::binary},
.node = node
}};
s->set("upload", info);
s->fetch(std::min(to_read, chunk_size), fetch_handler);
}
}

15
src/util/crash.hxx Normal file
View File

@@ -0,0 +1,15 @@
#ifndef FILESERVER_CRASH_HXX
#define FILESERVER_CRASH_HXX
#include <source_location>
#include <spdlog/spdlog.h>
// TODO implement backtrace
[[noreturn]]
static void crash(std::source_location loc = std::source_location::current()) {
spdlog::critical("crash called from: {}:{} `{}`", loc.file_name(), loc.line(), loc.function_name());
spdlog::shutdown();
std::abort();
}
#endif //FILESERVER_CRASH_HXX

60
src/util/logging.hxx Normal file
View File

@@ -0,0 +1,60 @@
#ifndef FILESERVER_LOGGING_HXX
#define FILESERVER_LOGGING_HXX
#include <cstdarg>
#include <spdlog/spdlog.h>
#include <corvusoft/restbed/logger.hpp>
namespace {
}
namespace logging {
struct RestbedLogger : public restbed::Logger {
void stop() override {}
void start(const std::shared_ptr<const restbed::Settings>&) override {
logger = spdlog::default_logger()->clone("restbed");
}
void log(Level level, const char *format, ...) override {
std::va_list args;
va_start(args, format);
restbed_log(level, format, args);
va_end(args);
}
void log_if(bool expression, Level level, const char *format, ...) override {
if (expression) {
va_list args;
va_start(args, format);
restbed_log(level, format, args);
va_end(args);
}
}
private:
std::shared_ptr<spdlog::logger> logger;
void restbed_log(const restbed::Logger::Level restbed_level, const char* format, va_list args) {
spdlog::level::level_enum level;
switch (restbed_level) {
case restbed::Logger::DEBUG: level = spdlog::level::level_enum::debug; break;
case restbed::Logger::INFO: level = spdlog::level::level_enum::info; break;
case restbed::Logger::WARNING: level = spdlog::level::level_enum::warn; break;
case restbed::Logger::ERROR: level = spdlog::level::level_enum::err; break;
case restbed::Logger::SECURITY:
case restbed::Logger::FATAL: level = spdlog::level::level_enum::critical; break;
}
std::string buf;
buf.resize(1024);
int written = vsnprintf(buf.data(), 1024, format, args);
//if (std::string_view{buf.cbegin(), buf.cbegin()+10} == "Incoming '")
// return;
if (written >= 1024) {
buf.resize(written + 10);
written = vsnprintf(buf.data(), written + 10, format, args);
}
buf.resize(written);
logger->log(level, buf);
}
};
}
#endif //FILESERVER_LOGGING_HXX

7
src/util/stb.cxx Normal file
View File

@@ -0,0 +1,7 @@
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include <stb_image.h>
#include <stb_image_resize2.h>
#include <stb_image_write.h>

30
src/util/timed_mutex.hxx Normal file
View File

@@ -0,0 +1,30 @@
#ifndef FILESERVER_TIMED_MUTEX_HXX
#define FILESERVER_TIMED_MUTEX_HXX
#include <spdlog/spdlog.h>
#include <shared_mutex>
struct TimedSharedMutex {
std::shared_mutex m;
using clock = std::chrono::high_resolution_clock;
void lock_shared() {
auto start = clock::now();
m.lock_shared();
auto end = clock::now();
auto d = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
spdlog::info("Lock s took {} ms", d.count());
}
void unlock_shared() { m.unlock_shared(); spdlog::info("Unlock s"); }
void lock() {
auto start = clock::now();
m.lock();
auto end = clock::now();
auto d = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
spdlog::info("Lock took {} ms", d.count());
}
void unlock() { m.unlock(); spdlog::info("Unlock"); }
};
#endif //FILESERVER_TIMED_MUTEX_HXX