From 613cf68fc3bde8a0557862b5bb1ecffb21a97a5b Mon Sep 17 00:00:00 2001 From: Mutzi Date: Wed, 11 Oct 2023 17:10:20 +0200 Subject: [PATCH] Update code to allow struct with one generic parameter --- src/data.rs | 7 +- src/generators/cpp_s.rs | 6 +- src/generators/ts_c.rs | 3 +- src/main.rs | 245 +++++++++++++----------------- src/parser.rs | 200 ++++++++++++++++++++++++ templates/typescript_client.rs.ts | 10 +- 6 files changed, 318 insertions(+), 153 deletions(-) create mode 100644 src/parser.rs diff --git a/src/data.rs b/src/data.rs index 3e53f2e..dff645f 100644 --- a/src/data.rs +++ b/src/data.rs @@ -7,7 +7,8 @@ pub enum Types { U8, U16, U32, U64, Array(Box), Optional(Box), - Named(String) + Named(String), + Generic(String, Box, proc_macro2::Span) } #[derive(Debug, Clone, Default)] @@ -25,7 +26,9 @@ pub struct FieldTy { #[derive(Debug, Clone, Default)] pub struct StructTy { pub name: String, - pub fields: Vec + pub fields: Vec, + pub generic_fields: Vec, + pub generic_name: Option } #[derive(Debug, Clone, Default)] diff --git a/src/generators/cpp_s.rs b/src/generators/cpp_s.rs index 14be085..51e9baa 100644 --- a/src/generators/cpp_s.rs +++ b/src/generators/cpp_s.rs @@ -33,7 +33,8 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { Types::U64 => "std::uint64_t".into(), Types::Named(name) => name.into(), Types::Optional(inner) => format!("std::optional<{}>", ty_to_str(inner)), - Types::Array(inner) => format!("std::vector<{}>", ty_to_str(inner)) + Types::Array(inner) => format!("std::vector<{}>", ty_to_str(inner)), + Types::Generic(_, _, _) => unreachable!() } } @@ -86,7 +87,8 @@ pub fn json_write(ty: &crate::data::FieldTy) -> String { {} }} __w.EndArray();", inner_var_name, ty.name, inner) - } + }, + Types::Generic(_, _, _) => unreachable!() } } diff --git a/src/generators/ts_c.rs b/src/generators/ts_c.rs index 5117918..8656e4f 100644 --- a/src/generators/ts_c.rs +++ b/src/generators/ts_c.rs @@ -12,7 +12,8 @@ pub fn ty_to_str(ty: &crate::data::Types) -> String { |Types::U8 | Types::U16 | Types::U32 | Types::U64 => "number".into(), Types::Named(name) => name.into(), Types::Optional(inner) => format!("({}|null)", ty_to_str(inner)), - Types::Array(inner) => format!("{}[]", ty_to_str(inner)) + Types::Array(inner) => format!("{}[]", ty_to_str(inner)), + Types::Generic(_, _, _) => unreachable!() } } diff --git a/src/main.rs b/src/main.rs index 1ceaa36..4df6a47 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,13 @@ mod data; mod generators; mod templates; +mod parser; +use std::collections::HashMap; +use std::fmt::Write; use std::fs::File; use std::io::Read; use clap::Parser; -use syn::spanned::Spanned; #[derive(Debug, clap::Parser)] struct Args { @@ -28,141 +30,93 @@ struct GenArgs { static SOURCE_FILE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); static SOURCE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); -fn parse_enum(item: &syn::ItemEnum) -> data::EnumTy { - data::EnumTy { - name: item.ident.to_string(), - values: item.variants.iter().enumerate().map(|(i, v) | (v.ident.to_string(), i)).collect() +fn has_generic(ty: &mut data::Types) -> Option<&mut data::Types> { + use data::Types; + match ty { + Types::String | Types::Bool | + Types::F32 | Types::F64 | + Types::I8 | Types::I16 | Types::I32 | Types::I64 | + Types::U8 | Types::U16 | Types::U32 | Types::U64 => None, + Types::Array(inner) => has_generic(inner), + Types::Optional(inner) => has_generic(inner), + Types::Named(_) => None, + Types::Generic(_, _, _) => Some(ty) } } -fn parse_type_string(ty: String) -> data::Types { - use data::Types::*; - match ty.as_str() { - "str" | "String" => String, - "bool" => Bool, - "f32" => F32, - "f64" => F64, - "i8" => I8, - "i16" => I16, - "i32" => I32, - "i64" => I64, - "u8" => U8, - "u16" => U16, - "u32" => U32, - "u64" => U64, - _ => Named(ty) +pub fn gen_generic_name(ty: &data::Types) -> String { + use data::Types; + match &ty { + Types::String => "string".into(), + Types::Bool => "bool".into(), + Types::F32 => "f32".into(), + Types::F64 => "f64".into(), + Types::I8 => "i8".into(), + Types::I16 => "i16".into(), + Types::I32 => "i32".into(), + Types::I64 => "i64".into(), + Types::U8 => "u8".into(), + Types::U16 => "u16".into(), + Types::U32 => "u32".into(), + Types::U64 => "u64".into(), + Types::Named(name) => name.clone(), + Types::Optional(inner) => format!("maybe_{}", gen_generic_name(inner)), + Types::Array(inner) => format!("array_{}", gen_generic_name(inner)), + Types::Generic(_, _, span) => emit_error(Some((span.clone(), "Nested custom generics are not allowed"))) } } - -fn parse_type(item: &syn::Type) -> data::Types { - match item { - syn::Type::Path(path) => { - let segments = &path.path.segments; - if segments.len() != 1 { - emit_error(item.span(), "Path segments with len != 1"); - } - let segment = &segments[0]; - if !segment.arguments.is_empty() { - if segment.ident.to_string() != "Option" { - emit_error(item.span(), "Only Option are currently allowed to have arguments"); - } - let args = match &segment.arguments { - syn::PathArguments::AngleBracketed(v) => v, - _ => emit_error(item.span(), "Angle bracketed arguments expected") - }; - if args.args.len() != 1 { - emit_error(item.span(), "Expected 1 argument"); - } - let ty = match &args.args[0] { - syn::GenericArgument::Type(v) => parse_type(v), - _ => emit_error(item.span(), "Type bracketed arguments expected") - }; - data::Types::Optional(ty.into()) - } else { - parse_type_string(segment.ident.to_string()) - } - } - syn::Type::Slice(slice) => { - data::Types::Array(parse_type(&slice.elem).into()) - } - _ => emit_error(item.span(), "Unsupported type") - } -} - -fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { - let name = item.ident.to_string(); - let mut fields = vec![]; - for field in &item.fields { - if field.ident.is_none() { - emit_error(field.span(), "Missing field name"); - } - let name = field.ident.as_ref().unwrap().to_string(); - let ty = parse_type(&field.ty); - fields.push(data::FieldTy { name, ty }); - } - data::StructTy { name, fields } -} - -fn try_parse_iterator(ty: &syn::Type) -> Option { - if let syn::Type::Path(ty) = ty { - let seg = ty.path.segments.last()?; - if seg.ident.to_string() == "Iterator" { - if let syn::PathArguments::AngleBracketed(args) = &seg.arguments { - if let Some(syn::GenericArgument::Type(ty)) = args.args.first() { - Some(parse_type(ty)) - } else { None } - } else { None } - } else { None } - } else { None } -} - -fn parse_method(item: &syn::Signature) -> data::MethodTy { - let mut method = data::MethodTy::default(); - method.name = item.ident.to_string(); - - for arg in &item.inputs { - let arg = match arg { - syn::FnArg::Typed(v) => v, - _ => emit_error(arg.span(), "Unsupported argument") +fn replace_generics<'a>(it: impl Iterator, generics: &HashMap, new_structs: &mut HashMap) { + for ty in it.filter_map(|ty| has_generic(ty)) { + let (name, inner, span) = match ty { + data::Types::Generic(name, inner, span) => + (name.clone(), inner.clone(), span.clone()), + _ => unreachable!() }; - let ty = parse_type(&arg.ty); - let name = match &*arg.pat { - syn::Pat::Ident(v) => v.ident.to_string(), - _ => emit_error(arg.span(), "Unsupported argument") + let generic = match generics.get(&name) { + Some(v) => v, + None => emit_error(Some((span, "Type does not exists"))) }; - method.args.push(data::FieldTy { name, ty }); + let new_name = format!("{}_{}", name, gen_generic_name(&inner)); + *ty = data::Types::Named(new_name.clone()); + new_structs.insert(new_name, (generic.clone(), *inner)); } - - match &item.output { - syn::ReturnType::Default => { - method.ret = None; - method.ret_stream = false; - } - syn::ReturnType::Type(_, ty) => { - if let Some(ty) = try_parse_iterator(ty) { - method.ret_stream = true; - method.ret = Some(ty); - } else { - method.ret_stream = false; - method.ret = Some(parse_type(ty)); - } - } - }; - - method } -fn parse_service(item: &syn::ItemTrait) -> data::ServiceTy { - let name = item.ident.to_string(); - let mut methods = vec![]; - for item in &item.items { - let item = match item { - syn::TraitItem::Fn(v) => v, - _ => emit_error(item.span(), "Only functions are supported") - }; - methods.push(parse_method(&item.sig)); +fn replace_generic_type(ty: &mut data::Types, generic_name: &String, replacement: &data::Types) { + use data::Types; + match ty { + Types::Named(name) => if name == generic_name { *ty = replacement.clone(); } else { unreachable!() }, + Types::Optional(inner) => replace_generic_type(inner, generic_name, replacement), + Types::Array(inner) => replace_generic_type(inner, generic_name, replacement), + _ => unreachable!() + } +} + +fn resolve_generics(rpc: &mut data::RPC) { + let mut generics: HashMap = HashMap::new(); + rpc.structs.retain(|s| if s.generic_name.is_none() { true } else { + generics.insert(s.name.clone(), s.clone()); + false + }); + + let mut new_structs = HashMap::new(); + for s in &mut rpc.structs { + replace_generics(s.fields.iter_mut().map(|f| &mut f.ty), &generics, &mut new_structs); + } + for m in &mut rpc.services.iter_mut().map(|s| s.methods.iter_mut()).flatten() { + replace_generics(m.args.iter_mut().map(|f| &mut f.ty).chain(&mut m.ret), &generics, &mut new_structs); + } + + for (name, (mut generic, inner)) in new_structs { + generic.name = name; + let generic_name = generic.generic_name.take().unwrap(); + for mut field in generic.generic_fields { + replace_generic_type(&mut field.ty, &generic_name, &inner); + generic.fields.push(field); + } + generic.generic_fields = vec![]; + rpc.structs.push(generic); } - data::ServiceTy { name, methods } } fn main() { @@ -175,19 +129,21 @@ fn main() { SOURCE_FILE.set(args.file.to_string_lossy().to_string()).unwrap(); SOURCE.set(content).unwrap(); - let ast = syn::parse_file(SOURCE.get().unwrap()).unwrap(); + let ast = match syn::parse_file(SOURCE.get().unwrap()) { + Ok(v) => v, + Err(e) => emit_error(e.into_iter() + .map(|e| { + let span = e.span(); + let mut msg = String::new(); + write!(msg, "{e}").unwrap(); + (span, msg) + }) + ) + }; - let mut rpc = data::RPC::default(); + let mut rpc = parser::parse_file(&ast); - for item in &ast.items { - match item { - syn::Item::Enum(v) => rpc.enums.push(parse_enum(v)), - syn::Item::Struct(v) => rpc.structs.push(parse_struct(v)), - syn::Item::Trait(v) => rpc.services.push(parse_service(v)), - syn::Item::Use(_) => {} - _ => emit_error(item.span(), "Unsupported item") - } - } + resolve_generics(&mut rpc); for gen in &args.generators.clients { gen.generate(&args.rpc_name, &rpc); @@ -198,7 +154,7 @@ fn main() { } } -fn emit_error(span: proc_macro2::Span, msg: impl Into) -> ! { +pub fn emit_error(errors: impl IntoIterator)>) -> ! { use codespan_reporting::{ diagnostic::{Diagnostic, Label}, files::{SimpleFiles, Files}, @@ -214,14 +170,17 @@ fn emit_error(span: proc_macro2::Span, msg: impl Into) -> ! { let file = files.get(file_id).unwrap(); - let start = span.start(); - let start: usize = file.line_range((), start.line-1).unwrap().start + start.column; - let end = span.end(); - let end: usize = file.line_range((), end.line-1).unwrap().start + end.column; + let errors = errors.into_iter().map(|(span, msg)| { + let start = span.start(); + let start: usize = file.line_range((), start.line-1).unwrap().start + start.column; + let end = span.end(); + let end: usize = file.line_range((), end.line-1).unwrap().start + end.column; + Label::primary(file_id, start..end).with_message(msg) + }); let diagnostic = Diagnostic::error() - .with_labels(vec![Label::primary(file_id, start..end).with_message(msg)]); + .with_labels(errors.collect()); term::emit(&mut writer.lock(), &config, &files, &diagnostic).expect("cannot write error"); - std::process::abort(); + std::process::exit(1); } diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..8d8cbd3 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,200 @@ +use syn::spanned::Spanned; +use super::{data, emit_error}; + +fn parse_enum(item: &syn::ItemEnum) -> data::EnumTy { + data::EnumTy { + name: item.ident.to_string(), + values: item.variants.iter().enumerate().map(|(i, v) | (v.ident.to_string(), i)).collect() + } +} + +fn parse_type_string(ty: String) -> data::Types { + use data::Types::*; + match ty.as_str() { + "str" | "String" => String, + "bool" => Bool, + "f32" => F32, + "f64" => F64, + "i8" => I8, + "i16" => I16, + "i32" => I32, + "i64" => I64, + "u8" => U8, + "u16" => U16, + "u32" => U32, + "u64" => U64, + _ => Named(ty) + } +} + +fn parse_type(item: &syn::Type) -> data::Types { + match item { + syn::Type::Path(path) => { + let segments = &path.path.segments; + if segments.len() != 1 { + emit_error(vec![(item.span(), "Path segments with len != 1")]); + } + let segment = &segments[0]; + if !segment.arguments.is_empty() { + let args = match &segment.arguments { + syn::PathArguments::AngleBracketed(v) => v, + _ => emit_error(vec![(item.span(), "Angle bracketed arguments expected")]) + }; + if args.args.len() != 1 { + emit_error(vec![(item.span(), "Expected 1 argument")]); + } + let ty = match &args.args[0] { + syn::GenericArgument::Type(v) => parse_type(v), + _ => emit_error(vec![(item.span(), "Only types are supported")]) + }; + let name = segment.ident.to_string(); + if name == "Option" { + data::Types::Optional(ty.into()) + } else if name == "Vec" { + data::Types::Array(ty.into()) + } else { + data::Types::Generic(name, ty.into(), segment.ident.span()) + } + } else { + parse_type_string(segment.ident.to_string()) + } + } + syn::Type::Slice(slice) => { + data::Types::Array(parse_type(&slice.elem).into()) + } + _ => emit_error(vec![(item.span(), "Unsupported type")]) + } +} + +fn has_generic(ty: &data::Types, generic: &String) -> bool { + use data::Types; + match ty { + Types::String | Types::Bool | + Types::F32 | Types::F64 | + Types::I8 | Types::I16 | Types::I32 | Types::I64 | + Types::U8 | Types::U16 | Types::U32 | Types::U64 => false, + Types::Array(inner) => has_generic(inner, generic), + Types::Optional(inner) => has_generic(inner, generic), + Types::Named(name) => name == generic, + Types::Generic(name, inner, _) => name == generic || has_generic(inner, generic) + } +} + +fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { + let name = item.ident.to_string(); + if let Some(v) = &item.generics.where_clause { + emit_error(Some((v.span(), "Where clauses are not allowed"))); + } + if item.generics.params.len() > 1 { + emit_error(Some((item.generics.params.span(), "Only one generic parameter is allowed for now"))); + } + let generic_name = item.generics.params.first().map(|g| { + match g { + syn::GenericParam::Const(_) | + syn::GenericParam::Lifetime(_) => emit_error(Some((g.span(), "Only generic types are allowed"))), + syn::GenericParam::Type(ty) => { + if !ty.bounds.is_empty() { + emit_error(Some((ty.span(), "Bounds are not allowed"))); + } + if let Some(d) = &ty.default { + emit_error(Some((d.span(), "Defaults are not allowed"))); + } + ty.ident.to_string() + } + } + }); + let mut generic_fields = vec![]; + let mut fields = vec![]; + for field in &item.fields { + if field.ident.is_none() { + emit_error(vec![(field.span(), "Missing field name")]); + } + let name = field.ident.as_ref().unwrap().to_string(); + let ty = parse_type(&field.ty); + let ty = data::FieldTy { name, ty }; + if generic_name.is_some() && has_generic(&ty.ty, generic_name.as_ref().unwrap()) { + generic_fields.push(ty); + } else { + fields.push(ty); + } + } + data::StructTy { name, fields, generic_fields, generic_name } +} + +fn try_parse_iterator(ty: &syn::Type) -> Option { + if let syn::Type::Path(ty) = ty { + let seg = ty.path.segments.last()?; + if seg.ident.to_string() == "Iterator" { + if let syn::PathArguments::AngleBracketed(args) = &seg.arguments { + if let Some(syn::GenericArgument::Type(ty)) = args.args.first() { + Some(parse_type(ty)) + } else { None } + } else { None } + } else { None } + } else { None } +} + +fn parse_method(item: &syn::Signature) -> data::MethodTy { + let mut method = data::MethodTy::default(); + method.name = item.ident.to_string(); + + for arg in &item.inputs { + let arg = match arg { + syn::FnArg::Typed(v) => v, + _ => emit_error(vec![(arg.span(), "Unsupported argument")]) + }; + let ty = parse_type(&arg.ty); + let name = match &*arg.pat { + syn::Pat::Ident(v) => v.ident.to_string(), + _ => emit_error(vec![(arg.span(), "Unsupported argument")]) + }; + method.args.push(data::FieldTy { name, ty }); + } + + match &item.output { + syn::ReturnType::Default => { + method.ret = None; + method.ret_stream = false; + } + syn::ReturnType::Type(_, ty) => { + if let Some(ty) = try_parse_iterator(ty) { + method.ret_stream = true; + method.ret = Some(ty); + } else { + method.ret_stream = false; + method.ret = Some(parse_type(ty)); + } + } + }; + + method +} + +fn parse_service(item: &syn::ItemTrait) -> data::ServiceTy { + let name = item.ident.to_string(); + let mut methods = vec![]; + for item in &item.items { + let item = match item { + syn::TraitItem::Fn(v) => v, + _ => emit_error(vec![(item.span(), "Only functions are supported")]) + }; + methods.push(parse_method(&item.sig)); + } + data::ServiceTy { name, methods } +} + +pub fn parse_file(ast: &syn::File) -> data::RPC { + let mut rpc = data::RPC::default(); + + for item in &ast.items { + match item { + syn::Item::Enum(v) => rpc.enums.push(parse_enum(v)), + syn::Item::Struct(v) => rpc.structs.push(parse_struct(v)), + syn::Item::Trait(v) => rpc.services.push(parse_service(v)), + syn::Item::Use(_) => {} + _ => emit_error(vec![(item.span(), "Unsupported item")]) + } + } + + rpc +} diff --git a/templates/typescript_client.rs.ts b/templates/typescript_client.rs.ts index 45fa032..9d8ea11 100644 --- a/templates/typescript_client.rs.ts +++ b/templates/typescript_client.rs.ts @@ -18,17 +18,17 @@ export interface @s.name @{ export class MRPCConnector @{ url: string; + private __create_msg(service: string, method: string, data: any) @{ + return @{service, method, data@}; + @} + public constructor(url: string) @{ this.url = url; @} @for s in &rpc.services { @for m in &s.methods { public @(s.name)_@(m.name)(@method_args(m))@method_ret(m) @{ - const __msg = @{ - service: '@s.name', - method: '@m.name', - data: @{@m.args.iter().map(|a| &a.name).join(",")@} - @}; + const __msg = this.__create_msg('@s.name', '@m.name', @{@m.args.iter().map(|a| &a.name).join(",")@}); @if m.ret.is_some() && !m.ret_stream {return fetch(this.url, @{ method: 'POST', body: JSON.stringify(__msg)