diff --git a/src/generators/cpp_s.rs b/src/generators/cpp_s.rs index 51b83f7..14be085 100644 --- a/src/generators/cpp_s.rs +++ b/src/generators/cpp_s.rs @@ -1,6 +1,21 @@ use itertools::Itertools; use crate::data::RPC; +pub const JSON_INNER_IMPLS: &[(&str, &str)] = &[ + ("std::string", "String"), + ("std::int8_t", "Int"), + ("std::int16_t", "Int"), + ("std::int32_t", "Int"), + ("std::int64_t", "Int64"), + ("std::uint8_t", "Uint"), + ("std::uint16_t", "Uint"), + ("std::uint32_t", "Uint"), + ("std::uint64_t", "Uint64"), + ("bool", "Bool"), + ("std::float_t", "Double"), + ("std::double_t", "Double") +]; + pub fn ty_to_str(ty: &crate::data::Types) -> String { use crate::data::Types; match &ty { @@ -22,11 +37,10 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { } } - pub fn method_args(method: &crate::data::MethodTy) -> String { method.args.iter() .map(|arg| format!("{} &&{}", ty_to_str(&arg.ty), arg.name)) - .chain(method.ret_stream.then(|| format!("std::shared_ptr>&&", ty_to_str(method.ret.as_ref().unwrap())))) + .chain(method.ret_stream.then(|| format!("MRPCStream<{}>&&", ty_to_str(method.ret.as_ref().unwrap())))) .join(", ") } @@ -76,6 +90,18 @@ pub fn json_write(ty: &crate::data::FieldTy) -> String { } } +pub fn streams_required(rpc: &RPC) -> Vec { + let mut streams = std::collections::HashSet::new(); + for s in &rpc.services { + for m in &s.methods { + if m.ret_stream { + streams.insert(ty_to_str(m.ret.as_ref().unwrap())); + } + } + } + streams.into_iter().collect() +} + pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) { let header_name = file_base_name.with_extension("h"); let header_name = header_name.file_name().unwrap().to_str().unwrap(); diff --git a/src/generators/ts_c.rs b/src/generators/ts_c.rs index 8ad42f4..5117918 100644 --- a/src/generators/ts_c.rs +++ b/src/generators/ts_c.rs @@ -19,7 +19,7 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { pub fn method_args(method: &crate::data::MethodTy) -> String { method.args.iter() .map(|arg| format!("{}: {}", arg.name, ty_to_str(&arg.ty))) - .chain(method.ret_stream.then(|| format!("__cbk: (v: {}) => void", ty_to_str(method.ret.as_ref().unwrap())))) + .chain(method.ret_stream.then(|| format!("__cbk: (v: {}|null) => void", ty_to_str(method.ret.as_ref().unwrap())))) .join(", ") } diff --git a/templates/cpp_server.rs.cpp b/templates/cpp_server.rs.cpp index 6f57066..4fdb6cc 100644 --- a/templates/cpp_server.rs.cpp +++ b/templates/cpp_server.rs.cpp @@ -4,80 +4,93 @@ @(header_name: &str, rpc: &RPC) #include "@header_name" +#include +#include +#include + @:cpp_server_json_cpp(rpc) -template -void send_msg(crow::websocket::connection &c, std::uint64_t id, const T &v) @{ +template +void send_msg(const std::shared_ptr &c, const T &v) @{ + if (c->is_closed()) + return; rapidjson::StringBuffer s; mrpc::MRPCJWriter writer@{s@}; - writer.StartObject(); - writer.Key("id"); - writer.Uint64(id); - writer.Key("data"); - if constexpr (std::is_same_v) - writer.Null(); - else - v >> writer; - writer.EndObject(); - c.send_text(s.GetString()); + v >> writer; + const auto body_ptr = s.GetString(); + const auto body = restbed::Bytes@{body_ptr, body_ptr+s.GetLength()@}; + c->yield( + 200, + body, + std::multimap@{ + @{"Content-Type", "application/json"@}, + @{"Content-Length", std::to_string(body.size())@} + @} + ); @} -void mrpc::MRPCStreamImpl::close() noexcept @{ - if (conn != nullptr) @{ - send_msg(*conn, id, nullptr); - conn = nullptr; - @} +template +void send_sse_msg(const std::shared_ptr &c, const T &v) @{ + if (c->is_closed()) + return; + rapidjson::StringBuffer s; + std::memcpy(s.Push(5), "data:", 5); + mrpc::MRPCJWriter writer@{s@}; + v >> writer; + std::memcpy(s.Push(2), "\n\n", 2); + const auto body_ptr = s.GetString(); + const auto body = restbed::Bytes@{body_ptr, body_ptr+s.GetLength()@}; + c->yield(body); @} -void mrpc::MRPCStreamImpl::abort() noexcept @{ conn = nullptr; @} -bool mrpc::MRPCStreamImpl::is_open() noexcept @{ return conn != nullptr; @} -void mrpc::MRPCServer::install(crow::SimpleApp &app, std::string &&route) @{ - app.route_dynamic(std::move(route)) - .websocket() - .onclose([&](crow::websocket::connection &c, const std::string&)@{ - std::lock_guard guard@{__streams_mutex@}; - auto range = __streams.equal_range(&c); - for (auto it = range.first; it != range.second; ++it) - it->second->abort(); - __streams.erase(&c); - @}) - .onmessage([this](auto &&a, auto &&b, auto &&c) @{ - try @{ msg_handler(a, b, c); @} - catch (const std::exception &_) @{@} +mrpc::MRPCStreamImpl::MRPCStreamImpl(const std::shared_ptr &conn) : conn(conn) @{ + conn->yield( + 200, + std::multimap@{ + @{"Cache-Control", "no-cache"@}, + @{"Content-Type", "text/event-stream"@} + @} + ); +@} + +void mrpc::MRPCStreamImpl::close() const noexcept @{ conn->close("data:null\n\n"); @} +bool mrpc::MRPCStreamImpl::is_open() const noexcept @{ return conn->is_open(); @} +@for s in streams_required(rpc) {template<> void MRPCStream::send(const mrpc::@s &v) const noexcept @{ send_sse_msg(conn, v); @} +} + +mrpc::MRPCServer::MRPCServer(std::shared_ptr &r) @{ + r->set_method_handler("POST", [this](const std::shared_ptr& s) @{ + const auto req = s->get_request(); + const auto body_len = req->get_header("Content-Length", 0); + s->fetch(body_len, [this](const std::shared_ptr &s, auto &&body) @{ + try @{ msg_handler(s, body); @} + catch (const std::exception &_) @{ s->close(400); @} @}); + @}); @} -void mrpc::MRPCServer::msg_handler(crow::websocket::connection &__c, std::string __msg, bool) @{ + +void mrpc::MRPCServer::msg_handler(const std::shared_ptr __c, const restbed::Bytes &__msg) @{ rapidjson::Document __j; - __j.ParseInsitu(__msg.data()); + __j.Parse((const char*)__msg.data(), __msg.size()); if (__j.HasParseError()) throw std::exception@{@}; - std::uint64_t __id; std::string __service, __method; - json_get(__j, "id", __id); json_get(__j, "service", __service); json_get(__j, "method", __method); - try @{ - auto __data_member = __j.FindMember("data"); - if (__data_member == __j.MemberEnd() || !__data_member->value.IsObject()) - throw std::exception@{@}; - auto &__data = __data_member->value; -@for (si, s) in rpc.services.iter().enumerate() { - @if si > 0 {else }if (__service == "@s.name") @{ -@for (mi, m) in s.methods.iter().enumerate() { - @if mi > 0 {else }if (__method == "@m.name") @{ - @if m.ret_stream { - auto __stream = std::make_shared>(&__c, __id); - @{ std::lock_guard guard@{__streams_mutex@}; __streams.emplace(&__c, __stream); @} - } -@for (name, ty) in m.args.iter().map(|a| (&a.name, ty_to_str(&a.ty))) { @ty @name; json_get<@ty>(__data, "@name", @name); - } - @if m.ret_stream || m.ret.is_none() {@(s.name)_@(m.name)(@call_args(m));} - else {send_msg(__c, __id, @(s.name)_@(m.name)(@call_args(m)));} - @} -} - else @{ throw std::exception@{@}; @} - @} -} + std::string __service, __method; + json_get(__j, "service", __service); json_get(__j, "method", __method); + auto __data_member = __j.FindMember("data"); + if (__data_member == __j.MemberEnd() || !__data_member->value.IsObject()) + throw std::exception@{@}; + auto &__data = __data_member->value; +@for (si, s) in rpc.services.iter().enumerate() {@if si > 0 { else }else{ }if (__service == "@s.name") @{ +@for (mi, m) in s.methods.iter().enumerate() {@if mi > 0 { else }else{ }if (__method == "@m.name") @{ + @if m.ret_stream {auto __stream = MRPCStream<@ty_to_str(m.ret.as_ref().unwrap())>@{__c@}; + } +@for (name, ty) in m.args.iter().map(|a| (&a.name, ty_to_str(&a.ty))) { @ty @name; json_get<@ty>(__data, "@name", @name); + } + @if m.ret_stream || m.ret.is_none() {@(s.name)_@(m.name)(@call_args(m));} + else {send_msg(__c, @(s.name)_@(m.name)(@call_args(m)));} + @}} else @{ throw std::exception@{@}; @} - @} catch (const std::exception &_) @{ - std::cerr << "Got invalid request " << __id << " for " << __service << "::" << __method << std::endl; - @} + @}} + else @{ throw std::exception@{@}; @} @} @} diff --git a/templates/cpp_server.rs.h b/templates/cpp_server.rs.h index 8a43e6c..de90709 100644 --- a/templates/cpp_server.rs.h +++ b/templates/cpp_server.rs.h @@ -7,27 +7,31 @@ #ifndef MRPC_GEN_H #define MRPC_GEN_H -#include #include -#include -#include #include +#include +#include #include -#include +#include +#include #define RAPIDJSON_HAS_STDSTRING 1 #include #include #include +namespace restbed @{ + class Resource; + class Session; +@} + namespace mrpc @{ - using MRPCJWriter = rapidjson::Writer; +using MRPCJWriter = rapidjson::Writer; @for e in &rpc.enums { enum struct @e.name : std::uint64_t @{ @e.values.iter().map(|(k,v)| format!("{k} = {v}")).join(",\n ") @}; } -@for s in &rpc.structs { -struct @s.name; +@for s in &rpc.structs {struct @s.name; } @for s in &rpc.structs { struct @s.name @{ @@ -35,51 +39,30 @@ struct @s.name @{ } MRPCJWriter& operator >>(MRPCJWriter&) const; @(s.name)& operator <<(const rapidjson::Value&); -@}; -} - +@};} +@if streams_required(rpc).len() > 0 { struct MRPCStreamImpl @{ - virtual void close() noexcept final; - virtual void abort() noexcept final; - virtual bool is_open() noexcept final; + void close() const noexcept; + bool is_open() const noexcept; protected: - MRPCStreamImpl(crow::websocket::connection *conn, uint64_t id) : conn(conn), id(id) @{@} - crow::websocket::connection* conn; - std::uint64_t id; + explicit MRPCStreamImpl(const std::shared_ptr &conn); + std::shared_ptr conn; @}; -template +template struct MRPCStream final : MRPCStreamImpl @{ - MRPCStream(crow::websocket::connection *conn, uint64_t id) : MRPCStreamImpl(conn, id) @{@} - bool send(const T &v) noexcept @{ - if (!conn) return false; - try @{ - rapidjson::StringBuffer s; - mrpc::MRPCJWriter writer@{s@}; - writer.StartObject(); - writer.Key("id"); - writer.Uint64(id); - writer.Key("data"); - v >> writer; - writer.EndObject(); - conn->send_text(s.GetString()); - @} catch (const std::exception &_) @{ - abort(); - return false; - @} - return true; - @} + explicit MRPCStream(const std::shared_ptr &conn) : MRPCStreamImpl(conn) @{@} + void send(const T &v) const noexcept; @}; +@for s in streams_required(rpc) {template struct MRPCStream<@(s)>; +}} struct MRPCServer @{ - virtual void install(crow::SimpleApp &app, std::string &&route) final; + explicit MRPCServer(std::shared_ptr&); private: @for s in &rpc.services {@for m in &s.methods { virtual @method_ret(m) @(s.name)_@(m.name)(@method_args(m)) = 0; }} - virtual void msg_handler(crow::websocket::connection&, std::string, bool) final; - - std::mutex __streams_mutex; - std::unordered_multimap> __streams; + virtual void msg_handler(std::shared_ptr, const restbed::Bytes&) final; @}; @} diff --git a/templates/cpp_server_json.rs.cpp b/templates/cpp_server_json.rs.cpp index a67f3ec..189409d 100644 --- a/templates/cpp_server_json.rs.cpp +++ b/templates/cpp_server_json.rs.cpp @@ -6,57 +6,12 @@ template void json_get(const rapidjson::Value &j, const char *key, T &v); template void json_get_inner(const rapidjson::Value&, T &v) = delete; - -template<> -inline void json_get_inner(const rapidjson::Value &member, std::string &v) @{ - if (!member.IsString()) +@for (ty, jty) in JSON_INNER_IMPLS { +template<> inline void json_get_inner(const rapidjson::Value &member, @ty &v) @{ + if (!member.Is@(jty)()) throw std::exception@{@}; - v = member.GetString(); -@} - -@for i in [8, 16, 32] { -template<> -inline void json_get_inner(const rapidjson::Value &member, std::int@(i)_t &v) @{ - if (!member.IsInt()) - throw std::exception@{@}; - v = member.GetInt(); -@} - -template<> -inline void json_get_inner(const rapidjson::Value &member, std::uint@(i)_t& v) @{ - if (!member.IsUint()) - throw std::exception@{@}; - v = member.GetUint(); -@} -} - -template<> -inline void json_get_inner(const rapidjson::Value &member, std::int64_t &v) @{ - if (!member.IsInt64()) - throw std::exception@{@}; - v = member.GetInt64(); -@} - -template<> -inline void json_get_inner(const rapidjson::Value &member, std::uint64_t& v) @{ - if (!member.IsUint64()) - throw std::exception@{@}; - v = member.GetUint64(); -@} - -template<> -inline void json_get_inner(const rapidjson::Value &member, bool &v) @{ - if (!member.IsBool()) - throw std::exception@{@}; - v = member.GetBool(); -@} - -template<> -inline void json_get_inner(const rapidjson::Value &member, double &v) @{ - if (!member.IsDouble()) - throw std::exception@{@}; - v = member.GetDouble(); -@} + v = member.Get@(jty)(); +@}} template inline void json_get_inner(const rapidjson::Value &member, std::optional &v) @{ @@ -80,18 +35,15 @@ inline void json_get_inner(const rapidjson::Value &member, std::vector &v) @{ @} @} - @for s in &rpc.structs { -template<> -inline void json_get_inner(const rapidjson::Value &__j, mrpc::@s.name &v) @{ +template<> inline void json_get_inner(const rapidjson::Value &__j, mrpc::@s.name &v) @{ using namespace mrpc; @for f in &s.fields { json_get<@ty_to_str(&f.ty)>(__j, "@f.name", v.@f.name); }@} } @for e in &rpc.enums { -template<> -inline void json_get_inner(const rapidjson::Value &j, mrpc::@e.name &v) @{ +template<> inline void json_get_inner(const rapidjson::Value &j, mrpc::@e.name &v) @{ json_get_inner(j, (std::uint64_t&)v); @} mrpc::MRPCJWriter& operator >>(const mrpc::@e.name &v, mrpc::MRPCJWriter &w) @{ @@ -112,10 +64,10 @@ namespace mrpc @{ @for s in &rpc.structs { MRPCJWriter& @(s.name)::operator >>(MRPCJWriter &__w) const @{ __w.StartObject(); -@for f in &s.fields { __w.Key("@f.name"); +@for f in &s.fields { __w.Key("@f.name", @f.name.len()); @json_write(&f) } __w.EndObject(); return __w; @} @(s.name)& @(s.name)::operator <<(const rapidjson::Value &__j) @{ json_get_inner<@(s.name)>(__j, *this); return *this; @} -} +} \ No newline at end of file diff --git a/templates/typescript_client.rs.ts b/templates/typescript_client.rs.ts index f959d7d..45fa032 100644 --- a/templates/typescript_client.rs.ts +++ b/templates/typescript_client.rs.ts @@ -3,79 +3,41 @@ @use crate::generators::ts_c::*; @(rpc: &RPC) +import @{ fetchEventSource @} from '@@microsoft/fetch-event-source'; @for e in &rpc.enums { export enum @e.name @{ @for (k,v) in &e.values { @k = @v, -} -@} +}@} } @for s in &rpc.structs { export interface @s.name @{ @for f in &s.fields { @f.name: @ty_to_str(&f.ty); +}@} } -@} -} - -interface _WSResponse @{ - id: number; - data: any; -@} - -interface _WSWaitingEntry @{ - ok: (v: any) => void; - err: (reason?: any) => void; -@} export class MRPCConnector @{ url: string; - socket: WebSocket; - nmi: number; - waiting: @{ [id: number]: _WSWaitingEntry @}; - streams: @{ [id: number]: (v: any) => void @}; - - private open() @{ - this.socket = new WebSocket(this.url); - this.socket.onmessage = ev => @{ - const data = JSON.parse(ev.data) as _WSResponse; - if (data.id in this.streams) @{ - this.streams[data.id](data.data); - if (data.data == null) - delete this.streams[data.id]; - @} else if (data.id in this.waiting) @{ - this.waiting[data.id].ok(data.data); - delete this.waiting[data.id]; - @} else @{ - console.log(`Got unexpected message: $@{data@}`); - @} - @} - this.socket.onerror = () => setTimeout(() => @{this.open();@}, 2500); - this.socket.onclose = () => setTimeout(() => @{this.open();@}, 2500); - @} - - private get_prom(id: number): Promise @{ - return new Promise((ok, err) => @{ this.waiting[id] = @{ok, err@}; @}); - @} public constructor(url: string) @{ this.url = url; - this.nmi = 0; - this.waiting = @{@}; - this.streams = @{@}; - this.open(); @} @for s in &rpc.services { @for m in &s.methods { public @(s.name)_@(m.name)(@method_args(m))@method_ret(m) @{ const __msg = @{ - id: this.nmi++, service: '@s.name', method: '@m.name', data: @{@m.args.iter().map(|a| &a.name).join(",")@} @}; - @if m.ret.is_some() && !m.ret_stream {const __p = this.get_prom<@ty_to_str(m.ret.as_ref().unwrap())>(__msg.id);} - else if m.ret_stream {this.streams[__msg.id] = __cbk;} - this.socket.send(JSON.stringify(__msg)); - @if m.ret.is_some() && !m.ret_stream {return __p;} + @if m.ret.is_some() && !m.ret_stream {return fetch(this.url, @{ + method: 'POST', + body: JSON.stringify(__msg) + @}).then((__r) => __r.json());} + else if m.ret_stream {fetchEventSource(this.url, @{ + method: 'POST', + body: JSON.stringify(__msg), + onmessage: __e => __cbk(JSON.parse(__e.data)) + @});} else {fetch(this.url, @{method: 'POST', body: JSON.stringify(__msg)@});} @} }} @}