use std::io::Write; use itertools::Itertools; use crate::data::RPC; use super::IndentedWriter; fn field_ty_to_ty_str(ty: &crate::data::FieldTy, ignore_optional: bool) -> String { use crate::data::Types; let inner = match &ty.ty { Types::String => "std::string", Types::Bool => "bool", Types::F32 => "std::float_t", Types::F64 => "std::double_t", Types::I8 => "std::int8_t", Types::I16 => "std::int16_t", Types::I32 => "std::int32_t", Types::I64 => "std::int64_t", Types::U8 => "std::uint8_t", Types::U16 => "std::uint16_t", Types::U32 => "std::uint32_t", Types::U64 => "std::uint64_t", Types::Named(name) => name }; if ty.array { format!("std::vector<{inner}>") } else if ty.optional && !ignore_optional { format!("std::optional<{inner}>") } else { inner.to_string() } } fn method_signature(service_name: &String, m: &crate::data::MethodTy) -> String { let mut ret = String::new(); let ret_type = m.ret.as_ref().map_or("void".to_string(), |ty| field_ty_to_ty_str(ty, false)); if m.ret_stream { ret += "void"; } else { ret += &ret_type; } ret += &format!(" {}_{}(", service_name, m.name); ret += &m.args.iter() .map(|arg| field_ty_to_ty_str(arg, false)) .chain(m.ret_stream.then(|| format!("std::shared_ptr>"))) .map(|arg| arg + "&&") .join(", "); ret += ")"; ret } fn output_header(f: &mut IndentedWriter, rpc: &RPC) { f.f.write_all( b"#pragma once #ifndef MRPC_GEN_H #define MRPC_GEN_H #include #include #include #include #include #include #include #include namespace mrpc {\n").unwrap(); for e in &rpc.enums { writeln!(f, "enum struct {} : std::uint64_t {{", e.name).unwrap(); if let Some((last, vals)) = e.values.split_last() { for v in vals { writeln!(f, "{} = {},", v.0, v.1).unwrap(); } writeln!(f, "{} = {}", last.0, last.1).unwrap(); } f.write_all(b"};\n\n").unwrap(); } for s in &rpc.structs { writeln!(f, "struct {};", s.name).unwrap(); writeln!(f, "void to_json(nlohmann::json&, const {}&);", s.name).unwrap(); writeln!(f, "void from_json(const nlohmann::json&, {}&);\n", s.name).unwrap(); } f.f.write_all(b"\n").unwrap(); for s in &rpc.structs { writeln!(f, "struct {} {{", s.name).unwrap(); for field in &s.fields { writeln!(f, "{} {};", field_ty_to_ty_str(field, false), field.name).unwrap(); } f.write_all(b"};\n\n").unwrap(); } f.f.write_all( b"struct MRPCStreamImpl { virtual void close() noexcept final; virtual void abort() noexcept final; virtual bool is_open() noexcept final; protected: MRPCStreamImpl(crow::websocket::connection *conn, uint64_t id) : conn(conn), id(id) {} crow::websocket::connection* conn; std::uint64_t id; }; template struct MRPCStream final : MRPCStreamImpl { MRPCStream(crow::websocket::connection *conn, uint64_t id) : MRPCStreamImpl(conn, id) {} bool send(const T &v) noexcept { if (!conn) return false; try { conn->send_text(nlohmann::json{{\"id\", id},{\"data\", v}}.dump()); } catch (const std::exception &_) { abort(); return false; } return true; } }; struct MRPCServer { virtual void install(crow::SimpleApp &app, std::string &&route) final; private:\n").unwrap(); f.ident = 1; for service in &rpc.services { for method in &service.methods { writeln!(f, "virtual {} = 0;", method_signature(&service.name, method)).unwrap(); } } f.f.write_all(b" virtual void msg_handler(crow::websocket::connection&, const std::string&, bool) final; std::mutex __streams_mutex; std::unordered_multimap> __streams; }; } #endif // MRPC_GEN_H\n").unwrap(); } fn output_struct_json_stuff(f: &mut IndentedWriter, s: &crate::data::StructTy) { writeln!(f, "void to_json(json &j, const {} &v) {{", s.name).unwrap(); for field in &s.fields { if field.optional { writeln!(f, "json_set_opt(j, \"{0}\", v.{0});", field.name).unwrap(); } else { writeln!(f, "j[\"{0}\"] = v.{0};", field.name).unwrap(); } } f.write_all(b"}\n\n").unwrap(); writeln!(f, "void from_json(const json &j, {} &v) {{", s.name).unwrap(); for field in &s.fields { if field.optional { writeln!(f, "v.{0} = json_get_opt<{1}>(j, \"{0}\");", field.name, field_ty_to_ty_str(field, true)).unwrap(); } else { writeln!(f, "j.at(\"{0}\").get_to(v.{0});", field.name).unwrap(); } } f.write_all(b"}\n\n").unwrap(); } fn output_cpp(f: &mut IndentedWriter, header_name: String, rpc: &RPC) { writeln!(f, "#include \"{header_name}\"").unwrap(); f.f.write_all( b"using json = nlohmann::json; template inline std::optional json_get_opt(const json &j, std::string &&k) { if (j.contains(k) && !j.at(k).is_null()) return j.at(k).get(); else return std::nullopt; } template inline void json_set_opt(json &j, std::string &&k, const std::optional &v) { if (v.has_value()) j[k] = v.value(); else j[k] = nullptr; } namespace mrpc {\n").unwrap(); for s in &rpc.structs { output_struct_json_stuff(f, s); } f.f.write_all( b"} template void send_msg(crow::websocket::connection &c, uint64_t id, const T &v) { c.send_text(json{{\"id\", id},{\"data\", v}}.dump()); } void mrpc::MRPCStreamImpl::close() noexcept { if (conn != nullptr) { send_msg(*conn, id, nullptr); conn = nullptr; } } void mrpc::MRPCStreamImpl::abort() noexcept { conn = nullptr; } bool mrpc::MRPCStreamImpl::is_open() noexcept { return conn != nullptr; } void mrpc::MRPCServer::install(crow::SimpleApp &app, std::string &&route) { app.route_dynamic(std::move(route)) .websocket() .onclose([&](crow::websocket::connection &c, const std::string&){ std::lock_guard guard{__streams_mutex}; auto range = __streams.equal_range(&c); for (auto it = range.first; it != range.second; ++it) it->second->abort(); __streams.erase(&c); }) .onmessage([this](auto &&a, auto &&b, auto &&c) { try { msg_handler(a, b, c); } catch (const std::exception &_) {} }); } void mrpc::MRPCServer::msg_handler(crow::websocket::connection &__c, const std::string &__msg, bool) { json __j = json::parse(__msg); std::uint64_t __id = __j.at(\"id\"); std::string __service = __j.at(\"service\"), __method = __j.at(\"method\"); try {\n").unwrap(); f.ident = 2; f.write_all(b"json __data = __j.at(\"data\");\n").unwrap(); let mut first_service = true; for service in &rpc.services { if first_service { first_service = false; } else { f.write_all(b"else ").unwrap(); } writeln!(f, "if (__service == \"{}\") {{", service.name).unwrap(); let mut first_method = true; for method in &service.methods { if first_method { first_method = false; } else { f.write_all(b"else ").unwrap(); } writeln!(f, "if (__method == \"{}\") {{", method.name).unwrap(); if method.ret_stream { writeln!(f, "auto __stream = std::make_shared>(&__c, __id);", field_ty_to_ty_str(method.ret.as_ref().unwrap(), false)).unwrap(); f.write_all(b"{ std::lock_guard guard{__streams_mutex}; __streams.emplace(&__c, __stream); }\n").unwrap(); } for arg in &method.args { let ty = field_ty_to_ty_str(&arg, false); if arg.optional { writeln!(f, "{0} {1} = json_get_opt<{2}>(__data, \"{1}\");", ty, arg.name, field_ty_to_ty_str(arg, true)).unwrap(); } else { writeln!(f, "{0} {1} = __data.at(\"{1}\");", ty, arg.name).unwrap(); } } let args = method.args.iter() .map(|arg| format!("std::move({})", arg.name)) .chain(method.ret_stream.then_some(String::from("std::move(__stream)"))) .collect::>(); if method.ret.is_none() || method.ret_stream { writeln!(f, "{}_{}({});", service.name, method.name, args.join(", ")).unwrap(); } else { writeln!(f, "auto __ret = {}_{}({});", service.name, method.name, args.join(", ")).unwrap(); f.write_all(b"send_msg(__c, __id, __ret);\n").unwrap(); } f.write_all(b"} ").unwrap(); } f.write_all(b"else { throw std::exception{}; }\n} ").unwrap(); } f.f.write_all(b"else { throw std::exception{}; } } catch (const std::exception &_) { std::cerr << \"Got invalid request \" << __id << \" for \" << __service << \"::\" << __method << std::endl; } }\n\n").unwrap(); } pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) { output_header(&mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("h")).unwrap()), rpc); output_cpp( &mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("cpp")).unwrap()), file_base_name.with_extension("h").file_name().unwrap().to_string_lossy().to_string(), rpc ); }