commit 2c1e684e33d76c673edb574b1d51c7cf8f69db62 Author: Mutzi Date: Wed Sep 27 17:51:03 2023 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6e858b2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Created by https://www.toptal.com/developers/gitignore/api/rust,clion +# Edit at https://www.toptal.com/developers/gitignore?templates=rust,clion + +### CLion ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### CLion Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### Rust ### +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# End of https://www.toptal.com/developers/gitignore/api/rust,clion diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..c46b28e --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/mrpc.iml b/.idea/mrpc.iml new file mode 100644 index 0000000..cf84ae4 --- /dev/null +++ b/.idea/mrpc.iml @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..56bad3e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "mrpc" +version = "0.1.0" +edition = "2021" + +[dependencies] +codespan-reporting = "0.11.1" +once_cell = "1.18.0" +itertools = "0.11.0" + +[dependencies.proc-macro2] +version = "1.0.67" +features = ["span-locations"] + +[dependencies.syn] +version = "2.0.37" +features = ["full", "extra-traits"] + +[dependencies.clap] +version = "4.4.5" +features = ["derive"] diff --git a/src/data.rs b/src/data.rs new file mode 100644 index 0000000..04bb83a --- /dev/null +++ b/src/data.rs @@ -0,0 +1,61 @@ +#[derive(Debug, Clone)] +pub enum Types { + String, + Bool, + F32, F64, + I8, I16, I32, I64, + U8, U16, U32, U64, + Named(String) +} + +#[derive(Debug, Clone, Default)] +pub struct EnumTy { + pub name: String, + pub values: Vec<(String, usize)> +} + +#[derive(Debug, Clone)] +pub struct FieldTy { + pub name: String, + pub ty: Types, + pub optional: bool, + pub array: bool +} + +#[derive(Debug, Clone, Default)] +pub struct StructTy { + pub name: String, + pub fields: Vec +} + +#[derive(Debug, Clone, Default)] +pub struct MethodTy { + pub name: String, + pub args: Vec, + pub ret: Option, + pub ret_stream: bool +} + +#[derive(Debug, Clone, Default)] +pub struct ServiceTy { + pub name: String, + pub methods: Vec +} + +#[derive(Debug, Clone, Default)] +pub struct RPC { + pub enums: Vec, + pub structs: Vec, + pub services: Vec +} + +impl FieldTy { + pub fn new(ty: Types) -> Self { + Self { + name: String::new(), + ty, + optional: false, + array: false + } + } +} diff --git a/src/generators/cpp_s.rs b/src/generators/cpp_s.rs new file mode 100644 index 0000000..ad37372 --- /dev/null +++ b/src/generators/cpp_s.rs @@ -0,0 +1,294 @@ +use std::io::Write; +use itertools::Itertools; +use crate::data::RPC; +use super::IndentedWriter; + +fn field_ty_to_ty_str(ty: &crate::data::FieldTy, ignore_optional: bool) -> String { + use crate::data::Types; + let inner = match &ty.ty { + Types::String => "std::string", + Types::Bool => "bool", + Types::F32 => "std::float_t", + Types::F64 => "std::double_t", + Types::I8 => "std::int8_t", + Types::I16 => "std::int16_t", + Types::I32 => "std::int32_t", + Types::I64 => "std::int64_t", + Types::U8 => "std::uint8_t", + Types::U16 => "std::uint16_t", + Types::U32 => "std::uint32_t", + Types::U64 => "std::uint64_t", + Types::Named(name) => name + }; + + if ty.array { + format!("std::vector<{inner}>") + } else if ty.optional && !ignore_optional { + format!("std::optional<{inner}>") + } else { + inner.to_string() + } +} + +fn method_signature(service_name: &String, m: &crate::data::MethodTy) -> String { + let mut ret = String::new(); + + let ret_type = m.ret.as_ref().map_or("void".to_string(), |ty| field_ty_to_ty_str(ty, false)); + + if m.ret_stream { + ret += "void"; + } else { + ret += &ret_type; + } + + ret += &format!(" {}_{}(", service_name, m.name); + + ret += &m.args.iter() + .map(|arg| field_ty_to_ty_str(arg, false)) + .chain(m.ret_stream.then(|| format!("std::shared_ptr>"))) + .map(|arg| arg + "&&") + .join(", "); + + ret += ")"; + + ret +} + +fn output_header(f: &mut IndentedWriter, rpc: &RPC) { + f.f.write_all( +b"#pragma once +#ifndef MRPC_GEN_H +#define MRPC_GEN_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mrpc {\n").unwrap(); + + for e in &rpc.enums { + writeln!(f, "enum struct {} : std::uint64_t {{", e.name).unwrap(); + if let Some((last, vals)) = e.values.split_last() { + for v in vals { + writeln!(f, "{} = {},", v.0, v.1).unwrap(); + } + writeln!(f, "{} = {}", last.0, last.1).unwrap(); + } + f.write_all(b"};\n\n").unwrap(); + } + + for s in &rpc.structs { + writeln!(f, "struct {};", s.name).unwrap(); + writeln!(f, "void to_json(nlohmann::json&, const {}&);", s.name).unwrap(); + writeln!(f, "void from_json(const nlohmann::json&, {}&);\n", s.name).unwrap(); + } + + f.f.write_all(b"\n").unwrap(); + + for s in &rpc.structs { + writeln!(f, "struct {} {{", s.name).unwrap(); + for field in &s.fields { + writeln!(f, "{} {};", field_ty_to_ty_str(field, false), field.name).unwrap(); + } + f.write_all(b"};\n\n").unwrap(); + } + + f.f.write_all( +b"struct MRPCStreamImpl { + virtual void close() noexcept final; + virtual void abort() noexcept final; + virtual bool is_open() noexcept final; +protected: + MRPCStreamImpl(crow::websocket::connection *conn, uint64_t id) : conn(conn), id(id) {} + crow::websocket::connection* conn; + std::uint64_t id; +}; + +template +struct MRPCStream final : MRPCStreamImpl { + MRPCStream(crow::websocket::connection *conn, uint64_t id) : MRPCStreamImpl(conn, id) {} + bool send(const T &v) noexcept { + if (!conn) return false; + try { + conn->send_text(nlohmann::json{{\"id\", id},{\"data\", v}}.dump()); + } catch (const std::exception &_) { + abort(); + return false; + } + return true; + } +}; + +struct MRPCServer { + virtual void install(crow::SimpleApp &app, std::string &&route) final; +private:\n").unwrap(); + f.ident = 1; + + for service in &rpc.services { + for method in &service.methods { + writeln!(f, "virtual {} = 0;", method_signature(&service.name, method)).unwrap(); + } + } + + f.f.write_all(b" + virtual void msg_handler(crow::websocket::connection&, const std::string&, bool) final; + + std::mutex __streams_mutex; + std::unordered_multimap> __streams; +}; +} +#endif // MRPC_GEN_H\n").unwrap(); +} + +fn output_struct_json_stuff(f: &mut IndentedWriter, s: &crate::data::StructTy) { + writeln!(f, "void to_json(json &j, const {} &v) {{", s.name).unwrap(); + for field in &s.fields { + if field.optional { + writeln!(f, "json_set_opt(j, \"{0}\", v.{0});", field.name).unwrap(); + } else { + writeln!(f, "j[\"{0}\"] = v.{0};", field.name).unwrap(); + } + } + f.write_all(b"}\n\n").unwrap(); + writeln!(f, "void from_json(const json &j, {} &v) {{", s.name).unwrap(); + for field in &s.fields { + if field.optional { + writeln!(f, "v.{0} = json_get_opt<{1}>(j, \"{0}\");", field.name, field_ty_to_ty_str(field, true)).unwrap(); + } else { + writeln!(f, "j.at(\"{0}\").get_to(v.{0});", field.name).unwrap(); + } + } + f.write_all(b"}\n\n").unwrap(); +} + +fn output_cpp(f: &mut IndentedWriter, header_name: String, rpc: &RPC) { + writeln!(f, "#include \"{header_name}\"").unwrap(); + f.f.write_all( +b"using json = nlohmann::json; + +template +inline std::optional json_get_opt(const json &j, std::string &&k) { + if (j.contains(k) && !j.at(k).is_null()) + return j.at(k).get(); + else + return std::nullopt; +} + +template +inline void json_set_opt(json &j, std::string &&k, const std::optional &v) { + if (v.has_value()) + j[k] = v.value(); + else + j[k] = nullptr; +} + +namespace mrpc {\n").unwrap(); + + for s in &rpc.structs { + output_struct_json_stuff(f, s); + } + + f.f.write_all( +b"} + +template +void send_msg(crow::websocket::connection &c, uint64_t id, const T &v) { + c.send_text(json{{\"id\", id},{\"data\", v}}.dump()); +} + +void mrpc::MRPCStreamImpl::close() noexcept { + if (conn != nullptr) { + send_msg(*conn, id, nullptr); + conn = nullptr; + } +} +void mrpc::MRPCStreamImpl::abort() noexcept { conn = nullptr; } +bool mrpc::MRPCStreamImpl::is_open() noexcept { return conn != nullptr; } + +void mrpc::MRPCServer::install(crow::SimpleApp &app, std::string &&route) { + app.route_dynamic(std::move(route)) + .websocket() + .onclose([&](crow::websocket::connection &c, const std::string&){ + std::lock_guard guard{__streams_mutex}; + auto range = __streams.equal_range(&c); + for (auto it = range.first; it != range.second; ++it) + it->second->abort(); + __streams.erase(&c); + }) + .onmessage([this](auto &&a, auto &&b, auto &&c) { + try { msg_handler(a, b, c); } + catch (const std::exception &_) {} + }); +} +void mrpc::MRPCServer::msg_handler(crow::websocket::connection &__c, const std::string &__msg, bool) { + json __j = json::parse(__msg); + std::uint64_t __id = __j.at(\"id\"); + std::string __service = __j.at(\"service\"), __method = __j.at(\"method\"); + try {\n").unwrap(); + f.ident = 2; + + f.write_all(b"json __data = __j.at(\"data\");\n").unwrap(); + + let mut first_service = true; + for service in &rpc.services { + if first_service { first_service = false; } + else { f.write_all(b"else ").unwrap(); } + writeln!(f, "if (__service == \"{}\") {{", service.name).unwrap(); + let mut first_method = true; + for method in &service.methods { + if first_method { first_method = false; } + else { f.write_all(b"else ").unwrap(); } + writeln!(f, "if (__method == \"{}\") {{", method.name).unwrap(); + if method.ret_stream { + writeln!(f, "auto __stream = std::make_shared>(&__c, __id);", + field_ty_to_ty_str(method.ret.as_ref().unwrap(), false)).unwrap(); + f.write_all(b"{ std::lock_guard guard{__streams_mutex}; __streams.emplace(&__c, __stream); }\n").unwrap(); + } + + for arg in &method.args { + let ty = field_ty_to_ty_str(&arg, false); + if arg.optional { + writeln!(f, "{0} {1} = json_get_opt<{2}>(__data, \"{1}\");", ty, arg.name, field_ty_to_ty_str(arg, true)).unwrap(); + } else { + writeln!(f, "{0} {1} = __data.at(\"{1}\");", ty, arg.name).unwrap(); + } + } + + let args = method.args.iter() + .map(|arg| format!("std::move({})", arg.name)) + .chain(method.ret_stream.then_some(String::from("std::move(__stream)"))) + .collect::>(); + + if method.ret.is_none() || method.ret_stream { + writeln!(f, "{}_{}({});", service.name, method.name, args.join(", ")).unwrap(); + } else { + writeln!(f, "auto __ret = {}_{}({});", service.name, method.name, args.join(", ")).unwrap(); + f.write_all(b"send_msg(__c, __id, __ret);\n").unwrap(); + } + + f.write_all(b"} ").unwrap(); + } + + f.write_all(b"else { throw std::exception{}; }\n} ").unwrap(); + } + + f.f.write_all(b"else { throw std::exception{}; } + } catch (const std::exception &_) { + std::cerr << \"Got invalid request \" << __id << \" for \" << __service << \"::\" << __method << std::endl; + } +}\n\n").unwrap(); +} + +pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) { + output_header(&mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("h")).unwrap()), rpc); + output_cpp( + &mut IndentedWriter::new(std::fs::File::create(file_base_name.with_extension("cpp")).unwrap()), + file_base_name.with_extension("h").file_name().unwrap().to_string_lossy().to_string(), + rpc + ); +} diff --git a/src/generators/mod.rs b/src/generators/mod.rs new file mode 100644 index 0000000..6828890 --- /dev/null +++ b/src/generators/mod.rs @@ -0,0 +1,75 @@ +mod cpp_s; +mod ts_c; + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum ServerGenerators { + Cpp +} + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum ClientGenerators { + Ts +} + +impl ServerGenerators { + pub fn generate(&self, file_base_name: &std::path::PathBuf, rpc: &crate::data::RPC) { + match self { + Self::Cpp => cpp_s::gen(file_base_name, rpc) + } + } +} + +impl ClientGenerators { + pub fn generate(&self, file_base_name: &std::path::PathBuf, rpc: &crate::data::RPC) { + match self { + Self::Ts => ts_c::gen(file_base_name, rpc) + } + } +} + +pub struct IndentedWriter { + pub ident: usize, + pub f: std::fs::File, + indent_next: bool +} + +impl IndentedWriter { + pub fn new(f: std::fs::File) -> Self { + Self { + ident: 0, + f, + indent_next: false + } + } +} + +impl std::io::Write for IndentedWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + for b in buf { + if b == &b'}' && self.ident > 0 { + self.ident -= 1; + } + + if b == &b'\n' { + self.indent_next = true; + } else if self.indent_next { + self.indent_next = false; + for _ in 0..self.ident { + self.f.write_all(b" ")?; + } + } + + if b == &b'{' { + self.ident += 1; + } + + self.f.write_all(&[*b])?; + } + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.f.flush() + } +} + diff --git a/src/generators/ts_c.rs b/src/generators/ts_c.rs new file mode 100644 index 0000000..d3d1c56 --- /dev/null +++ b/src/generators/ts_c.rs @@ -0,0 +1,157 @@ +use std::io::Write; +use itertools::Itertools; +use crate::data::RPC; +use super::IndentedWriter; + +fn output_common(f: &mut IndentedWriter) { + f.f.write_all( +b"interface _WSResponse { + id: number; + data: any; +} + +interface _WSWaitingEntry { + ok: (v: any) => void; + err: (reason?: any) => void; +} + +export class MRPCConnector { + url: string; + socket: WebSocket; + next_message_id: number; + waiting: { [id: number]: _WSWaitingEntry }; + streams: { [id: number]: (v: any) => void }; + + private open() { + this.socket = new WebSocket(this.url); + this.socket.onmessage = ev => { + const data = JSON.parse(ev.data) as _WSResponse; + if (data.id in this.streams) { + this.streams[data.id](data.data); + if (data.data == null) + delete this.streams[data.id]; + } else if (data.id in this.waiting) { + this.waiting[data.id].ok(data.data); + delete this.waiting[data.id]; + } else { + console.log(`Got unexpected message: ${data}`); + } + } + this.socket.onerror = () => setTimeout(() => {this.open();}, 2500); + this.socket.onclose = () => setTimeout(() => {this.open();}, 2500); + } + + public constructor(url: string) { + this.url = url; + this.next_message_id = 0; + this.waiting = {}; + this.streams = {}; + this.open(); + }\n\n").unwrap(); + f.ident = 1; +} + +fn field_ty_to_ty_str(ty: &crate::data::FieldTy, with_name: bool) -> String { + use crate::data::Types; + + let mut ret = String::new(); + if with_name { + ret += &ty.name; + if ty.optional { + ret += "?"; + } + ret += ": "; + } + ret += match &ty.ty { + Types::String => "string", + Types::Bool => "boolean", + Types::F32 | Types::F64 + |Types::I8 | Types::I16 | Types::I32 | Types::I64 + |Types::U8 | Types::U16 | Types::U32 | Types::U64 => "number", + Types::Named(name) => name + }; + if ty.array { + ret += "[]"; + } + ret +} + +fn output_services(f: &mut IndentedWriter, rpc: &RPC) { + for service in &rpc.services { + for method in &service.methods { + write!(f, "public {}_{}(", service.name, method.name).unwrap(); + f.write_all(method.args.iter() + .map(|arg| field_ty_to_ty_str(arg, true)) + .chain(method.ret_stream.then(|| format!("cbk: (v: {}) => void", field_ty_to_ty_str(method.ret.as_ref().unwrap(), false)))) + .join(", ") + .as_bytes() + ).unwrap(); + f.write_all(b")").unwrap(); + if let Some(ret) = &method.ret { + if ret.optional { + unimplemented!("Optional return value is current not supported in typescript client"); + } + if !method.ret_stream { + write!(f, ": Promise<{}>", field_ty_to_ty_str(ret, false)).unwrap(); + } + } + f.write_all(b" {\nconst msg = {id:this.next_message_id++,").unwrap(); + write!(f, "service:'{}',method:'{}',data:{{", service.name, method.name).unwrap(); + f.write_all(method.args.iter() + .map(|arg| { + if arg.optional { + format!("{0}:{0}||null", arg.name) + } else { + arg.name.clone() + } + }) + .join(",") + .as_bytes() + ).unwrap(); + f.write_all(b"}};\n").unwrap(); + if let Some(ret) = &method.ret { + if !method.ret_stream { + writeln!(f, "const p = new Promise<{}>((ok,err) => {{ this.waiting[msg.id] = {{ ok, err }}; }});", field_ty_to_ty_str(ret, false)).unwrap(); + } else { + f.write_all(b"this.streams[msg.id] = cbk;\n").unwrap(); + } + } + f.write_all(b"this.socket.send(JSON.stringify(msg));\n").unwrap(); + if method.ret.is_some() && !method.ret_stream{ + f.write_all(b"return p;\n").unwrap(); + } + f.write_all(b"}\n\n").unwrap(); + } + } + f.write_all(b"}").unwrap(); +} + +fn output_enums(f: &mut IndentedWriter, rpc: &RPC) { + for e in &rpc.enums { + writeln!(f, "export enum {} {{", e.name).unwrap(); + for (k, v) in &e.values { + writeln!(f, "{k} = {v},").unwrap(); + } + f.write_all(b"}\n\n").unwrap(); + } +} + +fn output_structs(f: &mut IndentedWriter, rpc: &RPC) { + for s in &rpc.structs { + writeln!(f, "export interface {} {{", s.name).unwrap(); + for field in &s.fields { + writeln!(f, "{};", field_ty_to_ty_str(field, true)).unwrap(); + } + f.write_all(b"}\n\n").unwrap(); + } +} + +pub fn gen(file_base_name: &std::path::PathBuf, rpc: &RPC) { + let f = std::fs::File::create(file_base_name.with_extension("ts")).unwrap(); + let mut f = IndentedWriter::new(f); + let f = &mut f; + output_enums(f, rpc); + output_structs(f, rpc); + output_common(f); + output_services(f, rpc); +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..2f5fff1 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,235 @@ +mod data; +mod generators; + +use std::fs::File; +use std::io::Read; +use clap::Parser; +use syn::spanned::Spanned; + +#[derive(Debug, clap::Parser)] +struct Args { + pub file: std::path::PathBuf, + #[arg(short='n', long="name", help="base name for output files", required=true)] + pub rpc_name: std::path::PathBuf, + #[command(flatten)] + pub generators: GenArgs +} + +#[derive(Debug, clap::Args)] +#[group(required = true, multiple = true)] +struct GenArgs { + #[arg(value_enum, short='c', long="client")] + pub clients: Vec, + #[arg(value_enum, short='s', long="server")] + pub servers: Vec +} + +static SOURCE_FILE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); +static SOURCE: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); + +fn parse_enum(item: &syn::ItemEnum) -> data::EnumTy { + data::EnumTy { + name: item.ident.to_string(), + values: item.variants.iter().enumerate().map(|(i, v) | (v.ident.to_string(), i)).collect() + } +} + +fn parse_type_string(ty: String) -> data::Types { + use data::Types::*; + match ty.as_str() { + "str" | "String" => String, + "bool" => Bool, + "f32" => F32, + "f64" => F64, + "i8" => I8, + "i16" => I16, + "i32" => I32, + "i64" => I64, + "u8" => U8, + "u16" => U16, + "u32" => U32, + "u64" => U64, + _ => Named(ty) + } +} + +fn parse_type(item: &syn::Type) -> data::FieldTy { + match item { + syn::Type::Path(path) => { + let segments = &path.path.segments; + if segments.len() != 1 { + emit_error(item.span(), "Path segments with len != 1"); + } + let segment = &segments[0]; + if !segment.arguments.is_empty() { + if segment.ident.to_string() != "Option" { + emit_error(item.span(), "Only Option are currently allowed to have arguments"); + } + let args = match &segment.arguments { + syn::PathArguments::AngleBracketed(v) => v, + _ => emit_error(item.span(), "Angle bracketed arguments expected") + }; + if args.args.len() != 1 { + emit_error(item.span(), "Expected 1 argument"); + } + let mut ty = match &args.args[0] { + syn::GenericArgument::Type(v) => parse_type(v), + _ => emit_error(item.span(), "Type bracketed arguments expected") + }; + ty.optional = true; + ty + } else { + data::FieldTy::new(parse_type_string(segment.ident.to_string())) + } + } + syn::Type::Slice(slice) => { + let mut ty = parse_type(&slice.elem); + if ty.array { + emit_error(item.span(), "Double array found"); + } + ty.array = true; + ty + } + _ => emit_error(item.span(), "Unsupported type") + } +} + +fn parse_struct(item: &syn::ItemStruct) -> data::StructTy { + let name = item.ident.to_string(); + let mut fields = vec![]; + for field in &item.fields { + if field.ident.is_none() { + emit_error(field.span(), "Missing field name"); + } + let name = field.ident.as_ref().unwrap().to_string(); + let ty = parse_type(&field.ty); + fields.push(data::FieldTy { + name, + ..ty + }); + } + data::StructTy { name, fields } +} + +fn try_parse_iterator(ty: &syn::Type) -> Option { + if let syn::Type::Path(ty) = ty { + let seg = ty.path.segments.last()?; + if seg.ident.to_string() == "Iterator" { + if let syn::PathArguments::AngleBracketed(args) = &seg.arguments { + if let Some(syn::GenericArgument::Type(ty)) = args.args.first() { + Some(parse_type(ty)) + } else { None } + } else { None } + } else { None } + } else { None } +} + +fn parse_method(item: &syn::Signature) -> data::MethodTy { + let mut method = data::MethodTy::default(); + method.name = item.ident.to_string(); + + for arg in &item.inputs { + let arg = match arg { + syn::FnArg::Typed(v) => v, + _ => emit_error(arg.span(), "Unsupported argument") + }; + let mut ty = parse_type(&arg.ty); + ty.name = match &*arg.pat { + syn::Pat::Ident(v) => v.ident.to_string(), + _ => emit_error(arg.span(), "Unsupported argument") + }; + method.args.push(ty); + } + + match &item.output { + syn::ReturnType::Default => { + method.ret = None; + method.ret_stream = false; + } + syn::ReturnType::Type(_, ty) => { + if let Some(ty) = try_parse_iterator(ty) { + method.ret_stream = true; + method.ret = Some(ty); + } else { + method.ret_stream = false; + method.ret = Some(parse_type(ty)); + } + } + }; + + method +} + +fn parse_service(item: &syn::ItemTrait) -> data::ServiceTy { + let name = item.ident.to_string(); + let mut methods = vec![]; + for item in &item.items { + let item = match item { + syn::TraitItem::Fn(v) => v, + _ => emit_error(item.span(), "Only functions are supported") + }; + methods.push(parse_method(&item.sig)); + } + data::ServiceTy { name, methods } +} + +fn main() { + let args = Args::parse(); + + let mut file = File::open(args.file.clone()).unwrap(); + let mut content = String::new(); + file.read_to_string(&mut content).unwrap(); + + SOURCE_FILE.set(args.file.to_string_lossy().to_string()).unwrap(); + SOURCE.set(content).unwrap(); + + let ast = syn::parse_file(SOURCE.get().unwrap()).unwrap(); + + let mut rpc = data::RPC::default(); + + for item in &ast.items { + match item { + syn::Item::Enum(v) => rpc.enums.push(parse_enum(v)), + syn::Item::Struct(v) => rpc.structs.push(parse_struct(v)), + syn::Item::Trait(v) => rpc.services.push(parse_service(v)), + syn::Item::Use(_) => {} + _ => emit_error(item.span(), "Unsupported item") + } + } + + for gen in &args.generators.clients { + gen.generate(&args.rpc_name, &rpc); + } + + for gen in &args.generators.servers { + gen.generate(&args.rpc_name, &rpc); + } +} + +fn emit_error(span: proc_macro2::Span, msg: impl Into) -> ! { + use codespan_reporting::{ + diagnostic::{Diagnostic, Label}, + files::{SimpleFiles, Files}, + term::{ + self, + termcolor::{ColorChoice, StandardStream}, + }, + }; + let mut files = SimpleFiles::new(); + let config = term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + let file_id = files.add(SOURCE_FILE.get().unwrap().clone(), SOURCE.get().unwrap().clone()); + + let file = files.get(file_id).unwrap(); + + let start = span.start(); + let start: usize = file.line_range((), start.line-1).unwrap().start + start.column; + let end = span.end(); + let end: usize = file.line_range((), end.line-1).unwrap().start + end.column; + + let diagnostic = Diagnostic::error() + .with_labels(vec![Label::primary(file_id, start..end).with_message(msg)]); + + term::emit(&mut writer.lock(), &config, &files, &diagnostic).expect("cannot write error"); + std::process::abort(); +}