tiny_rpc_macros/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::borrow::Cow;
4
5use proc_macro::TokenStream;
6use proc_macro_error::*;
7use quote::{format_ident, quote, ToTokens};
8use syn::{spanned::Spanned, *};
9
10#[proc_macro_error]
11#[proc_macro_attribute]
12pub fn rpc_trait(_args: TokenStream, trait_body: TokenStream) -> TokenStream {
13    rpc_define(trait_body)
14}
15
16// TODO generics
17#[proc_macro_error]
18#[proc_macro]
19pub fn rpc_define(trait_body: TokenStream) -> TokenStream {
20    let root: Path = parse_quote! { ::tiny_rpc::rpc::re_export }; // TODO extract from meta
21
22    let trait_body = parse_macro_input!(trait_body as ItemTrait);
23    let ident = &trait_body.ident;
24    let func_list = gen_func_list(&trait_body);
25    let (req_rsp_body, req_ident, rsp_ident) =
26        gen_req_rsp(&root, &trait_body.vis, ident, &func_list);
27    let server_body = gen_server(&root, ident, &req_ident, &rsp_ident, &func_list);
28    let client_body = gen_client(&root, ident, &req_ident, &rsp_ident, &func_list);
29
30    let ret = quote! {
31        #trait_body
32        #req_rsp_body
33        #server_body
34        #client_body
35    };
36    if option_env!("RUST_TRACE_MACROS").is_some() {
37        println!("{}", ret);
38    }
39    ret.into()
40}
41
42/// Check if `arg` is receiver we want, i.e., `&self`, `&'a self`, `self: &Self` and `self: &'a Self`
43fn is_ref_receiver(arg: Option<&FnArg>) -> bool {
44    let arg = match arg {
45        Some(arg) => arg,
46        None => return false,
47    };
48
49    match arg {
50        FnArg::Receiver(receiver) => receiver.reference.is_some() && receiver.mutability.is_none(), // `&self`, not `self` or `&mut self`
51        FnArg::Typed(PatType { pat, ty, .. }) => {
52            matches!(pat.as_ref(), Pat::Ident(ident) if ident.ident == "self") // `self: T`
53                && matches!(
54                    ty.as_ref(),
55                    Type::Reference(TypeReference{ mutability, elem, ..})
56                        if mutability.is_none() && matches!(elem.as_ref(), Type::Path(path) if path.qself.is_none() && path.path.is_ident("Self")) // `self: T` where T is a reference to `Self`
57                )
58        }
59    }
60}
61
62/// Generate a list of trait method.
63///
64/// This function emit error and generate dummy for following cases:
65///  - A trait method which has a default implementation.
66///  - The first input is not `&self` or its equivalent.
67///  - An input is given in pattern, e.g., `(x, y): (f32, f32)`.
68///  - A trait method which defined at least one lifetime generic parameter except `'req`
69fn gen_func_list(trait_body: &ItemTrait) -> Vec<Cow<'_, TraitItemMethod>> {
70    let ref_receiver: FnArg = parse_quote!(&self); // const
71
72    trait_body
73        .items
74        .iter()
75        .filter_map(|item| match item {
76            TraitItem::Method(method) => {
77                // check signatures and prepare dummy for bad method
78
79                let mut method = Cow::Borrowed(method);
80                if method.default.is_some() {
81                    emit_error!(
82                        method.default,
83                        "trait method can't have default implementation"
84                    );
85
86                    // create dummy by remove default and append semicolon(`;`)
87                    let mut dummy = method.into_owned();
88                    dummy.semi_token = Some(Token![;](dummy.default.span()));
89                    dummy.default = None;
90                    method = Cow::Owned(dummy);
91                }
92
93                if !is_ref_receiver(method.sig.inputs.first()) {
94                    emit_error!(method, "trait method must have `&self` receiver");
95
96                    // create dummy by replace bad receiver, append one if none is provided
97                    let mut dummy = method.into_owned();
98                    match dummy.sig.inputs.first() {
99                        Some(FnArg::Receiver(_)) => {
100                            // e.g., `&mut self`
101                            *(dummy
102                                .sig
103                                .inputs
104                                .first_mut()
105                                .expect("infallible: non-mutable use before")) =
106                                ref_receiver.clone();
107                        }
108                        Some(FnArg::Typed(PatType { pat, .. })) => match &**pat {
109                            Pat::Ident(PatIdent { ident, .. }) if ident == "self" => {
110                                // e.g., `self: Box<Self>`
111                                *(dummy
112                                    .sig
113                                    .inputs
114                                    .first_mut()
115                                    .expect("infallible: non-mutable use before")) =
116                                    ref_receiver.clone();
117                            }
118                            _ => {
119                                // no receiver, just argument
120                                dummy.sig.inputs.insert(0, ref_receiver.clone());
121                            }
122                        },
123                        None => {
124                            // no receiver and argument
125                            dummy.sig.inputs.insert(0, ref_receiver.clone());
126                        }
127                    }
128                    method = Cow::Owned(dummy);
129                }
130
131                for i in 0..(method.sig.inputs.len()) {
132                    if let FnArg::Typed(PatType { ref pat, .. }) = method.sig.inputs[i] {
133                        match pat.as_ref() {
134                            Pat::Ident(_) => {}
135                            other => {
136                                emit_error!(other, "trait method cannot use pattern as argument");
137
138                                let dummy_ident = format_ident!("__dummy_{:x}", {
139                                    use std::hash::{Hash, Hasher};
140
141                                    let mut h =
142                                        std::collections::hash_map::DefaultHasher::default();
143                                    other.hash(&mut h);
144                                    h.finish()
145                                });
146
147                                let new_pat = Box::new(Pat::Ident(PatIdent {
148                                    ident: dummy_ident,
149                                    attrs: Default::default(),
150                                    by_ref: None,
151                                    mutability: None,
152                                    subpat: None,
153                                }));
154
155                                let mut dummy = method.into_owned();
156                                match dummy.sig.inputs[i] {
157                                    FnArg::Typed(PatType { ref mut pat, .. }) => *pat = new_pat,
158                                    _ => unreachable!(),
159                                }
160                                method = Cow::Owned(dummy);
161                            }
162                        }
163                    }
164                }
165
166                for lifetime in method.sig.generics.lifetimes() {
167                    if lifetime.lifetime.ident != "req" {
168                        emit_error!(
169                            lifetime.lifetime.span(),
170                            "trait method may only have one lifetime parameter called `'req`"
171                        );
172                    }
173                }
174
175                Some(method)
176            }
177            item => {
178                emit_error!(
179                    item,
180                    "#[rpc_define] trait cannot have any item other than function"
181                );
182                None
183            }
184        })
185        .collect::<Vec<_>>()
186}
187
188fn gen_req_rsp<'a>(
189    root: &Path,
190    vis: &Visibility,
191    ident: &Ident,
192    func_list: &[Cow<'a, TraitItemMethod>],
193) -> (proc_macro2::TokenStream, Ident, Ident) {
194    let unit_type = parse_quote!(()); // const
195    let serde_borrow_attr: Attribute = parse_quote!(#[serde(borrow)]); // const
196
197    let req_ident = format_ident!("{}Request", ident);
198    let rsp_ident = format_ident!("{}Response", ident);
199    let serde_path = format!("{}::serde", root.to_token_stream());
200    let serde_path = LitStr::new(serde_path.as_str(), root.span());
201
202    let func_ident = func_list
203        .iter()
204        .map(|method| &method.sig.ident)
205        .collect::<Vec<_>>();
206    let input_type = func_list.iter().map(|method| {
207        method
208            .sig
209            .inputs
210            .iter()
211            .skip(1) // Skip the receiver, which must exist.
212            .map(|input| match input {
213                FnArg::Typed(PatType { ty, .. }) => ty,
214                FnArg::Receiver(_) => unreachable!(),
215            })
216            .collect::<Vec<_>>()
217    });
218    let input_borrow = func_list.iter().map(|method| {
219        if method.sig.generics.lifetimes().next().is_some() {
220            Some(&serde_borrow_attr)
221        } else {
222            None
223        }
224    });
225    let output_type = func_list.iter().map(|method| match method.sig.output {
226        ReturnType::Default => &unit_type,
227        ReturnType::Type(_, ref ty) => ty.as_ref(),
228    });
229
230    let req_rsp = quote! {
231        #[derive(#root::Serialize, #root::Deserialize)]
232        #[serde(crate = #serde_path)]
233        #[serde(deny_unknown_fields)]
234        #[allow(non_camel_case_types)]
235        #vis enum #req_ident<'req> {
236            #( #func_ident ( #input_borrow ( #(#input_type,)* ) ), )*
237            ___tiny_rpc_marker((#root::Never, #root::PhantomData<&'req ()>))
238        }
239
240        #[derive(#root::Serialize, #root::Deserialize)]
241        #[serde(crate = #serde_path)]
242        #[serde(deny_unknown_fields)]
243        #[allow(non_camel_case_types)]
244        #vis enum #rsp_ident {
245            #( #func_ident ( #output_type ), )*
246        }
247    };
248
249    (req_rsp, req_ident, rsp_ident)
250}
251
252fn gen_server<'a>(
253    root: &Path,
254    ident: &Ident,
255    req_ident: &Ident,
256    rsp_ident: &Ident,
257    func_list: &[Cow<'a, TraitItemMethod>],
258) -> proc_macro2::TokenStream {
259    let null_stream = quote! {}; // const
260    let keyword_await = quote! { .await }; // const
261
262    let server_ident = format_ident!("{}Server", ident);
263    let func_ident = func_list
264        .iter()
265        .map(|method| &method.sig.ident)
266        .collect::<Vec<_>>();
267    let input_ident = func_list
268        .iter()
269        .map(|method| {
270            method
271                .sig
272                .inputs
273                .iter()
274                .filter_map(|input| match input {
275                    FnArg::Receiver(_) => None, // NOTE all receivers are now `&self`
276                    FnArg::Typed(PatType { pat, .. }) => match &**pat {
277                        Pat::Ident(ident) => Some(&ident.ident),
278                        _ => unreachable!(),
279                    },
280                })
281                .collect::<Vec<_>>()
282        })
283        .collect::<Vec<_>>();
284    let await_if_async = func_list.iter().map(|method| {
285        method
286            .sig
287            .asyncness
288            .map_or(&null_stream, |_| &keyword_await)
289    });
290
291    quote! {
292        pub struct #server_ident<T: #ident + #root::Send + #root::Sync + 'static>(#root::Arc<T>);
293
294        impl<T: #ident + #root::Send + #root::Sync + 'static> #server_ident<T> {
295            pub fn serve(server_impl: T, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
296                Self::__internal_serve(Self(#root::Arc::new(server_impl)), transport)
297            }
298
299            pub fn serve_arc(server_impl: #root::Arc<T>, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
300                Self::__internal_serve(Self(server_impl), transport)
301            }
302
303            fn __internal_serve(self, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
304                #root::Server::serve(self, transport)
305            }
306        }
307
308        impl<T: #ident + #root::Send + #root::Sync + 'static> #root::Clone for #server_ident<T> {
309            fn clone(&self) -> Self {
310                Self(#root::Clone::clone(&self.0))
311            }
312        }
313
314        impl<T: #ident + #root::Send + #root::Sync + 'static> #root::Server for #server_ident<T> {
315            fn make_response(
316                self,
317                req: #root::RpcFrame,
318            ) -> #root::BoxFuture<'static, #root::Result<#root::RpcFrame>> {
319                #root::FutureExt::boxed(
320                    async move {
321                        let id = req.id()?;
322                        let req = req.data()?;
323                        let rsp = match req {
324                            #(
325                                #req_ident::#func_ident( ( #(#input_ident,)* ) ) => {
326                                    #rsp_ident::#func_ident(self.0.#func_ident(#(#input_ident),*) #await_if_async)
327                                }
328                            )*
329                            #req_ident::___tiny_rpc_marker(_) => #root::unreachable!(),
330                        };
331                        let rsp = #root::RpcFrame::new(id, rsp)?;
332                        Ok(rsp)
333                    }
334                )
335            }
336        }
337    }
338}
339
340fn gen_client<'a>(
341    root: &Path,
342    ident: &Ident,
343    req_ident: &Ident,
344    rsp_ident: &Ident,
345    func_list: &[Cow<'a, TraitItemMethod>],
346) -> proc_macro2::TokenStream {
347    let unit_type: Type = parse_quote!(()); // const
348
349    let client_ident = format_ident!("{}Client", ident);
350    let signature = func_list
351        .iter()
352        .cloned()
353        .map(|method| {
354            let method = method.into_owned();
355            let span = method.span();
356            let mut sig = method.sig;
357
358            if sig.asyncness.is_none() {
359                sig.asyncness = Some(Token![async](span));
360            }
361
362            let ty = match sig.output {
363                ReturnType::Type(_, ty) => *ty,
364                ReturnType::Default => unit_type.clone(),
365            };
366            sig.output = parse_quote! { -> #root::Result<#ty> };
367            sig
368        })
369        .collect::<Vec<_>>();
370    let arg_ident = signature.iter().map(|sig| {
371        sig.inputs
372            .iter()
373            .filter_map(|arg| match arg {
374                FnArg::Receiver(_) => None,
375                FnArg::Typed(PatType { pat, .. }) => match &**pat {
376                    Pat::Ident(ident) => Some(&ident.ident),
377                    _ => unreachable!(),
378                },
379            })
380            .collect::<Vec<_>>()
381    });
382    let func_ident = signature.iter().map(|sig| &sig.ident);
383
384    quote! {
385        #[derive(Clone)]
386        pub struct #client_ident(#root::IdGenerator, #root::ClientDriverHandle);
387
388        impl #root::Client for #client_ident {
389            fn from_handle(handle: #root::ClientDriverHandle) -> Self {
390                Self(#root::IdGenerator::new(), handle)
391            }
392
393            fn handle(&self) -> &#root::ClientDriverHandle {
394                &self.1
395            }
396        }
397
398        impl #client_ident {
399            pub fn new(transport: #root::Transport) -> (Self, #root::BoxFuture<'static, ()>) {
400                #root::Client::new(transport)
401            }
402
403            #(
404                pub #signature {
405                    let args = ( #(#arg_ident,)* );
406                    let id = self.0.next();
407                    let span = info_span!(#root::stringify!(#func_ident), %id);
408
409                    #root::Instrument::instrument(
410                        async move {
411                            let req = #req_ident::#func_ident(args);
412                            let req = #root::RpcFrame::new(id, req)?;
413                            let rsp = <Self as #root::Client>::make_request(self, req).await?;
414                            let rsp = rsp.data()?;
415                            match rsp {
416                                #rsp_ident::#func_ident(ret) => Ok(ret),
417                                _ => Err(#root::Into::into(#root::ProtocolError::ResponseMismatch(id))),
418                            }
419                        },
420                        span,
421                    )
422                    .await
423                }
424            )*
425        }
426    }
427}
428
429#[test]
430fn test_is_ref_receiver() {
431    let ref_receiver: &[FnArg] = &[
432        parse_quote!(self),
433        parse_quote!(&self),
434        parse_quote!(&'a self),
435        parse_quote!(&mut self),
436        parse_quote!(&'a mut self),
437        parse_quote!(self: Self),
438        parse_quote!(self: &Self),
439        parse_quote!(self: &'a Self),
440        parse_quote!(self: &mut Self),
441        parse_quote!(self: &'a mut Self),
442    ];
443    let answer = &[
444        false, true, true, false, false, false, true, true, false, false,
445    ];
446
447    assert_eq!(is_ref_receiver(None), false);
448    for (t, a) in ref_receiver.into_iter().zip(answer) {
449        assert_eq!(is_ref_receiver(Some(t)), *a);
450    }
451}