web_rpc_macro/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, ToTokens};
6use syn::{
7    braced,
8    ext::IdentExt,
9    parenthesized,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote,
12    punctuated::Punctuated,
13    spanned::Spanned,
14    token::Comma,
15    Attribute, FnArg, Ident, Meta, NestedMeta, Pat, PatType, ReturnType, Token, Type, Visibility,
16};
17
18macro_rules! extend_errors {
19    ($errors: ident, $e: expr) => {
20        match $errors {
21            Ok(_) => $errors = Err($e),
22            Err(ref mut errors) => errors.extend($e),
23        }
24    };
25}
26
27struct Service {
28    attrs: Vec<Attribute>,
29    vis: Visibility,
30    ident: Ident,
31    rpcs: Vec<RpcMethod>,
32}
33
34struct RpcMethod {
35    is_async: bool,
36    attrs: Vec<Attribute>,
37    ident: Ident,
38    args: Vec<PatType>,
39    transfer: HashSet<Ident>,
40    post: HashSet<Ident>,
41    output: ReturnType,
42}
43
44struct ServiceGenerator<'a> {
45    trait_ident: &'a Ident,
46    service_ident: &'a Ident,
47    client_ident: &'a Ident,
48    request_ident: &'a Ident,
49    response_ident: &'a Ident,
50    vis: &'a Visibility,
51    attrs: &'a [Attribute],
52    rpcs: &'a [RpcMethod],
53    camel_case_idents: &'a [Ident],
54}
55
56impl<'a> ServiceGenerator<'a> {
57    fn enum_request(&self) -> TokenStream2 {
58        let &Self {
59            vis,
60            request_ident,
61            camel_case_idents,
62            rpcs,
63            ..
64        } = self;
65        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
66            |(RpcMethod { args, post, .. }, camel_case_ident)| {
67                let args_filtered = args.iter().filter(
68                    |arg| matches!(&*arg.pat, Pat::Ident(ident) if !post.contains(&ident.ident)),
69                );
70                quote! {
71                    #camel_case_ident { #( #args_filtered ),* }
72                }
73            },
74        );
75        quote! {
76            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
77            #vis enum #request_ident {
78                #( #variants ),*
79            }
80        }
81    }
82
83    fn enum_response(&self) -> TokenStream2 {
84        let &Self {
85            vis,
86            response_ident,
87            camel_case_idents,
88            rpcs,
89            ..
90        } = self;
91        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
92            |(RpcMethod { output, post, .. }, camel_case_ident)| match output {
93                ReturnType::Type(_, ty) if !post.contains(&Ident::new("return", output.span())) => {
94                    quote! {
95                        #camel_case_ident ( #ty )
96                    }
97                }
98                _ => quote! {
99                    #camel_case_ident ( () )
100                },
101            },
102        );
103        quote! {
104            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
105            #vis enum #response_ident {
106                #( #variants ),*
107            }
108        }
109    }
110
111    fn trait_service(&self) -> TokenStream2 {
112        let &Self {
113            attrs,
114            rpcs,
115            vis,
116            trait_ident,
117            ..
118        } = self;
119
120        let unit_type: &Type = &parse_quote!(());
121        let rpc_fns = rpcs.iter().map(
122            |RpcMethod {
123                 attrs,
124                 args,
125                 ident,
126                 is_async,
127                 output,
128                 ..
129             }| {
130                let output = match output {
131                    ReturnType::Type(_, ref ty) => ty,
132                    ReturnType::Default => unit_type,
133                };
134                let is_async = match is_async {
135                    true => quote!(async),
136                    false => quote!(),
137                };
138                quote! {
139                    #( #attrs )*
140                    #is_async fn #ident(&self, #( #args ),*) -> #output;
141                }
142            },
143        );
144
145        let forward_fns = rpcs
146            .iter()
147            .map(
148                |RpcMethod {
149                     attrs,
150                     args,
151                     ident,
152                     is_async,
153                     output,
154                     ..
155                 }| {
156                    let output = match output {
157                        ReturnType::Type(_, ref ty) => ty,
158                        ReturnType::Default => unit_type,
159                    };
160                    let do_await = match is_async {
161                        true => quote!(.await),
162                        false => quote!(),
163                    };
164                    let is_async = match is_async {
165                        true => quote!(async),
166                        false => quote!(),
167                    };
168                    let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
169                        Pat::Ident(ident) => Some(&ident.ident),
170                        _ => None,
171                    });
172                    quote! {
173                        #( #attrs )*
174                        #is_async fn #ident(&self, #( #args ),*) -> #output {
175                            T::#ident(self, #( #forward_args ),*)#do_await
176                        }
177                    }
178                },
179            )
180            .collect::<Vec<_>>();
181
182        quote! {
183            #( #attrs )*
184            #vis trait #trait_ident {
185                #( #rpc_fns )*
186            }
187
188            impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
189                #( #forward_fns )*
190            }
191            impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
192                #( #forward_fns )*
193            }
194            impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
195                #( #forward_fns )*
196            }
197        }
198    }
199
200    fn struct_client(&self) -> TokenStream2 {
201        let &Self {
202            vis,
203            client_ident,
204            request_ident,
205            response_ident,
206            camel_case_idents,
207            rpcs,
208            ..
209        } = self;
210
211        let rpc_fns = rpcs
212            .iter()
213            .zip(camel_case_idents.iter())
214            .map(|(RpcMethod { attrs, args, transfer, post, ident, output, .. }, camel_case_ident)| {
215                /* sort arguments based on post and transfer attributes */
216                let serialize_arg_idents = args.iter()
217                    .filter_map(|arg| match &*arg.pat {
218                        Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
219                        _ => None
220                    });
221                let post_arg_idents = args.iter()
222                    .filter_map(|arg| match &*arg.pat {
223                        Pat::Ident(ident) if post.contains(&ident.ident) => Some(&ident.ident),
224                        _ => None
225                    });
226                let transfer_arg_idents = args.iter()
227                    .filter_map(|arg| match &*arg.pat {
228                        Pat::Ident(ident) if transfer.contains(&ident.ident) => Some(&ident.ident),
229                        _ => None
230                    });
231
232                let return_type = match output {
233                    ReturnType::Type(_, ref ty) => quote! {
234                        web_rpc::client::RequestFuture<#ty>
235                    },
236                    _ => quote!(())
237                };
238                let maybe_register_callback = match output {
239                    ReturnType::Type(_, _) => quote! {
240                        let (__response_tx, __response_rx) =
241                            web_rpc::futures_channel::oneshot::channel();
242                        self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
243                    },
244                    _ => Default::default()
245                };
246
247                let unpack_response = if post.contains(&Ident::new("return", output.span())) {
248                    let unit_output: &Type = &parse_quote!(());
249                    let output = match output {
250                        ReturnType::Type(_, ref ty) => ty,
251                        _ => unit_output
252                    };
253                    quote! {
254                        let (_, __post_response) = response;
255                        web_rpc::wasm_bindgen::JsCast::dyn_into::<#output>(__post_response.shift())
256                            .unwrap()
257                    }
258                } else {
259                    quote! {
260                        let (__serialize_response, _) = response;
261                        let #response_ident::#camel_case_ident(__inner) = __serialize_response else {
262                            panic!("received incorrect response variant")
263                        };
264                        __inner
265                    }
266                };
267
268                let maybe_unpack_and_return_future = match output {
269                    ReturnType::Type(_, _) => quote! {
270                        let __response_future = web_rpc::futures_util::FutureExt::map(
271                            __response_rx,
272                            |response| {
273                                let response = response.unwrap();
274                                #unpack_response
275                            }
276                        );
277                        let __abort_sender = self.abort_sender.clone();
278                        let __dispatcher = self.dispatcher.clone();
279                        web_rpc::client::RequestFuture::new(
280                            __response_future,
281                            __dispatcher,
282                            std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
283                    },
284                    _ => Default::default()
285                };
286
287                quote! {
288                    #( #attrs )*
289                    #vis fn #ident(
290                        &self,
291                        #( #args ),*
292                    ) -> #return_type {
293                        let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
294                        let __request = #request_ident::#camel_case_ident {
295                            #( #serialize_arg_idents ),*
296                        };
297                        let __serialized = (self.request_serializer)(__seq_id, __request);
298                        let __serialized = js_sys::Uint8Array::from(&__serialized[..]).buffer();
299                        let __post: &[&wasm_bindgen::JsValue] =
300                            &[__serialized.as_ref(), #( #post_arg_idents.as_ref() ),*];
301                        let __post = web_rpc::js_sys::Array::from_iter(__post);
302                        let __transfer: &[&wasm_bindgen::JsValue] =
303                            &[__serialized.as_ref(), #( #transfer_arg_idents.as_ref() ),*];
304                        let __transfer = web_rpc::js_sys::Array::from_iter(__transfer);
305                        #maybe_register_callback
306                        self.port.post_message(&__post, &__transfer).unwrap();
307                        #maybe_unpack_and_return_future
308                    }
309                }
310            });
311
312        quote! {
313            #[derive(core::clone::Clone)]
314            #vis struct #client_ident {
315                callback_map: std::rc::Rc<
316                    std::cell::RefCell<
317                        web_rpc::client::CallbackMap<#response_ident>
318                    >
319                >,
320                port: web_rpc::port::Port,
321                listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
322                dispatcher: web_rpc::futures_util::future::Shared<
323                    web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
324                >,
325                request_serializer: std::rc::Rc<
326                    dyn std::ops::Fn(usize, #request_ident) -> std::vec::Vec<u8>
327                >,
328                abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
329                seq_id: std::rc::Rc<std::cell::RefCell<usize>>
330            }
331            impl std::fmt::Debug for #client_ident {
332                fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333                    formatter.debug_struct(std::stringify!(#client_ident))
334                        .finish()
335                }
336            }
337            impl web_rpc::client::Client for #client_ident {
338                type Request = #request_ident;
339                type Response = #response_ident;
340            }
341            impl From<web_rpc::client::Configuration<#request_ident, #response_ident>>
342                for #client_ident {
343                fn from((callback_map, port, listener, dispatcher, request_serializer, abort_sender):
344                    web_rpc::client::Configuration<#request_ident, #response_ident>) -> Self {
345                    Self {
346                        callback_map,
347                        port,
348                        listener,
349                        dispatcher,
350                        request_serializer,
351                        abort_sender,
352                        seq_id: std::default::Default::default()
353                    }
354                }
355            }
356            impl #client_ident {
357                #( #rpc_fns )*
358            }
359        }
360    }
361
362    fn struct_server(&self) -> TokenStream2 {
363        let &Self {
364            vis,
365            trait_ident,
366            service_ident,
367            request_ident,
368            response_ident,
369            camel_case_idents,
370            rpcs,
371            ..
372        } = self;
373
374        let handlers = rpcs.iter()
375            .zip(camel_case_idents.iter())
376            .map(|(RpcMethod { is_async, ident, args, transfer, post, output, .. }, camel_case_ident)| {
377                let serialize_arg_idents = args.iter()
378                    .filter_map(|arg| match &*arg.pat {
379                        Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
380                        _ => None
381                    });
382                let extract_js_args = args.iter()
383                    .filter_map(|arg| match &*arg.pat {
384                        Pat::Ident(ident) if post.contains(&ident.ident) => {
385                            let arg_pat = &arg.pat;
386                            let arg_ty = &arg.ty;
387                            Some(quote! {
388                                let #arg_pat = web_rpc::wasm_bindgen::JsCast::dyn_into::<#arg_ty>(__js_args.shift())
389                                    .unwrap();
390                            })
391                        },
392                        _ => None
393                    });
394                let return_ident = Ident::new("return", output.span());
395                let return_response = match (post.contains(&return_ident), transfer.contains(&return_ident)) {
396                    (false, _) => quote! {
397                        let __post = web_rpc::js_sys::Array::new();
398                        let __transfer = web_rpc::js_sys::Array::new();
399                        (Self::Response::#camel_case_ident(__response), __post, __transfer)
400                    },
401                    (true, false) => quote! {
402                        let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
403                        let __transfer = web_rpc::js_sys::Array::new();
404                        (Self::Response::#camel_case_ident(()), __post, __transfer)
405                    },
406                    (true, true) => quote! {
407                        let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
408                        let __transfer = web_rpc::js_sys::Array::of1(__response.as_ref());
409                        (Self::Response::#camel_case_ident(()), __post, __transfer)
410                    }
411                };
412                let args = args.iter().filter_map(|arg| match &*arg.pat {
413                    Pat::Ident(ident) => Some(&ident.ident),
414                    _ => None
415                });
416                match is_async {
417                    true => quote! {
418                        Self::Request::#camel_case_ident { #( #serialize_arg_idents ),* } => {
419                            #( #extract_js_args )*
420                            let __task =
421                                web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #args ),*));
422                            web_rpc::pin_utils::pin_mut!(__task);
423                            web_rpc::futures_util::select! {
424                                _ = __abort_rx => None,
425                                __response = __task => Some({
426                                    #return_response
427                                })
428                            }
429                        }
430                    },
431                    false => quote! {
432                        Self::Request::#camel_case_ident { #( #serialize_arg_idents ),* } => {
433                            #( #extract_js_args )*
434                            let __response = self.server_impl.#ident(#( #args ),*);
435                            Some({
436                                #return_response
437                            })
438                        }
439                    }
440                }
441            });
442
443        quote! {
444            #vis struct #service_ident<T> {
445                server_impl: T
446            }
447            impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
448                type Request = #request_ident;
449                type Response = #response_ident;
450                async fn execute(
451                    &self,
452                    __seq_id: usize,
453                    mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
454                    __request: Self::Request,
455                    __js_args: web_rpc::js_sys::Array
456                ) -> (usize, Option<(Self::Response, web_rpc::js_sys::Array, web_rpc::js_sys::Array)>) {
457                    let __result = match __request {
458                        #( #handlers )*
459                    };
460                    (__seq_id, __result)
461                }
462            }
463            impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
464                fn from(server_impl: T) -> Self {
465                    Self { server_impl }
466                }
467            }
468        }
469    }
470}
471
472impl<'a> ToTokens for ServiceGenerator<'a> {
473    fn to_tokens(&self, output: &mut TokenStream2) {
474        output.extend(vec![
475            self.enum_request(),
476            self.enum_response(),
477            self.trait_service(),
478            self.struct_client(),
479            self.struct_server(),
480        ])
481    }
482}
483
484impl Parse for Service {
485    fn parse(input: ParseStream) -> syn::Result<Self> {
486        let attrs = input.call(Attribute::parse_outer)?;
487        let vis = input.parse()?;
488        input.parse::<Token![trait]>()?;
489        let ident: Ident = input.parse()?;
490        let content;
491        braced!(content in input);
492        let mut rpcs = Vec::<RpcMethod>::new();
493        while !content.is_empty() {
494            rpcs.push(content.parse()?);
495        }
496
497        Ok(Self {
498            attrs,
499            vis,
500            ident,
501            rpcs,
502        })
503    }
504}
505
506impl Parse for RpcMethod {
507    fn parse(input: ParseStream) -> syn::Result<Self> {
508        let mut errors = Ok(());
509        let attrs = input.call(Attribute::parse_outer)?;
510        let (post_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
511            attr.path
512                .segments
513                .last()
514                .is_some_and(|last_segment| last_segment.ident == "post")
515        });
516        let mut transfer: HashSet<Ident> = HashSet::new();
517        let mut post: HashSet<Ident> = HashSet::new();
518        for post_attr in post_attrs {
519            let parsed_args =
520                post_attr.parse_args_with(Punctuated::<NestedMeta, Token![,]>::parse_terminated)?;
521            for parsed_arg in parsed_args {
522                match &parsed_arg {
523                    NestedMeta::Meta(meta) => match meta {
524                        Meta::Path(path) => {
525                            if let Some(segment) = path.segments.last() {
526                                post.insert(segment.ident.clone());
527                            }
528                        }
529                        Meta::List(list) => match list.path.segments.last() {
530                            Some(last_segment) if last_segment.ident == "transfer" => {
531                                if list.nested.len() != 1 {
532                                    extend_errors!(
533                                        errors,
534                                        syn::Error::new(
535                                            parsed_arg.span(),
536                                            "Syntax error in post attribute"
537                                        )
538                                    );
539                                }
540                                match list.nested.first() {
541                                    Some(NestedMeta::Meta(Meta::Path(path))) => {
542                                        match path.segments.last() {
543                                            Some(segment) => {
544                                                post.insert(segment.ident.clone());
545                                                transfer.insert(segment.ident.clone());
546                                            }
547                                            _ => extend_errors!(
548                                                errors,
549                                                syn::Error::new(
550                                                    parsed_arg.span(),
551                                                    "Syntax error in post attribute"
552                                                )
553                                            ),
554                                        }
555                                    }
556                                    _ => extend_errors!(
557                                        errors,
558                                        syn::Error::new(
559                                            parsed_arg.span(),
560                                            "Syntax error in post attribute"
561                                        )
562                                    ),
563                                }
564                            }
565                            _ => extend_errors!(
566                                errors,
567                                syn::Error::new(
568                                    parsed_arg.span(),
569                                    "Syntax error in post attribute"
570                                )
571                            ),
572                        },
573                        _ => extend_errors!(
574                            errors,
575                            syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
576                        ),
577                    },
578                    _ => extend_errors!(
579                        errors,
580                        syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
581                    ),
582                }
583            }
584        }
585
586        let is_async = input.parse::<Token![async]>().is_ok();
587        input.parse::<Token![fn]>()?;
588        let ident = input.parse()?;
589        let content;
590        parenthesized!(content in input);
591        let mut args = Vec::new();
592        for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
593            match arg {
594                FnArg::Typed(captured) => match &*captured.pat {
595                    Pat::Ident(_) => args.push(captured),
596                    _ => {
597                        extend_errors!(
598                            errors,
599                            syn::Error::new(
600                                captured.pat.span(),
601                                "patterns are not allowed in RPC arguments"
602                            )
603                        )
604                    }
605                },
606                FnArg::Receiver(_) => {
607                    extend_errors!(
608                        errors,
609                        syn::Error::new(arg.span(), "receivers are not allowed in RPC arguments")
610                    );
611                }
612            }
613        }
614        errors?;
615        let output = input.parse()?;
616        input.parse::<Token![;]>()?;
617
618        Ok(Self {
619            is_async,
620            attrs,
621            ident,
622            args,
623            post,
624            transfer,
625            output,
626        })
627    }
628}
629
630/// This attribute macro should applied to traits that need to be turned into RPCs. The
631/// macro will consume the trait and output three items in its place. For example,
632/// a trait `Calculator` will be replaced with two structs `CalculatorClient` and
633/// `CalculatorService` and a new trait by the same name with the methods which have had
634/// the a `&self` receiver added to them.
635#[proc_macro_attribute]
636pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
637    let Service {
638        ref attrs,
639        ref vis,
640        ref ident,
641        ref rpcs,
642    } = parse_macro_input!(input as Service);
643
644    let camel_case_fn_names: &Vec<_> = &rpcs
645        .iter()
646        .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
647        .collect();
648
649    ServiceGenerator {
650        trait_ident: ident,
651        service_ident: &format_ident!("{}Service", ident),
652        client_ident: &format_ident!("{}Client", ident),
653        request_ident: &format_ident!("{}Request", ident),
654        response_ident: &format_ident!("{}Response", ident),
655        vis,
656        attrs,
657        rpcs,
658        camel_case_idents: &rpcs
659            .iter()
660            .zip(camel_case_fn_names.iter())
661            .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
662            .collect::<Vec<_>>(),
663    }
664    .into_token_stream()
665    .into()
666}
667
668fn snake_to_camel(ident_str: &str) -> String {
669    let mut camel_ty = String::with_capacity(ident_str.len());
670
671    let mut last_char_was_underscore = true;
672    for c in ident_str.chars() {
673        match c {
674            '_' => last_char_was_underscore = true,
675            c if last_char_was_underscore => {
676                camel_ty.extend(c.to_uppercase());
677                last_char_was_underscore = false;
678            }
679            c => camel_ty.extend(c.to_lowercase()),
680        }
681    }
682
683    camel_ty.shrink_to_fit();
684    camel_ty
685}