diff --git a/src/data.rs b/src/data.rs index dff645f..c613800 100644 --- a/src/data.rs +++ b/src/data.rs @@ -8,7 +8,7 @@ pub enum Types { Array(Box), Optional(Box), Named(String), - Generic(String, Box, proc_macro2::Span) + Generic(String, Vec) } #[derive(Debug, Clone, Default)] @@ -27,8 +27,7 @@ pub struct FieldTy { pub struct StructTy { pub name: String, pub fields: Vec, - pub generic_fields: Vec, - pub generic_name: Option + pub generic_names: Vec } #[derive(Debug, Clone, Default)] diff --git a/src/generators/cpp_s.rs b/src/generators/cpp_s.rs index 51e9baa..b9e21a6 100644 --- a/src/generators/cpp_s.rs +++ b/src/generators/cpp_s.rs @@ -18,7 +18,7 @@ pub const JSON_INNER_IMPLS: &[(&str, &str)] = &[ pub fn ty_to_str(ty: &crate::data::Types) -> String { use crate::data::Types; - match &ty { + match ty { Types::String => "std::string".into(), Types::Bool => "bool".into(), Types::F32 => "std::float_t".into(), @@ -34,10 +34,35 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { Types::Named(name) => name.into(), Types::Optional(inner) => format!("std::optional<{}>", ty_to_str(inner)), Types::Array(inner) => format!("std::vector<{}>", ty_to_str(inner)), - Types::Generic(_, _, _) => unreachable!() + Types::Generic(name, types) => + format!("{}<{}>", name, types.iter().map(|ty| ty_to_str(ty)).join(", ")) } } +pub fn get_struct_generics(s: &crate::data::StructTy) -> String { + if s.generic_names.is_empty() { + "".into() + } else { + format!("template<{}>\n", generics_brace_inner_typename(s)) + } +} + +pub fn generics_brace_inner_typename(s: &crate::data::StructTy) -> String { + s.generic_names.iter().map(|n| String::from("typename ") + n).join(", ") +} + +pub fn generics_brace(s: &crate::data::StructTy) -> String { + if s.generic_names.is_empty() { + "".into() + } else { + format!("<{}>", generics_brace_inner(s)) + } +} + +pub fn generics_brace_inner(s: &crate::data::StructTy) -> String { + s.generic_names.iter().join(", ") +} + pub fn method_args(method: &crate::data::MethodTy) -> String { method.args.iter() .map(|arg| format!("{} &&{}", ty_to_str(&arg.ty), arg.name)) @@ -60,38 +85,6 @@ pub fn call_args(method: &crate::data::MethodTy) -> String { .join(", ") } -pub fn json_write(ty: &crate::data::FieldTy) -> String { - use crate::data::Types; - match &ty.ty { - Types::String => format!("__w.String({});", ty.name), - Types::Bool => format!("__w.Bool({});", ty.name), - Types::F32 | Types::F64 => format!("__w.Double({});", ty.name), - Types::I8 | Types::I16 | Types::I32 | Types::I64 => format!("__w.Int64({});", ty.name), - Types::U8 | Types::U16 | Types::U32 | Types::U64 => format!("__w.Uint64({});", ty.name), - Types::Named(_) => format!("{} >> __w;", ty.name), - Types::Optional(inner) => { - let inner = crate::data::FieldTy { name: format!("({}.value())", ty.name), ty: (**inner).clone() }; - let inner = json_write(&inner); - format!( -"if ({}.has_value()) {{ - {} - }} else __w.Null();", ty.name, inner) - }, - Types::Array(inner) => { - let inner_var_name = format!("__{}__entry", ty.name); - let inner = crate::data::FieldTy { name: inner_var_name.clone(), ty: (**inner).clone() }; - let inner = json_write(&inner); - format!( -"__w.StartArray(); - for (const auto &{} : {}) {{ - {} - }} - __w.EndArray();", inner_var_name, ty.name, inner) - }, - Types::Generic(_, _, _) => unreachable!() - } -} - pub fn streams_required(rpc: &RPC) -> Vec { let mut streams = std::collections::HashSet::new(); for s in &rpc.services { @@ -105,10 +98,10 @@ pub fn streams_required(rpc: &RPC) -> Vec { } pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) { - let header_name = file_base_name.with_extension("h"); + let header_name = file_base_name.with_extension("hxx"); + let h = std::fs::File::create(&header_name).unwrap(); let header_name = header_name.file_name().unwrap().to_str().unwrap(); - let h = std::fs::File::create(file_base_name.with_extension("h")).unwrap(); - let c = std::fs::File::create(file_base_name.with_extension("cpp")).unwrap(); - crate::templates::cpp_server_h(h, rpc).unwrap(); - crate::templates::cpp_server_cpp(c, header_name, rpc).unwrap(); + let c = std::fs::File::create(file_base_name.with_extension("cxx")).unwrap(); + crate::templates::cpp_server_hxx(h, rpc).unwrap(); + crate::templates::cpp_server_cxx(c, header_name, rpc).unwrap(); } diff --git a/src/generators/ts_c.rs b/src/generators/ts_c.rs index 8656e4f..3cb0276 100644 --- a/src/generators/ts_c.rs +++ b/src/generators/ts_c.rs @@ -13,7 +13,16 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { Types::Named(name) => name.into(), Types::Optional(inner) => format!("({}|null)", ty_to_str(inner)), Types::Array(inner) => format!("{}[]", ty_to_str(inner)), - Types::Generic(_, _, _) => unreachable!() + Types::Generic(name, types) => + format!("{}<{}>", name, types.iter().map(|ty| ty_to_str(ty)).join(", ")) + } +} + +pub fn get_struct_generics(s: &crate::data::StructTy) -> String { + if s.generic_names.is_empty() { + "".into() + } else { + format!("<{}>", s.generic_names.join(", ")) } } diff --git a/src/main.rs b/src/main.rs index 4df6a47..4aaf981 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ mod generators; mod templates; mod parser; -use std::collections::HashMap; use std::fmt::Write; use std::fs::File; use std::io::Read; @@ -30,95 +29,6 @@ struct GenArgs { static SOURCE_FILE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); static SOURCE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); -fn has_generic(ty: &mut data::Types) -> Option<&mut data::Types> { - use data::Types; - match ty { - Types::String | Types::Bool | - Types::F32 | Types::F64 | - Types::I8 | Types::I16 | Types::I32 | Types::I64 | - Types::U8 | Types::U16 | Types::U32 | Types::U64 => None, - Types::Array(inner) => has_generic(inner), - Types::Optional(inner) => has_generic(inner), - Types::Named(_) => None, - Types::Generic(_, _, _) => Some(ty) - } -} - -pub fn gen_generic_name(ty: &data::Types) -> String { - use data::Types; - match &ty { - Types::String => "string".into(), - Types::Bool => "bool".into(), - Types::F32 => "f32".into(), - Types::F64 => "f64".into(), - Types::I8 => "i8".into(), - Types::I16 => "i16".into(), - Types::I32 => "i32".into(), - Types::I64 => "i64".into(), - Types::U8 => "u8".into(), - Types::U16 => "u16".into(), - Types::U32 => "u32".into(), - Types::U64 => "u64".into(), - Types::Named(name) => name.clone(), - Types::Optional(inner) => format!("maybe_{}", gen_generic_name(inner)), - Types::Array(inner) => format!("array_{}", gen_generic_name(inner)), - Types::Generic(_, _, span) => emit_error(Some((span.clone(), "Nested custom generics are not allowed"))) - } -} -fn replace_generics<'a>(it: impl Iterator, generics: &HashMap, new_structs: &mut HashMap) { - for ty in it.filter_map(|ty| has_generic(ty)) { - let (name, inner, span) = match ty { - data::Types::Generic(name, inner, span) => - (name.clone(), inner.clone(), span.clone()), - _ => unreachable!() - }; - let generic = match generics.get(&name) { - Some(v) => v, - None => emit_error(Some((span, "Type does not exists"))) - }; - let new_name = format!("{}_{}", name, gen_generic_name(&inner)); - *ty = data::Types::Named(new_name.clone()); - new_structs.insert(new_name, (generic.clone(), *inner)); - } -} - -fn replace_generic_type(ty: &mut data::Types, generic_name: &String, replacement: &data::Types) { - use data::Types; - match ty { - Types::Named(name) => if name == generic_name { *ty = replacement.clone(); } else { unreachable!() }, - Types::Optional(inner) => replace_generic_type(inner, generic_name, replacement), - Types::Array(inner) => replace_generic_type(inner, generic_name, replacement), - _ => unreachable!() - } -} - -fn resolve_generics(rpc: &mut data::RPC) { - let mut generics: HashMap = HashMap::new(); - rpc.structs.retain(|s| if s.generic_name.is_none() { true } else { - generics.insert(s.name.clone(), s.clone()); - false - }); - - let mut new_structs = HashMap::new(); - for s in &mut rpc.structs { - replace_generics(s.fields.iter_mut().map(|f| &mut f.ty), &generics, &mut new_structs); - } - for m in &mut rpc.services.iter_mut().map(|s| s.methods.iter_mut()).flatten() { - replace_generics(m.args.iter_mut().map(|f| &mut f.ty).chain(&mut m.ret), &generics, &mut new_structs); - } - - for (name, (mut generic, inner)) in new_structs { - generic.name = name; - let generic_name = generic.generic_name.take().unwrap(); - for mut field in generic.generic_fields { - replace_generic_type(&mut field.ty, &generic_name, &inner); - generic.fields.push(field); - } - generic.generic_fields = vec![]; - rpc.structs.push(generic); - } -} - fn main() { let args = Args::parse(); @@ -141,9 +51,7 @@ fn main() { ) }; - let mut rpc = parser::parse_file(&ast); - - resolve_generics(&mut rpc); + let rpc = parser::parse_file(&ast); for gen in &args.generators.clients { gen.generate(&args.rpc_name, &rpc); diff --git a/src/parser.rs b/src/parser.rs index 8d8cbd3..d81047b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -40,20 +40,19 @@ fn parse_type(item: &syn::Type) -> data::Types { syn::PathArguments::AngleBracketed(v) => v, _ => emit_error(vec![(item.span(), "Angle bracketed arguments expected")]) }; - if args.args.len() != 1 { - emit_error(vec![(item.span(), "Expected 1 argument")]); - } - let ty = match &args.args[0] { + let types = args.args.iter().map(|arg| match arg { syn::GenericArgument::Type(v) => parse_type(v), _ => emit_error(vec![(item.span(), "Only types are supported")]) - }; + }).collect::>(); let name = segment.ident.to_string(); if name == "Option" { - data::Types::Optional(ty.into()) + if types.len() != 1 { emit_error(Some((segment.span(), "Option needs exactly one argument"))); } + data::Types::Optional(types[0].clone().into()) } else if name == "Vec" { - data::Types::Array(ty.into()) + if types.len() != 1 { emit_error(Some((segment.span(), "Vec needs exactly one argument"))); } + data::Types::Array(types[0].clone().into()) } else { - data::Types::Generic(name, ty.into(), segment.ident.span()) + data::Types::Generic(name, types) } } else { parse_type_string(segment.ident.to_string()) @@ -66,20 +65,6 @@ fn parse_type(item: &syn::Type) -> data::Types { } } -fn has_generic(ty: &data::Types, generic: &String) -> bool { - use data::Types; - match ty { - Types::String | Types::Bool | - Types::F32 | Types::F64 | - Types::I8 | Types::I16 | Types::I32 | Types::I64 | - Types::U8 | Types::U16 | Types::U32 | Types::U64 => false, - Types::Array(inner) => has_generic(inner, generic), - Types::Optional(inner) => has_generic(inner, generic), - Types::Named(name) => name == generic, - Types::Generic(name, inner, _) => name == generic || has_generic(inner, generic) - } -} - fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { let name = item.ident.to_string(); if let Some(v) = &item.generics.where_clause { @@ -88,7 +73,7 @@ fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { if item.generics.params.len() > 1 { emit_error(Some((item.generics.params.span(), "Only one generic parameter is allowed for now"))); } - let generic_name = item.generics.params.first().map(|g| { + let generic_names = item.generics.params.iter().map(|g| { match g { syn::GenericParam::Const(_) | syn::GenericParam::Lifetime(_) => emit_error(Some((g.span(), "Only generic types are allowed"))), @@ -102,23 +87,16 @@ fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { ty.ident.to_string() } } - }); - let mut generic_fields = vec![]; - let mut fields = vec![]; - for field in &item.fields { + }).collect(); + let fields = item.fields.iter().map(|field| { if field.ident.is_none() { emit_error(vec![(field.span(), "Missing field name")]); } let name = field.ident.as_ref().unwrap().to_string(); let ty = parse_type(&field.ty); - let ty = data::FieldTy { name, ty }; - if generic_name.is_some() && has_generic(&ty.ty, generic_name.as_ref().unwrap()) { - generic_fields.push(ty); - } else { - fields.push(ty); - } - } - data::StructTy { name, fields, generic_fields, generic_name } + data::FieldTy { name, ty } + }).collect(); + data::StructTy { name, fields, generic_names } } fn try_parse_iterator(ty: &syn::Type) -> Option { diff --git a/templates/cpp_server.rs.cpp b/templates/cpp_server.rs.cxx similarity index 90% rename from templates/cpp_server.rs.cpp rename to templates/cpp_server.rs.cxx index 4fdb6cc..8866f64 100644 --- a/templates/cpp_server.rs.cpp +++ b/templates/cpp_server.rs.cxx @@ -1,6 +1,6 @@ @use crate::data::RPC; @use crate::generators::cpp_s::*; -@use super::cpp_server_json_cpp; +@use super::cpp_server_json_cxx; @(header_name: &str, rpc: &RPC) #include "@header_name" @@ -8,7 +8,9 @@ #include #include -@:cpp_server_json_cpp(rpc) +using namespace mrpc; + +@:cpp_server_json_cxx(rpc) template void send_msg(const std::shared_ptr &c, const T &v) @{ @@ -55,7 +57,7 @@ mrpc::MRPCStreamImpl::MRPCStreamImpl(const std::shared_ptr &co void mrpc::MRPCStreamImpl::close() const noexcept @{ conn->close("data:null\n\n"); @} bool mrpc::MRPCStreamImpl::is_open() const noexcept @{ return conn->is_open(); @} -@for s in streams_required(rpc) {template<> void MRPCStream::send(const mrpc::@s &v) const noexcept @{ send_sse_msg(conn, v); @} +@for s in streams_required(rpc) {template<> void MRPCStream<@s>::send(const @s &v) const noexcept @{ send_sse_msg(conn, v); @} } mrpc::MRPCServer::MRPCServer(std::shared_ptr &r) @{ @@ -75,7 +77,8 @@ void mrpc::MRPCServer::msg_handler(const std::shared_ptr __c, if (__j.HasParseError()) throw std::exception@{@}; std::string __service, __method; - json_get(__j, "service", __service); json_get(__j, "method", __method); + __service << json_get(__j, "service"); + __method << json_get(__j, "method"); auto __data_member = __j.FindMember("data"); if (__data_member == __j.MemberEnd() || !__data_member->value.IsObject()) throw std::exception@{@}; @@ -84,7 +87,7 @@ void mrpc::MRPCServer::msg_handler(const std::shared_ptr __c, @for (mi, m) in s.methods.iter().enumerate() {@if mi > 0 { else }else{ }if (__method == "@m.name") @{ @if m.ret_stream {auto __stream = MRPCStream<@ty_to_str(m.ret.as_ref().unwrap())>@{__c@}; } -@for (name, ty) in m.args.iter().map(|a| (&a.name, ty_to_str(&a.ty))) { @ty @name; json_get<@ty>(__data, "@name", @name); +@for (name, ty) in m.args.iter().map(|a| (&a.name, ty_to_str(&a.ty))) { @ty @name; @name << json_get(__data, "@name"); } @if m.ret_stream || m.ret.is_none() {@(s.name)_@(m.name)(@call_args(m));} else {send_msg(__c, @(s.name)_@(m.name)(@call_args(m)));} diff --git a/templates/cpp_server.rs.h b/templates/cpp_server.rs.hxx similarity index 93% rename from templates/cpp_server.rs.h rename to templates/cpp_server.rs.hxx index de90709..a27c148 100644 --- a/templates/cpp_server.rs.h +++ b/templates/cpp_server.rs.hxx @@ -31,10 +31,10 @@ enum struct @e.name : std::uint64_t @{ @e.values.iter().map(|(k,v)| format!("{k} = {v}")).join(",\n ") @}; } -@for s in &rpc.structs {struct @s.name; +@for s in &rpc.structs {@get_struct_generics(s)struct @s.name; } @for s in &rpc.structs { -struct @s.name @{ +@get_struct_generics(s)struct @s.name @{ @for f in &s.fields { @ty_to_str(&f.ty) @f.name; } MRPCJWriter& operator >>(MRPCJWriter&) const; @@ -58,6 +58,7 @@ struct MRPCStream final : MRPCStreamImpl @{ @for s in streams_required(rpc) {template struct MRPCStream<@(s)>; }} struct MRPCServer @{ + MRPCServer() = delete; explicit MRPCServer(std::shared_ptr&); private: @for s in &rpc.services {@for m in &s.methods { virtual @method_ret(m) @(s.name)_@(m.name)(@method_args(m)) = 0; diff --git a/templates/cpp_server_json.rs.cpp b/templates/cpp_server_json.rs.cpp deleted file mode 100644 index 189409d..0000000 --- a/templates/cpp_server_json.rs.cpp +++ /dev/null @@ -1,73 +0,0 @@ -@use crate::data::RPC; -@use crate::generators::cpp_s::*; - -@(rpc: &RPC) -template -void json_get(const rapidjson::Value &j, const char *key, T &v); -template -void json_get_inner(const rapidjson::Value&, T &v) = delete; -@for (ty, jty) in JSON_INNER_IMPLS { -template<> inline void json_get_inner(const rapidjson::Value &member, @ty &v) @{ - if (!member.Is@(jty)()) - throw std::exception@{@}; - v = member.Get@(jty)(); -@}} - -template -inline void json_get_inner(const rapidjson::Value &member, std::optional &v) @{ - if (member.IsNull()) - v = std::nullopt; - else @{ - T t; - json_get_inner(member, t); - v = std::move(t); - @} -@} - -template -inline void json_get_inner(const rapidjson::Value &member, std::vector &v) @{ - if (!member.IsArray()) - throw std::exception@{@}; - for (const auto &j : member.GetArray()) @{ - T t; - json_get_inner(j, t); - v.push_back(std::move(t)); - @} -@} - -@for s in &rpc.structs { -template<> inline void json_get_inner(const rapidjson::Value &__j, mrpc::@s.name &v) @{ - using namespace mrpc; -@for f in &s.fields { json_get<@ty_to_str(&f.ty)>(__j, "@f.name", v.@f.name); -}@} -} - -@for e in &rpc.enums { -template<> inline void json_get_inner(const rapidjson::Value &j, mrpc::@e.name &v) @{ - json_get_inner(j, (std::uint64_t&)v); -@} -mrpc::MRPCJWriter& operator >>(const mrpc::@e.name &v, mrpc::MRPCJWriter &w) @{ - w.Uint64((std::uint64_t)v); - return w; -@} -} - -template -inline void json_get(const rapidjson::Value &j, const char *key, T &v) @{ - auto member = j.FindMember(key); - if (member == j.MemberEnd()) - throw std::exception@{@}; - json_get_inner(member->value, v); -@} - -namespace mrpc @{ -@for s in &rpc.structs { -MRPCJWriter& @(s.name)::operator >>(MRPCJWriter &__w) const @{ - __w.StartObject(); -@for f in &s.fields { __w.Key("@f.name", @f.name.len()); - @json_write(&f) -} __w.EndObject(); - return __w; -@} -@(s.name)& @(s.name)::operator <<(const rapidjson::Value &__j) @{ json_get_inner<@(s.name)>(__j, *this); return *this; @} -} \ No newline at end of file diff --git a/templates/cpp_server_json.rs.cxx b/templates/cpp_server_json.rs.cxx new file mode 100644 index 0000000..5360aac --- /dev/null +++ b/templates/cpp_server_json.rs.cxx @@ -0,0 +1,95 @@ +@use crate::data::RPC; +@use crate::generators::cpp_s::*; + +@(rpc: &RPC) +@for (ty, jty) in JSON_INNER_IMPLS { +inline @(ty)& operator<<(@ty &v, const rapidjson::Value &j) @{ + if (!j.Is@(jty)()) + throw std::exception@{@}; + v = j.Get@(jty)(); + return v; +@} +inline mrpc::MRPCJWriter& operator>>(const @ty &v, mrpc::MRPCJWriter &w) @{ + w.@(jty)(v); + return w; +@}} + +@for e in &rpc.enums { +inline mrpc::@e.name& operator<<(mrpc::@e.name &v, const rapidjson::Value &j) @{ + ((std::uint64_t&)v) << j; + return v; +@} +mrpc::MRPCJWriter& operator>>(const mrpc::@e.name &v, mrpc::MRPCJWriter &w) @{ + w.Uint64((std::uint64_t)v); + return w; +@} +} + +template +inline std::vector& operator<<(std::vector &v, const rapidjson::Value &j); +template +inline std::optional& operator<<(std::optional &v, const rapidjson::Value &j) @{ + if (j.IsNull()) + v = std::nullopt; + else @{ + T t; + t << j; + v = std::move(t); + @} + return v; +@} + +template +inline std::vector& operator<<(std::vector &v, const rapidjson::Value &j) @{ + if (!j.IsArray()) + throw std::exception@{@}; + for (const auto &e : j.GetArray()) @{ + T t; + t << e; + v.push_back(std::move(t)); + @} + return v; +@} + +template +inline mrpc::MRPCJWriter& operator>>(const std::vector &v, mrpc::MRPCJWriter &w); +template +inline mrpc::MRPCJWriter& operator>>(const std::optional &v, mrpc::MRPCJWriter &w) @{ + if (v.has_value()) + v.value() >> w; + else + w.Null(); + return w; +@} + +template +inline mrpc::MRPCJWriter& operator>>(const std::vector &v, mrpc::MRPCJWriter &w) @{ + w.StartArray(); + for (const auto &e : v) + e >> w; + w.EndArray(); + return w; +@} + +inline const rapidjson::Value& json_get(const rapidjson::Value &j, const char *key) @{ + auto member = j.FindMember(key); + if (member == j.MemberEnd()) + throw std::exception@{@}; + return member->value; +@} + +namespace mrpc @{ +@for s in &rpc.structs { +@get_struct_generics(s)MRPCJWriter& @(s.name)@(generics_brace(s))::operator>>(MRPCJWriter &__w) const @{ + __w.StartObject(); +@for f in &s.fields { __w.Key("@f.name", @f.name.len()); + @f.name >> __w; +} __w.EndObject(); + return __w; +@} +@get_struct_generics(s)@(s.name)@(generics_brace(s))& @(s.name)@(generics_brace(s))::operator<<(const rapidjson::Value &__j) @{ + using namespace mrpc; +@for f in &s.fields { @f.name << json_get(__j, "@f.name"); +} return *this; +@} +} \ No newline at end of file diff --git a/templates/typescript_client.rs.ts b/templates/typescript_client.rs.ts index 9d8ea11..505918b 100644 --- a/templates/typescript_client.rs.ts +++ b/templates/typescript_client.rs.ts @@ -10,7 +10,7 @@ export enum @e.name @{ }@} } @for s in &rpc.structs { -export interface @s.name @{ +export interface @s.name@get_struct_generics(s) @{ @for f in &s.fields { @f.name: @ty_to_str(&f.ty); }@} }