rustic_jsonrpc_macro/
lib.rs

1use proc_macro::TokenStream;
2
3use proc_macro2::Span;
4use quote::{format_ident, quote, quote_spanned};
5use syn::parse::{Parse, ParseStream};
6use syn::spanned::Spanned;
7use syn::visit_mut::{visit_lifetime_mut, visit_type_reference_mut, VisitMut};
8use syn::{
9    parse_macro_input, parse_quote, Attribute, Error, FnArg, GenericParam, Ident, ItemFn, Lifetime,
10    LitStr, Meta, Pat, ReturnType, Token, Type, TypeReference,
11};
12
13#[proc_macro]
14pub fn method_ident(input: TokenStream) -> TokenStream {
15    let ident = parse_macro_input!(input as Ident);
16    let ident = format_method_ident(&ident);
17    quote!(#ident).into()
18}
19
20#[proc_macro_attribute]
21pub fn method(attr: TokenStream, input: TokenStream) -> TokenStream {
22    let attr = parse_macro_input!(attr as Attr);
23    let mut func = parse_macro_input!(input as ItemFn);
24    expand_method(&attr, &mut func).unwrap_or_else(|e| e.to_compile_error().into())
25}
26
27fn expand_method(attr: &Attr, func: &mut ItemFn) -> syn::Result<TokenStream> {
28    check_generic(func)?;
29    let args = collect_args(func)?;
30    let params_struct = params_struct(&args);
31    let mut args_gen = Vec::with_capacity(args.len());
32    for arg in args {
33        args_gen.push(match arg.kind {
34            Kind::Param => {
35                let name = arg.ident.unwrap();
36                quote!(params.#name)
37            }
38            Kind::Inject  => match &*arg.ty {
39                Type::Reference(TypeReference { elem: ty, .. }) => {
40                    let name = arg.ident.unwrap();
41                    let mut ty = ty.clone();
42                    RemoveLifetime.visit_type_mut(&mut ty);
43                    quote_spanned! {arg.ty.span()=>
44                        container.get::<#ty>().ok_or(rustic_jsonrpc::InjectError::new::<#ty>(stringify!(#name)))?
45                    }
46                }
47                _ => return Err(Error::new_spanned(arg.ty, "method: expected reference")),
48            }
49            Kind::From((ref ident, param_ty)) => {
50                let ty = &arg.ty;
51                quote_spanned! {ty.span()=>
52                    <#ty as rustic_jsonrpc::FromArg::<#param_ty>>::from_arg(container, params.#ident).await?
53                }
54            }
55        });
56    }
57
58    let ident = &func.sig.ident;
59    let method_ident = format_method_ident(&func.sig.ident);
60    let method_name = attr
61        .method_name
62        .clone()
63        .unwrap_or_else(|| ident.to_string());
64
65    let await_ = func.sig.asyncness.map(|_| quote!(.await));
66    let handler = quote! {
67        |container, params| {
68            #params_struct
69            use rustic_jsonrpc::serde_json::{from_str, to_value};
70            Box::pin(async {
71                match from_str::<Params>(params) {
72                    Ok(params) => Ok(to_value(#ident(#(#args_gen),*)#await_?).expect("json serialize error")),
73                    Err(err) => Err(rustic_jsonrpc::Error::new(rustic_jsonrpc::INVALID_PARAMS, err, None))?,
74                }
75            })
76        }
77    };
78
79    let output_assert = output_assert(&func);
80    func.block.stmts.insert(0, parse_quote!(#output_assert));
81    let vis = &func.vis;
82    let gen = quote! {
83        #func
84        #vis const #method_ident: rustic_jsonrpc::Method = rustic_jsonrpc::Method::new(#method_name, #handler);
85    };
86    Ok(gen.into())
87}
88
89fn format_method_ident(i: &Ident) -> Ident {
90    format_ident!(
91        "RUSTIC_JSONRPC_METHOD_{}",
92        i.to_string().to_ascii_uppercase()
93    )
94}
95
96fn check_generic(func: &ItemFn) -> syn::Result<()> {
97    let generics = &func.sig.generics;
98    for v in &generics.params {
99        match v {
100            GenericParam::Lifetime(_) => (),
101            _ => {
102                return Err(Error::new_spanned(
103                    generics,
104                    "method: generic type is not allowed",
105                ));
106            }
107        }
108    }
109    Ok(())
110}
111
112struct ResetLifetime(bool);
113
114impl ResetLifetime {
115    fn new() -> Self {
116        Self(false)
117    }
118}
119
120impl VisitMut for ResetLifetime {
121    fn visit_lifetime_mut(&mut self, i: &mut Lifetime) {
122        self.0 = true;
123        if i.ident != "param" {
124            i.ident = Ident::new("param", i.ident.span())
125        }
126        visit_lifetime_mut(self, i);
127    }
128
129    fn visit_type_reference_mut(&mut self, i: &mut TypeReference) {
130        if i.lifetime.is_none() {
131            i.lifetime = Some(Lifetime::new("'param", Span::call_site()))
132        }
133        visit_type_reference_mut(self, i);
134    }
135}
136
137struct RemoveLifetime;
138
139impl VisitMut for RemoveLifetime {
140    fn visit_type_reference_mut(&mut self, i: &mut TypeReference) {
141        i.lifetime = None;
142        visit_type_reference_mut(self, i);
143    }
144}
145
146fn borrow(tpe: &Type) -> bool {
147    match tpe {
148        Type::Path(tpe) => match tpe.path.segments.last() {
149            Some(v) => v.ident == "Cow",
150            _ => false,
151        },
152        _ => false,
153    }
154}
155
156fn params_struct(args: &[Argument]) -> proc_macro2::TokenStream {
157    let mut ident = Vec::with_capacity(args.len());
158    let mut ty = Vec::with_capacity(args.len());
159    let mut attr: Vec<Option<Attribute>> = Vec::with_capacity(args.len());
160    let mut reset_lifetime = ResetLifetime::new();
161    for arg in args {
162        let (arg_ident, arg_ty) = match arg.kind {
163            Kind::Param => (arg.ident.as_ref().unwrap(), &arg.ty),
164            Kind::Inject => continue,
165            Kind::From((ref arg_ident, ref arg_ty)) => (arg_ident, arg_ty),
166        };
167
168        ident.push(arg_ident);
169        let mut ty_clone = arg_ty.clone();
170        reset_lifetime.visit_type_mut(&mut ty_clone);
171        ty.push(ty_clone);
172
173        if borrow(&arg.ty) {
174            attr.push(Some(parse_quote!(#[serde(borrow)])));
175        } else {
176            attr.push(None)
177        }
178    }
179
180    if reset_lifetime.0 {
181        quote! {
182            #[derive(rustic_jsonrpc::serde::Deserialize)]
183            #[serde(crate = "rustic_jsonrpc::serde")]
184            struct Params<'param> {
185                #(#attr #ident: #ty,)*
186            }
187        }
188    } else {
189        quote! {
190            #[derive(rustic_jsonrpc::serde::Deserialize)]
191            #[serde(crate = "rustic_jsonrpc::serde")]
192            struct Params {
193                #(#ident: #ty,)*
194            }
195        }
196    }
197}
198
199fn output_assert(item_fn: &ItemFn) -> proc_macro2::TokenStream {
200    match item_fn.sig.output {
201        ReturnType::Default => {
202            Error::new_spanned(&item_fn.sig, "method: expected return type").to_compile_error()
203        }
204        ReturnType::Type(_, ref ty) => quote_spanned! {ty.span()=>
205            { let _ = <#ty as rustic_jsonrpc::MethodResult>::ASSERT; }
206        },
207    }
208}
209
210enum Kind {
211    Param,
212    Inject,
213    From((Ident, Box<Type>)),
214}
215
216struct Argument {
217    ident: Option<Ident>,
218    ty: Box<Type>,
219    kind: Kind,
220}
221
222fn collect_args(func: &mut ItemFn) -> syn::Result<Vec<Argument>> {
223    let mut args = Vec::with_capacity(func.sig.inputs.len());
224    for arg in &mut func.sig.inputs {
225        match arg {
226            FnArg::Typed(arg) => {
227                let kind = remove_arg_attr(&mut arg.attrs)?;
228                match *arg.pat {
229                    Pat::Ident(ref pat) => args.push(Argument {
230                        ident: Some(pat.ident.clone()),
231                        ty: arg.ty.clone(),
232                        kind,
233                    }),
234                    Pat::Wild(_) if matches!(kind, Kind::From(_)) => args.push(Argument {
235                        ident: None,
236                        ty: arg.ty.clone(),
237                        kind,
238                    }),
239                    _ => {
240                        return Err(Error::new_spanned(
241                            arg,
242                            "method: non identifier pattern is not allowed",
243                        ));
244                    }
245                };
246            }
247            FnArg::Receiver(_) => {
248                return Err(Error::new_spanned(
249                    arg,
250                    "method: self parameter is not allowed",
251                ));
252            }
253        }
254    }
255    Ok(args)
256}
257
258fn remove_arg_attr(attrs: &mut Vec<Attribute>) -> syn::Result<Kind> {
259    let mut kind = Kind::Param;
260    if attrs.is_empty() {
261        return Ok(kind);
262    }
263
264    let mut delete = vec![];
265    for i in 0..attrs.len() {
266        match attrs[i].meta {
267            Meta::Path(ref path) if path.is_ident("inject") => match kind {
268                Kind::Param => {
269                    kind = Kind::Inject;
270                    delete.push(i);
271                }
272                _ => return Err(Error::new_spanned(path, "method: unexpected attribute")),
273            },
274            Meta::List(ref list) if list.path.is_ident("from") => {
275                let from = list.parse_args::<From>()?;
276                match kind {
277                    Kind::Param => {
278                        kind = Kind::From(from.0);
279                        delete.push(i);
280                    }
281                    _ => return Err(Error::new_spanned(list, "method: unexpected attribute")),
282                };
283            }
284            _ => (),
285        }
286    }
287
288    for i in delete.into_iter().rev() {
289        attrs.remove(i);
290    }
291    Ok(kind)
292}
293
294struct From((Ident, Box<Type>));
295
296impl Parse for From {
297    fn parse(input: ParseStream) -> syn::Result<Self> {
298        let ident = input.parse::<Ident>()?;
299        input.parse::<Token![:]>()?;
300        let ty = input.parse::<Box<Type>>()?;
301        Ok(Self((ident, ty)))
302    }
303}
304
305#[derive(Default)]
306struct Attr {
307    method_name: Option<String>,
308}
309
310impl Parse for Attr {
311    fn parse(input: ParseStream) -> syn::Result<Self> {
312        let mut attr = Attr::default();
313        while !input.is_empty() {
314            let ident: Ident = input.parse()?;
315            match ident.to_string().as_str() {
316                "name" => {
317                    input.parse::<Token![=]>()?;
318                    let name: LitStr = input.parse()?;
319                    attr.method_name = Some(name.value());
320                    if input.peek(Token![,]) {
321                        input.parse::<Token![,]>()?;
322                    }
323                }
324                _ => {
325                    return Err(Error::new_spanned(
326                        &ident,
327                        format!("method: unknown attribute `{}`", ident),
328                    ));
329                }
330            }
331        }
332        Ok(attr)
333    }
334}