295 lines
9.6 KiB
Rust
295 lines
9.6 KiB
Rust
|
use std::io::Write;
|
||
|
use itertools::Itertools;
|
||
|
use crate::data::RPC;
|
||
|
use super::IndentedWriter;
|
||
|
|
||
|
fn field_ty_to_ty_str(ty: &crate::data::FieldTy, ignore_optional: bool) -> String {
|
||
|
use crate::data::Types;
|
||
|
let inner = match &ty.ty {
|
||
|
Types::String => "std::string",
|
||
|
Types::Bool => "bool",
|
||
|
Types::F32 => "std::float_t",
|
||
|
Types::F64 => "std::double_t",
|
||
|
Types::I8 => "std::int8_t",
|
||
|
Types::I16 => "std::int16_t",
|
||
|
Types::I32 => "std::int32_t",
|
||
|
Types::I64 => "std::int64_t",
|
||
|
Types::U8 => "std::uint8_t",
|
||
|
Types::U16 => "std::uint16_t",
|
||
|
Types::U32 => "std::uint32_t",
|
||
|
Types::U64 => "std::uint64_t",
|
||
|
Types::Named(name) => name
|
||
|
};
|
||
|
|
||
|
if ty.array {
|
||
|
format!("std::vector<{inner}>")
|
||
|
} else if ty.optional && !ignore_optional {
|
||
|
format!("std::optional<{inner}>")
|
||
|
} else {
|
||
|
inner.to_string()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn method_signature(service_name: &String, m: &crate::data::MethodTy) -> String {
|
||
|
let mut ret = String::new();
|
||
|
|
||
|
let ret_type = m.ret.as_ref().map_or("void".to_string(), |ty| field_ty_to_ty_str(ty, false));
|
||
|
|
||
|
if m.ret_stream {
|
||
|
ret += "void";
|
||
|
} else {
|
||
|
ret += &ret_type;
|
||
|
}
|
||
|
|
||
|
ret += &format!(" {}_{}(", service_name, m.name);
|
||
|
|
||
|
ret += &m.args.iter()
|
||
|
.map(|arg| field_ty_to_ty_str(arg, false))
|
||
|
.chain(m.ret_stream.then(|| format!("std::shared_ptr<MRPCStream<{ret_type}>>")))
|
||
|
.map(|arg| arg + "&&")
|
||
|
.join(", ");
|
||
|
|
||
|
ret += ")";
|
||
|
|
||
|
ret
|
||
|
}
|
||
|
|
||
|
fn output_header(f: &mut IndentedWriter, rpc: &RPC) {
|
||
|
f.f.write_all(
|
||
|
b"#pragma once
|
||
|
#ifndef MRPC_GEN_H
|
||
|
#define MRPC_GEN_H
|
||
|
|
||
|
#include <unordered_map>
|
||
|
#include <memory>
|
||
|
#include <mutex>
|
||
|
#include <iosfwd>
|
||
|
#include <string>
|
||
|
#include <cstdint>
|
||
|
#include <crow.h>
|
||
|
#include <json.hpp>
|
||
|
|
||
|
namespace mrpc {\n").unwrap();
|
||
|
|
||
|
for e in &rpc.enums {
|
||
|
writeln!(f, "enum struct {} : std::uint64_t {{", e.name).unwrap();
|
||
|
if let Some((last, vals)) = e.values.split_last() {
|
||
|
for v in vals {
|
||
|
writeln!(f, "{} = {},", v.0, v.1).unwrap();
|
||
|
}
|
||
|
writeln!(f, "{} = {}", last.0, last.1).unwrap();
|
||
|
}
|
||
|
f.write_all(b"};\n\n").unwrap();
|
||
|
}
|
||
|
|
||
|
for s in &rpc.structs {
|
||
|
writeln!(f, "struct {};", s.name).unwrap();
|
||
|
writeln!(f, "void to_json(nlohmann::json&, const {}&);", s.name).unwrap();
|
||
|
writeln!(f, "void from_json(const nlohmann::json&, {}&);\n", s.name).unwrap();
|
||
|
}
|
||
|
|
||
|
f.f.write_all(b"\n").unwrap();
|
||
|
|
||
|
for s in &rpc.structs {
|
||
|
writeln!(f, "struct {} {{", s.name).unwrap();
|
||
|
for field in &s.fields {
|
||
|
writeln!(f, "{} {};", field_ty_to_ty_str(field, false), field.name).unwrap();
|
||
|
}
|
||
|
f.write_all(b"};\n\n").unwrap();
|
||
|
}
|
||
|
|
||
|
f.f.write_all(
|
||
|
b"struct MRPCStreamImpl {
|
||
|
virtual void close() noexcept final;
|
||
|
virtual void abort() noexcept final;
|
||
|
virtual bool is_open() noexcept final;
|
||
|
protected:
|
||
|
MRPCStreamImpl(crow::websocket::connection *conn, uint64_t id) : conn(conn), id(id) {}
|
||
|
crow::websocket::connection* conn;
|
||
|
std::uint64_t id;
|
||
|
};
|
||
|
|
||
|
template<class T>
|
||
|
struct MRPCStream final : MRPCStreamImpl {
|
||
|
MRPCStream(crow::websocket::connection *conn, uint64_t id) : MRPCStreamImpl(conn, id) {}
|
||
|
bool send(const T &v) noexcept {
|
||
|
if (!conn) return false;
|
||
|
try {
|
||
|
conn->send_text(nlohmann::json{{\"id\", id},{\"data\", v}}.dump());
|
||
|
} catch (const std::exception &_) {
|
||
|
abort();
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
struct MRPCServer {
|
||
|
virtual void install(crow::SimpleApp &app, std::string &&route) final;
|
||
|
private:\n").unwrap();
|
||
|
f.ident = 1;
|
||
|
|
||
|
for service in &rpc.services {
|
||
|
for method in &service.methods {
|
||
|
writeln!(f, "virtual {} = 0;", method_signature(&service.name, method)).unwrap();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
f.f.write_all(b"
|
||
|
virtual void msg_handler(crow::websocket::connection&, const std::string&, bool) final;
|
||
|
|
||
|
std::mutex __streams_mutex;
|
||
|
std::unordered_multimap<crow::websocket::connection*, std::shared_ptr<MRPCStreamImpl>> __streams;
|
||
|
};
|
||
|
}
|
||
|
#endif // MRPC_GEN_H\n").unwrap();
|
||
|
}
|
||
|
|
||
|
fn output_struct_json_stuff(f: &mut IndentedWriter, s: &crate::data::StructTy) {
|
||
|
writeln!(f, "void to_json(json &j, const {} &v) {{", s.name).unwrap();
|
||
|
for field in &s.fields {
|
||
|
if field.optional {
|
||
|
writeln!(f, "json_set_opt(j, \"{0}\", v.{0});", field.name).unwrap();
|
||
|
} else {
|
||
|
writeln!(f, "j[\"{0}\"] = v.{0};", field.name).unwrap();
|
||
|
}
|
||
|
}
|
||
|
f.write_all(b"}\n\n").unwrap();
|
||
|
writeln!(f, "void from_json(const json &j, {} &v) {{", s.name).unwrap();
|
||
|
for field in &s.fields {
|
||
|
if field.optional {
|
||
|
writeln!(f, "v.{0} = json_get_opt<{1}>(j, \"{0}\");", field.name, field_ty_to_ty_str(field, true)).unwrap();
|
||
|
} else {
|
||
|
writeln!(f, "j.at(\"{0}\").get_to(v.{0});", field.name).unwrap();
|
||
|
}
|
||
|
}
|
||
|
f.write_all(b"}\n\n").unwrap();
|
||
|
}
|
||
|
|
||
|
fn output_cpp(f: &mut IndentedWriter, header_name: String, rpc: &RPC) {
|
||
|
writeln!(f, "#include \"{header_name}\"").unwrap();
|
||
|
f.f.write_all(
|
||
|
b"using json = nlohmann::json;
|
||
|
|
||
|
template<class T>
|
||
|
inline std::optional<T> json_get_opt(const json &j, std::string &&k) {
|
||
|
if (j.contains(k) && !j.at(k).is_null())
|
||
|
return j.at(k).get<T>();
|
||
|
else
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
template<class T>
|
||
|
inline void json_set_opt(json &j, std::string &&k, const std::optional<T> &v) {
|
||
|
if (v.has_value())
|
||
|
j[k] = v.value();
|
||
|
else
|
||
|
j[k] = nullptr;
|
||
|
}
|
||
|
|
||
|
namespace mrpc {\n").unwrap();
|
||
|
|
||
|
for s in &rpc.structs {
|
||
|
output_struct_json_stuff(f, s);
|
||
|
}
|
||
|
|
||
|
f.f.write_all(
|
||
|
b"}
|
||
|
|
||
|
template<class T>
|
||
|
void send_msg(crow::websocket::connection &c, uint64_t id, const T &v) {
|
||
|
c.send_text(json{{\"id\", id},{\"data\", v}}.dump());
|
||
|
}
|
||
|
|
||
|
void mrpc::MRPCStreamImpl::close() noexcept {
|
||
|
if (conn != nullptr) {
|
||
|
send_msg(*conn, id, nullptr);
|
||
|
conn = nullptr;
|
||
|
}
|
||
|
}
|
||
|
void mrpc::MRPCStreamImpl::abort() noexcept { conn = nullptr; }
|
||
|
bool mrpc::MRPCStreamImpl::is_open() noexcept { return conn != nullptr; }
|
||
|
|
||
|
void mrpc::MRPCServer::install(crow::SimpleApp &app, std::string &&route) {
|
||
|
app.route_dynamic(std::move(route))
|
||
|
.websocket()
|
||
|
.onclose([&](crow::websocket::connection &c, const std::string&){
|
||
|
std::lock_guard guard{__streams_mutex};
|
||
|
auto range = __streams.equal_range(&c);
|
||
|
for (auto it = range.first; it != range.second; ++it)
|
||
|
it->second->abort();
|
||
|
__streams.erase(&c);
|
||
|
})
|
||
|
.onmessage([this](auto &&a, auto &&b, auto &&c) {
|
||
|
try { msg_handler(a, b, c); }
|
||
|
catch (const std::exception &_) {}
|
||
|
});
|
||
|
}
|
||
|
void mrpc::MRPCServer::msg_handler(crow::websocket::connection &__c, const std::string &__msg, bool) {
|
||
|
json __j = json::parse(__msg);
|
||
|
std::uint64_t __id = __j.at(\"id\");
|
||
|
std::string __service = __j.at(\"service\"), __method = __j.at(\"method\");
|
||
|
try {\n").unwrap();
|
||
|
f.ident = 2;
|
||
|
|
||
|
f.write_all(b"json __data = __j.at(\"data\");\n").unwrap();
|
||
|
|
||
|
let mut first_service = true;
|
||
|
for service in &rpc.services {
|
||
|
if first_service { first_service = false; }
|
||
|
else { f.write_all(b"else ").unwrap(); }
|
||
|
writeln!(f, "if (__service == \"{}\") {{", service.name).unwrap();
|
||
|
let mut first_method = true;
|
||
|
for method in &service.methods {
|
||
|
if first_method { first_method = false; }
|
||
|
else { f.write_all(b"else ").unwrap(); }
|
||
|
writeln!(f, "if (__method == \"{}\") {{", method.name).unwrap();
|
||
|
if method.ret_stream {
|
||
|
writeln!(f, "auto __stream = std::make_shared<MRPCStream<{}>>(&__c, __id);",
|
||
|
field_ty_to_ty_str(method.ret.as_ref().unwrap(), false)).unwrap();
|
||
|
f.write_all(b"{ std::lock_guard guard{__streams_mutex}; __streams.emplace(&__c, __stream); }\n").unwrap();
|
||
|
}
|
||
|
|
||
|
for arg in &method.args {
|
||
|
let ty = field_ty_to_ty_str(&arg, false);
|
||
|
if arg.optional {
|
||
|
writeln!(f, "{0} {1} = json_get_opt<{2}>(__data, \"{1}\");", ty, arg.name, field_ty_to_ty_str(arg, true)).unwrap();
|
||
|
} else {
|
||
|
writeln!(f, "{0} {1} = __data.at(\"{1}\");", ty, arg.name).unwrap();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
let args = method.args.iter()
|
||
|
.map(|arg| format!("std::move({})", arg.name))
|
||
|
.chain(method.ret_stream.then_some(String::from("std::move(__stream)")))
|
||
|
.collect::<Vec<_>>();
|
||
|
|
||
|
if method.ret.is_none() || method.ret_stream {
|
||
|
writeln!(f, "{}_{}({});", service.name, method.name, args.join(", ")).unwrap();
|
||
|
} else {
|
||
|
writeln!(f, "auto __ret = {}_{}({});", service.name, method.name, args.join(", ")).unwrap();
|
||
|
f.write_all(b"send_msg(__c, __id, __ret);\n").unwrap();
|
||
|
}
|
||
|
|
||
|
f.write_all(b"} ").unwrap();
|
||
|
}
|
||
|
|
||
|
f.write_all(b"else { throw std::exception{}; }\n} ").unwrap();
|
||
|
}
|
||
|
|
||
|
f.f.write_all(b"else { throw std::exception{}; }
|
||
|
} catch (const std::exception &_) {
|
||
|
std::cerr << \"Got invalid request \" << __id << \" for \" << __service << \"::\" << __method << std::endl;
|
||
|
}
|
||
|
}\n\n").unwrap();
|
||
|
}
|
||
|
|
||
|
pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) {
|
||
|
output_header(&mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("h")).unwrap()), rpc);
|
||
|
output_cpp(
|
||
|
&mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("cpp")).unwrap()),
|
||
|
file_base_name.with_extension("h").file_name().unwrap().to_string_lossy().to_string(),
|
||
|
rpc
|
||
|
);
|
||
|
}
|