xy_rpc_macro/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, Span};
4use quote::{format_ident, quote};
5use std::iter::once;
6use syn::punctuated::Punctuated;
7use syn::{
8    parse_macro_input, parse_quote, Field, FieldMutability, Fields, FieldsNamed, FieldsUnnamed,
9    FnArg, GenericArgument, GenericParam, Generics, ItemEnum, ItemTrait, Lifetime, LifetimeParam,
10    PathArguments, ReturnType, Token, TraitItem, Type, TypeParamBound, TypeReference, Variant,
11    Visibility,
12};
13
14#[proc_macro_attribute]
15pub fn rpc_service(_attr: TokenStream, item: TokenStream) -> TokenStream {
16    let mut ast = parse_macro_input!(item as ItemTrait);
17    let vis = ast.vis.clone();
18    {
19        for item in ast.items.iter_mut() {
20            if let TraitItem::Fn(f) = item {
21                if let Some(_r) = f.sig.asyncness.take() {
22                    f.sig.output = match &f.sig.output {
23                        ReturnType::Default => parse_quote! {
24                            -> impl core::future::Future<Output = ()> + xy_rpc::maybe_send::MaybeSend
25                        },
26                        ReturnType::Type(_, ty) => parse_quote! {
27                            -> impl core::future::Future<Output = #ty> + xy_rpc::maybe_send::MaybeSend
28                        },
29                    };
30                }
31            }
32        }
33    }
34    let trait_ident = &ast.ident;
35    let msg_enum_ident = format_ident!("{}Msg", ast.ident);
36    let msg_ref_enum_ident = format_ident!("{}RefMsg", ast.ident);
37    let msg_reply_enum_ident = format_ident!("{}ReplyMsg", ast.ident);
38    // let msg_reply_ref_enum_ident = format_ident!("{}ReplyRefMsg", ast.ident);
39    let _handler_ident = format_ident!("{}Handler", ast.ident);
40    let schema_ident = format_ident!("{}Schema", ast.ident);
41    let caller_ident = format_ident!("{}Caller", ast.ident);
42    let ref_lifetime = Lifetime::new("'a", Span::call_site());
43    let (variants, ref_variants, reply_variants) = ast
44        .items
45        .iter()
46        .filter_map(|n| match n {
47            syn::TraitItem::Fn(n) => Some(n),
48            _ => return None,
49        })
50        .map(|n| {
51            let no_async = n.sig.asyncness.is_none();
52            // let found_input_trans_stream = get_input_trans_stream(n);
53            let ident = format_ident!("{}", n.sig.ident.to_string().to_case(Case::UpperCamel));
54            // match found_input_trans_stream {
55            //     None => {
56            let (named_fields, named_ref_fields) = n
57                .sig
58                .inputs
59                .iter()
60                .filter_map(|n| match n {
61                    FnArg::Typed(n) => Some(n),
62                    _ => None,
63                })
64                .map(|n| {
65                    (
66                        Field {
67                            attrs: vec![],
68                            vis: Visibility::Inherited,
69                            mutability: FieldMutability::None,
70                            ident: match n.pat.as_ref() {
71                                syn::Pat::Ident(n) => Some(n.ident.clone()),
72                                _ => panic!("fm parameters pat: only support named fields"),
73                            },
74                            colon_token: Default::default(),
75                            ty: *n.ty.clone(),
76                        },
77                        Field {
78                            attrs: vec![],
79                            vis: Visibility::Inherited,
80                            mutability: FieldMutability::None,
81                            ident: match n.pat.as_ref() {
82                                syn::Pat::Ident(n) => Some(n.ident.clone()),
83                                _ => panic!("fm parameters pat: only support named fields"),
84                            },
85                            colon_token: Default::default(),
86                            ty: Type::Reference(TypeReference {
87                                and_token: Default::default(),
88                                lifetime: Some(ref_lifetime.clone()),
89                                mutability: None,
90                                elem: Box::new(*n.ty.clone()),
91                            }),
92                        },
93                    )
94                })
95                .collect();
96            let base = Variant {
97                attrs: vec![],
98                ident: ident.clone(),
99                fields: Fields::Named(FieldsNamed {
100                    brace_token: Default::default(),
101                    named: named_fields,
102                }),
103                discriminant: None,
104            };
105            let base_ref = Variant {
106                attrs: vec![],
107                ident: ident.clone(),
108                fields: Fields::Named(FieldsNamed {
109                    brace_token: Default::default(),
110                    named: named_ref_fields,
111                }),
112                discriminant: None,
113            };
114            let base_reply = Variant {
115                attrs: vec![],
116                ident: ident.clone(),
117                fields: Fields::Unnamed({
118                    let ty = match &n.sig.output {
119                        ReturnType::Default => parse_quote!(()),
120                        ReturnType::Type(_, ty) => *ty.clone(),
121                    };
122                    let ty = get_future_output(no_async, &ty);
123                    FieldsUnnamed {
124                        paren_token: Default::default(),
125                        unnamed: once(Field {
126                            attrs: vec![],
127                            vis: Visibility::Inherited,
128                            mutability: FieldMutability::None,
129                            ident: None,
130                            colon_token: None,
131                            ty,
132                        })
133                        .collect(),
134                    }
135                }),
136                discriminant: None,
137            };
138            (base, base_ref, base_reply)
139            // }
140            // Some((_found_input_trans_stream, found_input_trans_stream_index)) => {
141            //     let start_variant = Variant {
142            //         attrs: vec![],
143            //         ident: ident.clone(),
144            //         fields: Fields::Named(FieldsNamed {
145            //             brace_token: Default::default(),
146            //             named: n
147            //                 .sig
148            //                 .inputs
149            //                 .iter()
150            //                 .enumerate()
151            //                 .filter(|n| n.0 != found_input_trans_stream_index)
152            //                 .filter_map(|(_, n)| match n {
153            //                     FnArg::Typed(n) => Some(n),
154            //                     _ => None,
155            //                 })
156            //                 .map(|n| Field {
157            //                     attrs: vec![],
158            //                     vis: Visibility::Inherited,
159            //                     mutability: FieldMutability::None,
160            //                     ident: match n.pat.as_ref() {
161            //                         syn::Pat::Ident(n) => Some(n.ident.clone()),
162            //                         _ => panic!("fm parameters pat: only support named fields"),
163            //                     },
164            //                     colon_token: Default::default(),
165            //                     ty: *n.ty.clone(),
166            //                 })
167            //                 .collect(),
168            //         }),
169            //         discriminant: None,
170            //     };
171            //     let base_reply = Variant {
172            //         attrs: vec![],
173            //         ident: ident.clone(),
174            //         fields: Fields::Unnamed({
175            //             let ty = match &n.sig.output {
176            //                 ReturnType::Default => parse_quote!(()),
177            //                 ReturnType::Type(_, ty) => *ty.clone(),
178            //             };
179            //             let ty = if no_async {
180            //                 let future_output_type = match &ty {
181            //                     Type::ImplTrait(type_impl) => {
182            //                         type_impl.bounds.iter().find_map(|n| match n {
183            //                             TypeParamBound::Trait(t) => {
184            //                                 let x = t
185            //                                     .path
186            //                                     .segments
187            //                                     .iter()
188            //                                     .find(|n| n.ident == "Future");
189            //                                 if let Some(x) = x {
190            //                                     let PathArguments::AngleBracketed(args) =
191            //                                         &x.arguments
192            //                                     else {
193            //                                         panic!("invalid return type")
194            //                                     };
195            //                                     args.args.iter().find_map(|n| match n {
196            //                                         GenericArgument::AssocType(a) => {
197            //                                             if a.ident == "Output" {
198            //                                                 Some(a.ty.clone())
199            //                                             } else {
200            //                                                 None
201            //                                             }
202            //                                         }
203            //                                         _ => None,
204            //                                     })
205            //                                 } else {
206            //                                     None
207            //                                 }
208            //                             }
209            //                             _ => None,
210            //                         })
211            //                     }
212            //                     _ => None,
213            //                 };
214            //                 if let Some(rt) = future_output_type {
215            //                     rt
216            //                 } else {
217            //                     parse_quote! {
218            //                         <#ty as core::future::Future>::Output
219            //                     }
220            //                 }
221            //             } else {
222            //                 ty
223            //             };
224            //             FieldsUnnamed {
225            //                 paren_token: Default::default(),
226            //                 unnamed: once(Field {
227            //                     attrs: vec![],
228            //                     vis: Visibility::Inherited,
229            //                     mutability: FieldMutability::None,
230            //                     ident: None,
231            //                     colon_token: None,
232            //                     ty,
233            //                 })
234            //                 .collect(),
235            //             }
236            //         }),
237            //         discriminant: None,
238            //     };
239            //     Either::Right(vec![(start_variant, base_reply)].into_iter())
240            // }
241            // }
242        })
243        .collect();
244    let msg_enum = ItemEnum {
245        attrs: parse_quote!(#[derive(Debug,serde::Serialize, serde::Deserialize)]),
246        vis: vis.clone(),
247        enum_token: Default::default(),
248        ident: msg_enum_ident.clone(),
249        generics: Default::default(),
250        brace_token: Default::default(),
251        variants,
252    };
253    let msg_ref_enum = ItemEnum {
254        attrs: parse_quote!(#[derive(Debug,serde::Serialize)]),
255        vis: vis.clone(),
256        enum_token: Default::default(),
257        ident: msg_ref_enum_ident.clone(),
258        generics: Generics {
259            lt_token: None,
260            params: once(GenericParam::Lifetime(LifetimeParam {
261                attrs: vec![],
262                lifetime: ref_lifetime.clone(),
263                colon_token: None,
264                bounds: Default::default(),
265            }))
266            .collect(),
267            gt_token: None,
268            where_clause: None,
269        },
270        brace_token: Default::default(),
271        variants: ref_variants,
272    };
273    let msg_reply_enum = ItemEnum {
274        attrs: parse_quote!(#[derive(Debug,serde::Serialize, serde::Deserialize)]),
275        vis: vis.clone(),
276        enum_token: Default::default(),
277        ident: msg_reply_enum_ident.clone(),
278        generics: Default::default(),
279        brace_token: Default::default(),
280        variants: reply_variants,
281    };
282
283    let (rpc_call_fn, rpc_call_fn_impl, match_expr, msg_info_matches, msg_reply_info_matches, ref_msg_info_matches): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = ast
284      .items
285      .iter()
286      .filter_map(|n| match n {
287         syn::TraitItem::Fn(n) => Some(n),
288         _ => return None,
289      })
290      .enumerate()
291      .map(|(i, n)| {
292         let id = i + 1;
293         let fields: Punctuated<Ident, Token![,]> = n
294            .sig
295            .inputs
296            .iter()
297            .filter_map(|n| match n {
298               FnArg::Typed(n) => Some(match n.pat.as_ref() {
299                  syn::Pat::Ident(n) => n.ident.clone(),
300                  _ => panic!("only support named fields"),
301               }),
302               _ => None,
303            })
304            .collect();
305          let mut rpc_call_fn = n.sig.clone();
306          for arg in rpc_call_fn.inputs.iter_mut() {
307              if let FnArg::Typed(arg) = arg {
308                 arg.ty = Box::new(Type::Reference(TypeReference {
309                    and_token: Default::default(),
310                    lifetime: Some(ref_lifetime.clone()),
311                    mutability: None,
312                    elem: Box::new(*arg.ty.clone()),
313                 }));
314              };
315          }
316          rpc_call_fn.generics = Generics {
317              lt_token: Some(Default::default()),
318              params: once(GenericParam::Lifetime(LifetimeParam {
319                  attrs: vec![],
320                  lifetime: ref_lifetime.clone(),
321                  colon_token: None,
322                  bounds: Default::default(),
323              })).collect(),
324              gt_token: Some(Default::default()),
325              where_clause: None,
326          };
327         match rpc_call_fn.output {
328            ReturnType::Default => {
329               rpc_call_fn = parse_quote! {
330                  -> impl core::future::Future<Output = Result<(), xy_rpc::RpcError>> + xy_rpc::maybe_send::MaybeSend +'static
331               }
332            }
333            ReturnType::Type(_, ty) => {
334               let is_async = rpc_call_fn.asyncness.is_some();
335               let output_ty = get_future_output(!is_async, &*ty);
336               if is_async {
337                  rpc_call_fn.asyncness = None;
338               }
339               rpc_call_fn.output = parse_quote! {
340                  -> impl core::future::Future<Output = Result<#output_ty, xy_rpc::RpcError>> + xy_rpc::maybe_send::MaybeSend +'static
341               };
342            }
343         }
344         let fn_ident = &rpc_call_fn.ident;
345         let item_name = fn_ident.to_string().to_case(Case::UpperCamel);
346         let enum_item_ident = format_ident!("{}", item_name);
347         (
348            quote! {
349                #rpc_call_fn
350            },
351            quote! {
352                #rpc_call_fn {
353                    let future = self.call(#msg_ref_enum_ident::#enum_item_ident { #fields });
354                    async move {
355                        let reply = future.await?;
356                        let #msg_reply_enum_ident::#enum_item_ident(reply) = reply.msg else {
357                            return Err(xy_rpc::RpcError::InvalidMsg)
358                        };
359                        Ok(reply)
360                    }
361                }
362            },
363            quote! {
364                 #msg_enum_ident::#enum_item_ident { #fields } => {
365                       let r = self.service.#fn_ident(#fields).await;
366                       #msg_reply_enum_ident::#enum_item_ident(r)
367                 }
368            },
369            quote! {
370                 #msg_enum_ident::#enum_item_ident { .. } => {
371                    xy_rpc::RpcMsgInfo {
372                        id: #id as _,
373                        name: #item_name
374                    }
375                 }
376            },
377            quote! {
378                 #msg_reply_enum_ident::#enum_item_ident(_) => {
379                    xy_rpc::RpcMsgInfo {
380                        id: #id as _,
381                        name: #item_name
382                    }
383                 }
384            },
385            quote! {
386                 #msg_ref_enum_ident::#enum_item_ident { .. } => {
387                    xy_rpc::RpcMsgInfo {
388                        id: #id as _,
389                        name: #item_name
390                    }
391                 }
392            },
393         )
394      })
395      .collect();
396
397    let schema = quote! {
398        #[derive(Clone, Debug, Default)]
399        #vis struct #schema_ident;
400        impl xy_rpc::RpcServiceSchema for #schema_ident
401        {
402            type Msg = #msg_enum_ident;
403            type Reply = #msg_reply_enum_ident;
404        }
405    };
406
407    let impls = quote! {
408        impl<T> xy_rpc::RpcMsgHandler<#schema_ident> for xy_rpc::RpcMsgHandlerWrapper<T>
409        where
410            T: #trait_ident,
411        {
412            fn handle(
413                &self,
414                msg: #msg_enum_ident,
415            ) -> impl core::future::Future<Output = #msg_reply_enum_ident> + xy_rpc::maybe_send::MaybeSend {
416                async move {
417                    match msg {
418                        #(#match_expr)*
419                    }
420                }
421            }
422        }
423         impl<'a> xy_rpc::RpcRefMsg for #msg_ref_enum_ident<'a> {
424               fn info(&self) -> xy_rpc::RpcMsgInfo {
425                  match self {
426                     #(#ref_msg_info_matches)*
427                  }
428               }
429         }
430         impl xy_rpc::RpcRefMsg for #msg_enum_ident {
431               fn info(&self) -> xy_rpc::RpcMsgInfo {
432                  match self {
433                     #(#msg_info_matches)*
434                  }
435               }
436         }
437         impl xy_rpc::RpcMsg for #msg_enum_ident {
438            type Ref<'a>  = #msg_ref_enum_ident<'a>;
439         }
440         impl<'a> xy_rpc::RpcRefMsg for &'a #msg_reply_enum_ident {
441               fn info(&self) -> xy_rpc::RpcMsgInfo {
442                  match self {
443                     #(#msg_reply_info_matches)*
444                  }
445               }
446         }
447         impl xy_rpc::RpcRefMsg for #msg_reply_enum_ident {
448               fn info(&self) -> xy_rpc::RpcMsgInfo {
449                  match self {
450                     #(#msg_reply_info_matches)*
451                  }
452               }
453         }
454         impl xy_rpc::RpcMsg for #msg_reply_enum_ident {
455               type Ref<'a>  = &'a #msg_reply_enum_ident;
456         }
457        #vis trait #caller_ident {
458            #(#rpc_call_fn;)*
459        }
460        impl<CF> #caller_ident for xy_rpc::XyRpcChannel<CF,#schema_ident> where CF: xy_rpc::formats::SerdeFormat {
461            #(#rpc_call_fn_impl)*
462        }
463    };
464
465    quote! {
466        #ast
467        #schema
468        #msg_enum
469        #msg_ref_enum
470        #msg_reply_enum
471        #impls
472    }
473    .into()
474}
475
476fn get_future_output(no_async: bool, ty: &Type) -> Type {
477    if no_async {
478        let future_output_type = match &ty {
479            Type::ImplTrait(type_impl) => type_impl.bounds.iter().find_map(|n| match n {
480                TypeParamBound::Trait(t) => {
481                    let x = t.path.segments.iter().find(|n| n.ident == "Future");
482                    if let Some(x) = x {
483                        let PathArguments::AngleBracketed(args) = &x.arguments else {
484                            panic!("invalid return type")
485                        };
486                        args.args.iter().find_map(|n| match n {
487                            GenericArgument::AssocType(a) => {
488                                if a.ident == "Output" {
489                                    Some(a.ty.clone())
490                                } else {
491                                    None
492                                }
493                            }
494                            _ => None,
495                        })
496                    } else {
497                        None
498                    }
499                }
500                _ => None,
501            }),
502            _ => None,
503        };
504        if let Some(rt) = future_output_type {
505            rt
506        } else {
507            parse_quote! {
508                <#ty as core::future::Future>::Output
509            }
510        }
511    } else {
512        ty.clone()
513    }
514}
515/*
516fn get_input_trans_stream(n: &TraitItemFn) -> Option<(&PathSegment, usize)> {
517    n.sig.inputs.iter().enumerate().find_map(|(i, n)| {
518        let ty = match n {
519            FnArg::Typed(n) => n,
520            _ => return None,
521        };
522        let Type::Path(ty) = &*ty.ty else { return None };
523        let segment = ty.path.segments.last()?;
524        (segment.ident == "TransStream").then_some((segment, i))
525    })
526}
527*/