Skip to main content

web_rpc_macro/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7    braced,
8    ext::IdentExt,
9    parenthesized,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote,
12    punctuated::Punctuated,
13    spanned::Spanned,
14    token::Comma,
15    Attribute, FnArg, Ident, Lifetime, Meta, NestedMeta, Pat, PatType, ReturnType, Token, Type,
16    Visibility,
17};
18
19macro_rules! extend_errors {
20    ($errors: ident, $e: expr) => {
21        match $errors {
22            Ok(_) => $errors = Err($e),
23            Err(ref mut errors) => errors.extend($e),
24        }
25    };
26}
27
28/// If `ty` is `impl Stream<Item = T>`, returns Some(T).
29fn stream_item_type(ty: &Type) -> Option<&Type> {
30    if let Type::ImplTrait(impl_trait) = ty {
31        for bound in &impl_trait.bounds {
32            if let syn::TypeParamBound::Trait(trait_bound) = bound {
33                let last_segment = trait_bound.path.segments.last()?;
34                if last_segment.ident == "Stream" {
35                    if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
36                        for arg in &args.args {
37                            if let syn::GenericArgument::Binding(binding) = arg {
38                                if binding.ident == "Item" {
39                                    return Some(&binding.ty);
40                                }
41                            }
42                        }
43                    }
44                }
45            }
46        }
47    }
48    None
49}
50
51struct Service {
52    attrs: Vec<Attribute>,
53    vis: Visibility,
54    ident: Ident,
55    rpcs: Vec<RpcMethod>,
56}
57
58struct RpcMethod {
59    is_async: Option<Token![async]>,
60    attrs: Vec<Attribute>,
61    receiver: syn::Receiver,
62    ident: Ident,
63    args: Vec<PatType>,
64    transfer: HashSet<Ident>,
65    post: HashSet<Ident>,
66    output: ReturnType,
67}
68
69struct ServiceGenerator<'a> {
70    trait_ident: &'a Ident,
71    service_ident: &'a Ident,
72    client_ident: &'a Ident,
73    request_ident: &'a Ident,
74    response_ident: &'a Ident,
75    vis: &'a Visibility,
76    attrs: &'a [Attribute],
77    rpcs: &'a [RpcMethod],
78    camel_case_idents: &'a [Ident],
79    has_borrowed_args: bool,
80    has_streaming_methods: bool,
81}
82
83impl<'a> ServiceGenerator<'a> {
84    fn enum_request(&self) -> TokenStream2 {
85        let &Self {
86            vis,
87            request_ident,
88            camel_case_idents,
89            rpcs,
90            has_borrowed_args,
91            ..
92        } = self;
93        let lifetime = if has_borrowed_args {
94            quote!(<'a>)
95        } else {
96            quote!()
97        };
98        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
99            |(RpcMethod { args, post, .. }, camel_case_ident)| {
100                let fields = args
101                    .iter()
102                    .filter(|arg| {
103                        matches!(&*arg.pat, Pat::Ident(ident) if !post.contains(&ident.ident))
104                    })
105                    .map(|arg| {
106                        if has_borrowed_args {
107                            if let Type::Reference(type_ref) = &*arg.ty {
108                                let mut type_ref = type_ref.clone();
109                                type_ref.lifetime = Some(Lifetime::new(
110                                    "'a",
111                                    type_ref.and_token.span(),
112                                ));
113                                let pat = &arg.pat;
114                                return quote! { #pat: #type_ref };
115                            }
116                        }
117                        quote! { #arg }
118                    });
119                quote! {
120                    #camel_case_ident { #( #fields ),* }
121                }
122            },
123        );
124        quote! {
125            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
126            #vis enum #request_ident #lifetime {
127                #( #variants ),*
128            }
129        }
130    }
131
132    fn enum_response(&self) -> TokenStream2 {
133        let &Self {
134            vis,
135            response_ident,
136            camel_case_idents,
137            rpcs,
138            ..
139        } = self;
140        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
141            |(RpcMethod { output, post, .. }, camel_case_ident)| match output {
142                ReturnType::Type(_, ty) if !post.contains(&Ident::new("return", output.span())) => {
143                    // For streaming methods, use the inner item type
144                    if let Some(item_ty) = stream_item_type(ty) {
145                        quote! {
146                            #camel_case_ident ( #item_ty )
147                        }
148                    } else {
149                        quote! {
150                            #camel_case_ident ( #ty )
151                        }
152                    }
153                }
154                _ => quote! {
155                    #camel_case_ident ( () )
156                },
157            },
158        );
159        quote! {
160            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
161            #vis enum #response_ident {
162                #( #variants ),*
163            }
164        }
165    }
166
167    fn trait_service(&self) -> TokenStream2 {
168        let &Self {
169            attrs,
170            rpcs,
171            vis,
172            trait_ident,
173            ..
174        } = self;
175
176        let unit_type: &Type = &parse_quote!(());
177        let rpc_fns = rpcs.iter().map(
178            |RpcMethod {
179                 attrs,
180                 args,
181                 receiver,
182                 ident,
183                 is_async,
184                 output,
185                 ..
186             }| {
187                if let ReturnType::Type(_, ref ty) = output {
188                    if let Some(item_ty) = stream_item_type(ty) {
189                        return quote_spanned! {ident.span()=>
190                            #( #attrs )*
191                            #is_async fn #ident(#receiver, #( #args ),*) -> impl web_rpc::futures_core::Stream<Item = #item_ty>;
192                        };
193                    }
194                }
195                let output = match output {
196                    ReturnType::Type(_, ref ty) => ty,
197                    ReturnType::Default => unit_type,
198                };
199                quote_spanned! {ident.span()=>
200                    #( #attrs )*
201                    #is_async fn #ident(#receiver, #( #args ),*) -> #output;
202                }
203            },
204        );
205
206        let forward_fns = rpcs
207            .iter()
208            .map(
209                |RpcMethod {
210                     attrs,
211                     args,
212                     receiver,
213                     ident,
214                     is_async,
215                     output,
216                     ..
217                 }| {
218                    {
219                        let output = if let ReturnType::Type(_, ref ty) = output {
220                            if let Some(item_ty) = stream_item_type(ty) {
221                                quote! { impl web_rpc::futures_core::Stream<Item = #item_ty> }
222                            } else {
223                                let ty: &Type = ty;
224                                quote! { #ty }
225                            }
226                        } else {
227                            let ty = unit_type;
228                            quote! { #ty }
229                        };
230                        let do_await = match is_async {
231                            Some(token) => quote_spanned!(token.span=> .await),
232                            None => quote!(),
233                        };
234                        let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
235                            Pat::Ident(ident) => Some(&ident.ident),
236                            _ => None,
237                        });
238                        quote_spanned! {ident.span()=>
239                            #( #attrs )*
240                            #is_async fn #ident(#receiver, #( #args ),*) -> #output {
241                                T::#ident(self, #( #forward_args ),*)#do_await
242                            }
243                        }
244                    }
245                },
246            )
247            .collect::<Vec<_>>();
248
249        quote! {
250            #( #attrs )*
251            #[allow(async_fn_in_trait)]
252            #vis trait #trait_ident {
253                #( #rpc_fns )*
254            }
255
256            impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
257                #( #forward_fns )*
258            }
259            impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
260                #( #forward_fns )*
261            }
262            impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
263                #( #forward_fns )*
264            }
265        }
266    }
267
268    fn struct_client(&self) -> TokenStream2 {
269        let &Self {
270            vis,
271            client_ident,
272            request_ident,
273            response_ident,
274            camel_case_idents,
275            rpcs,
276            has_streaming_methods,
277            ..
278        } = self;
279
280        let rpc_fns = rpcs
281            .iter()
282            .zip(camel_case_idents.iter())
283            .map(|(RpcMethod { attrs, args, transfer, post, ident, output, .. }, camel_case_ident)| {
284                /* sort arguments based on post and transfer attributes */
285                let serialize_arg_idents = args.iter()
286                    .filter_map(|arg| match &*arg.pat {
287                        Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
288                        _ => None
289                    });
290                let post_arg_idents = args.iter()
291                    .filter_map(|arg| match &*arg.pat {
292                        Pat::Ident(ident) if post.contains(&ident.ident) => Some(&ident.ident),
293                        _ => None
294                    });
295                let transfer_arg_idents = args.iter()
296                    .filter_map(|arg| match &*arg.pat {
297                        Pat::Ident(ident) if transfer.contains(&ident.ident) => Some(&ident.ident),
298                        _ => None
299                    });
300
301                // Check if this is a streaming method
302                let is_streaming = matches!(output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some());
303
304                if is_streaming {
305                    let item_ty = match output {
306                        ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
307                        _ => unreachable!(),
308                    };
309
310                    // Common: send the request
311                    let send_request = quote! {
312                        let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
313                        let __request = #request_ident::#camel_case_ident {
314                            #( #serialize_arg_idents ),*
315                        };
316                        let __header = web_rpc::MessageHeader::Request(__seq_id);
317                        let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
318                        let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
319                        let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
320                        let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
321                        let __post: &[&web_rpc::wasm_bindgen::JsValue] =
322                            &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #post_arg_idents.as_ref() ),*];
323                        let __post = web_rpc::js_sys::Array::from_iter(__post);
324                        let __transfer: &[&web_rpc::wasm_bindgen::JsValue] =
325                            &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #transfer_arg_idents.as_ref() ),*];
326                        let __transfer = web_rpc::js_sys::Array::from_iter(__transfer);
327                        self.port.post_message(&__post, &__transfer).unwrap();
328                    };
329
330                    let unpack_stream_item = if post.contains(&Ident::new("return", output.span())) {
331                        quote! {
332                            |(_response, __post_array)| {
333                                web_rpc::wasm_bindgen::JsCast::dyn_into::<#item_ty>(__post_array.shift())
334                                    .unwrap()
335                            }
336                        }
337                    } else {
338                        quote! {
339                            |(__response, _post_array)| {
340                                let #response_ident::#camel_case_ident(__inner) = __response else {
341                                    panic!("received incorrect response variant")
342                                };
343                                __inner
344                            }
345                        }
346                    };
347
348                    quote! {
349                        #( #attrs )*
350                        #vis fn #ident(
351                            &self,
352                            #( #args ),*
353                        ) -> web_rpc::client::StreamReceiver<#item_ty> {
354                            #send_request
355                            let (__item_tx, __item_rx) = web_rpc::futures_channel::mpsc::unbounded();
356                            self.stream_callback_map.borrow_mut().insert(__seq_id, __item_tx);
357                            let __mapped_rx = web_rpc::futures_util::StreamExt::map(
358                                __item_rx,
359                                #unpack_stream_item
360                            );
361                            let __abort_sender = self.abort_sender.clone();
362                            let __stream_callback_map = self.stream_callback_map.clone();
363                            let __dispatcher = self.dispatcher.clone();
364                            web_rpc::client::StreamReceiver::new(
365                                __mapped_rx,
366                                __dispatcher,
367                                std::boxed::Box::new(move || {
368                                    __stream_callback_map.borrow_mut().remove(&__seq_id);
369                                    (__abort_sender)(__seq_id);
370                                }),
371                            )
372                        }
373                    }
374                } else {
375                    // Non-streaming (original logic)
376                    let return_type = match output {
377                        ReturnType::Type(_, ref ty) => quote! {
378                            web_rpc::client::RequestFuture<#ty>
379                        },
380                        _ => quote!(())
381                    };
382                    let maybe_register_callback = match output {
383                        ReturnType::Type(_, _) => quote! {
384                            let (__response_tx, __response_rx) =
385                                web_rpc::futures_channel::oneshot::channel();
386                            self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
387                        },
388                        _ => Default::default()
389                    };
390
391                    let unpack_response = if post.contains(&Ident::new("return", output.span())) {
392                        let unit_output: &Type = &parse_quote!(());
393                        let output = match output {
394                            ReturnType::Type(_, ref ty) => ty,
395                            _ => unit_output
396                        };
397                        quote! {
398                            let (_, __post_response) = response;
399                            web_rpc::wasm_bindgen::JsCast::dyn_into::<#output>(__post_response.shift())
400                                .unwrap()
401                        }
402                    } else {
403                        quote! {
404                            let (__serialize_response, _) = response;
405                            let #response_ident::#camel_case_ident(__inner) = __serialize_response else {
406                                panic!("received incorrect response variant")
407                            };
408                            __inner
409                        }
410                    };
411
412                    let maybe_unpack_and_return_future = match output {
413                        ReturnType::Type(_, _) => quote! {
414                            let __response_future = web_rpc::futures_util::FutureExt::map(
415                                __response_rx,
416                                |response| {
417                                    let response = response.unwrap();
418                                    #unpack_response
419                                }
420                            );
421                            let __abort_sender = self.abort_sender.clone();
422                            let __dispatcher = self.dispatcher.clone();
423                            web_rpc::client::RequestFuture::new(
424                                __response_future,
425                                __dispatcher,
426                                std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
427                        },
428                        _ => Default::default()
429                    };
430
431                    quote! {
432                        #( #attrs )*
433                        #vis fn #ident(
434                            &self,
435                            #( #args ),*
436                        ) -> #return_type {
437                            let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
438                            let __request = #request_ident::#camel_case_ident {
439                                #( #serialize_arg_idents ),*
440                            };
441                            let __header = web_rpc::MessageHeader::Request(__seq_id);
442                            let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
443                            let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
444                            let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
445                            let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
446                            let __post: &[&web_rpc::wasm_bindgen::JsValue] =
447                                &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #post_arg_idents.as_ref() ),*];
448                            let __post = web_rpc::js_sys::Array::from_iter(__post);
449                            let __transfer: &[&web_rpc::wasm_bindgen::JsValue] =
450                                &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #transfer_arg_idents.as_ref() ),*];
451                            let __transfer = web_rpc::js_sys::Array::from_iter(__transfer);
452                            #maybe_register_callback
453                            self.port.post_message(&__post, &__transfer).unwrap();
454                            #maybe_unpack_and_return_future
455                        }
456                    }
457                }
458            });
459
460        let stream_callback_map_field = if has_streaming_methods {
461            quote! {
462                stream_callback_map: std::rc::Rc<
463                    std::cell::RefCell<
464                        web_rpc::client::StreamCallbackMap<#response_ident>
465                    >
466                >,
467            }
468        } else {
469            quote!()
470        };
471
472        let stream_callback_map_pat = if has_streaming_methods {
473            quote! { stream_callback_map, }
474        } else {
475            quote! { _, }
476        };
477
478        let stream_callback_map_init = if has_streaming_methods {
479            quote! { stream_callback_map, }
480        } else {
481            quote! {}
482        };
483
484        quote! {
485            #[derive(core::clone::Clone)]
486            #vis struct #client_ident {
487                callback_map: std::rc::Rc<
488                    std::cell::RefCell<
489                        web_rpc::client::CallbackMap<#response_ident>
490                    >
491                >,
492                #stream_callback_map_field
493                port: web_rpc::port::Port,
494                listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
495                dispatcher: web_rpc::futures_util::future::Shared<
496                    web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
497                >,
498                abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
499                seq_id: std::rc::Rc<std::cell::RefCell<usize>>
500            }
501            impl std::fmt::Debug for #client_ident {
502                fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503                    formatter.debug_struct(std::stringify!(#client_ident))
504                        .finish()
505                }
506            }
507            impl web_rpc::client::Client for #client_ident {
508                type Response = #response_ident;
509            }
510            impl From<web_rpc::client::Configuration<#response_ident>>
511                for #client_ident {
512                fn from((callback_map, #stream_callback_map_pat port, listener, dispatcher, abort_sender):
513                    web_rpc::client::Configuration<#response_ident>) -> Self {
514                    Self {
515                        callback_map,
516                        #stream_callback_map_init
517                        port,
518                        listener,
519                        dispatcher,
520                        abort_sender,
521                        seq_id: std::default::Default::default()
522                    }
523                }
524            }
525            impl #client_ident {
526                #( #rpc_fns )*
527            }
528        }
529    }
530
531    fn struct_server(&self) -> TokenStream2 {
532        let &Self {
533            vis,
534            trait_ident,
535            service_ident,
536            request_ident,
537            response_ident,
538            camel_case_idents,
539            rpcs,
540            has_borrowed_args,
541            ..
542        } = self;
543
544        let request_type = if has_borrowed_args {
545            quote! { #request_ident<'_> }
546        } else {
547            quote! { #request_ident }
548        };
549
550        let handlers = rpcs.iter()
551            .zip(camel_case_idents.iter())
552            .map(|(RpcMethod { is_async, ident, args, transfer, post, output, .. }, camel_case_ident)| {
553                let serialize_arg_idents = args.iter()
554                    .filter_map(|arg| match &*arg.pat {
555                        Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
556                        _ => None
557                    });
558                let extract_js_args = args.iter()
559                    .filter_map(|arg| match &*arg.pat {
560                        Pat::Ident(pat_ident) if post.contains(&pat_ident.ident) => {
561                            let arg_pat = &arg.pat;
562                            let arg_ty = &arg.ty;
563                            if let Type::Reference(type_ref) = &**arg_ty {
564                                let inner_ty = &type_ref.elem;
565                                let tmp_ident = format_ident!("__tmp_{}", pat_ident.ident);
566                                Some(quote! {
567                                    let #tmp_ident = __js_args.shift();
568                                    let #arg_pat: #arg_ty = web_rpc::wasm_bindgen::JsCast::dyn_ref::<#inner_ty>(&#tmp_ident)
569                                        .unwrap();
570                                })
571                            } else {
572                                Some(quote! {
573                                    let #arg_pat = web_rpc::wasm_bindgen::JsCast::dyn_into::<#arg_ty>(__js_args.shift())
574                                        .unwrap();
575                                })
576                            }
577                        },
578                        _ => None
579                    });
580
581                // Check if this is a streaming method
582                let is_streaming = matches!(output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some());
583
584                if is_streaming {
585                    let call_args = args.iter().filter_map(|arg| match &*arg.pat {
586                        Pat::Ident(ident) => Some(&ident.ident),
587                        _ => None
588                    });
589                    let return_ident = Ident::new("return", output.span());
590                    let wrap_item = match (post.contains(&return_ident), transfer.contains(&return_ident)) {
591                        (false, _) => quote! {
592                            let __response = #response_ident::#camel_case_ident(__item);
593                            let __post = web_rpc::js_sys::Array::new();
594                            let __transfer = web_rpc::js_sys::Array::new();
595                        },
596                        (true, false) => quote! {
597                            let __response = #response_ident::#camel_case_ident(());
598                            let __post = web_rpc::js_sys::Array::of1(__item.as_ref());
599                            let __transfer = web_rpc::js_sys::Array::new();
600                        },
601                        (true, true) => quote! {
602                            let __response = #response_ident::#camel_case_ident(());
603                            let __post = web_rpc::js_sys::Array::of1(__item.as_ref());
604                            let __transfer = web_rpc::js_sys::Array::of1(__item.as_ref());
605                        },
606                    };
607                    // Build the forwarding closure (reused for both async/sync)
608                    let fwd_body = quote! {
609                        let __stream_tx_clone = __stream_tx.clone();
610                        web_rpc::pin_utils::pin_mut!(__user_rx);
611                        let __fwd = async move {
612                            while let Some(__item) = web_rpc::futures_util::StreamExt::next(&mut __user_rx).await {
613                                #wrap_item
614                                if __stream_tx_clone.unbounded_send((__seq_id, Some((__response, __post, __transfer)))).is_err() {
615                                    break;
616                                }
617                            }
618                        };
619                        let __fwd = web_rpc::futures_util::FutureExt::fuse(__fwd);
620                        web_rpc::pin_utils::pin_mut!(__fwd);
621                        web_rpc::futures_util::select! {
622                            _ = __abort_rx => {},
623                            _ = __fwd => {},
624                        }
625                        let _ = __stream_tx.unbounded_send((__seq_id, None));
626                        web_rpc::service::ExecuteResult::StreamComplete
627                    };
628
629                    match is_async {
630                        Some(_) => quote! {
631                            #request_ident::#camel_case_ident { #( #serialize_arg_idents ),* } => {
632                                #( #extract_js_args )*
633                                let __get_rx = web_rpc::futures_util::FutureExt::fuse(
634                                    self.server_impl.#ident(#( #call_args ),*)
635                                );
636                                web_rpc::pin_utils::pin_mut!(__get_rx);
637                                let __maybe_rx = web_rpc::futures_util::select! {
638                                    _ = __abort_rx => None,
639                                    __rx = __get_rx => Some(__rx),
640                                };
641                                if let Some(mut __user_rx) = __maybe_rx {
642                                    #fwd_body
643                                } else {
644                                    let _ = __stream_tx.unbounded_send((__seq_id, None));
645                                    web_rpc::service::ExecuteResult::StreamComplete
646                                }
647                            }
648                        },
649                        None => quote! {
650                            #request_ident::#camel_case_ident { #( #serialize_arg_idents ),* } => {
651                                #( #extract_js_args )*
652                                let mut __user_rx = self.server_impl.#ident(#( #call_args ),*);
653                                #fwd_body
654                            }
655                        },
656                    }
657                } else {
658                    // Non-streaming (original logic, but wrapped in ExecuteResult::Response)
659                    let return_ident = Ident::new("return", output.span());
660                    let return_response = match (post.contains(&return_ident), transfer.contains(&return_ident)) {
661                        (false, _) => quote! {
662                            let __post = web_rpc::js_sys::Array::new();
663                            let __transfer = web_rpc::js_sys::Array::new();
664                            (#response_ident::#camel_case_ident(__response), __post, __transfer)
665                        },
666                        (true, false) => quote! {
667                            let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
668                            let __transfer = web_rpc::js_sys::Array::new();
669                            (#response_ident::#camel_case_ident(()), __post, __transfer)
670                        },
671                        (true, true) => quote! {
672                            let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
673                            let __transfer = web_rpc::js_sys::Array::of1(__response.as_ref());
674                            (#response_ident::#camel_case_ident(()), __post, __transfer)
675                        }
676                    };
677                    let call_args = args.iter().filter_map(|arg| match &*arg.pat {
678                        Pat::Ident(ident) => Some(&ident.ident),
679                        _ => None
680                    });
681                    match is_async {
682                        Some(_) => quote! {
683                            #request_ident::#camel_case_ident { #( #serialize_arg_idents ),* } => {
684                                #( #extract_js_args )*
685                                let __task =
686                                    web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #call_args ),*));
687                                web_rpc::pin_utils::pin_mut!(__task);
688                                web_rpc::service::ExecuteResult::Response(
689                                    web_rpc::futures_util::select! {
690                                        _ = __abort_rx => None,
691                                        __response = __task => Some({
692                                            #return_response
693                                        })
694                                    }
695                                )
696                            }
697                        },
698                        None => quote! {
699                            #request_ident::#camel_case_ident { #( #serialize_arg_idents ),* } => {
700                                #( #extract_js_args )*
701                                let __response = self.server_impl.#ident(#( #call_args ),*);
702                                web_rpc::service::ExecuteResult::Response(
703                                    Some({
704                                        #return_response
705                                    })
706                                )
707                            }
708                        }
709                    }
710                }
711            });
712
713        quote! {
714            #vis struct #service_ident<T> {
715                server_impl: T
716            }
717            impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
718                type Response = #response_ident;
719                async fn execute(
720                    &self,
721                    __seq_id: usize,
722                    mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
723                    __payload: std::vec::Vec<u8>,
724                    __js_args: web_rpc::js_sys::Array,
725                    __stream_tx: web_rpc::futures_channel::mpsc::UnboundedSender<
726                        web_rpc::service::StreamMessage<Self::Response>
727                    >,
728                ) -> (usize, web_rpc::service::ExecuteResult<Self::Response>) {
729                    let __request: #request_type = web_rpc::bincode::deserialize(&__payload).unwrap();
730                    let __result = match __request {
731                        #( #handlers )*
732                    };
733                    (__seq_id, __result)
734                }
735            }
736            impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
737                fn from(server_impl: T) -> Self {
738                    Self { server_impl }
739                }
740            }
741        }
742    }
743}
744
745impl<'a> ToTokens for ServiceGenerator<'a> {
746    fn to_tokens(&self, output: &mut TokenStream2) {
747        output.extend(vec![
748            self.enum_request(),
749            self.enum_response(),
750            self.trait_service(),
751            self.struct_client(),
752            self.struct_server(),
753        ])
754    }
755}
756
757impl Parse for Service {
758    fn parse(input: ParseStream) -> syn::Result<Self> {
759        let attrs = input.call(Attribute::parse_outer)?;
760        let vis = input.parse()?;
761        input.parse::<Token![trait]>()?;
762        let ident: Ident = input.parse()?;
763        let content;
764        braced!(content in input);
765        let mut rpcs = Vec::<RpcMethod>::new();
766        while !content.is_empty() {
767            rpcs.push(content.parse()?);
768        }
769
770        Ok(Self {
771            attrs,
772            vis,
773            ident,
774            rpcs,
775        })
776    }
777}
778
779impl Parse for RpcMethod {
780    fn parse(input: ParseStream) -> syn::Result<Self> {
781        let mut errors = Ok(());
782        let attrs = input.call(Attribute::parse_outer)?;
783        let (post_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
784            attr.path
785                .segments
786                .last()
787                .is_some_and(|last_segment| last_segment.ident == "post")
788        });
789        let mut transfer: HashSet<Ident> = HashSet::new();
790        let mut post: HashSet<Ident> = HashSet::new();
791        for post_attr in post_attrs {
792            let parsed_args =
793                post_attr.parse_args_with(Punctuated::<NestedMeta, Token![,]>::parse_terminated)?;
794            for parsed_arg in parsed_args {
795                match &parsed_arg {
796                    NestedMeta::Meta(meta) => match meta {
797                        Meta::Path(path) => {
798                            if let Some(segment) = path.segments.last() {
799                                post.insert(segment.ident.clone());
800                            }
801                        }
802                        Meta::List(list) => match list.path.segments.last() {
803                            Some(last_segment) if last_segment.ident == "transfer" => {
804                                if list.nested.len() != 1 {
805                                    extend_errors!(
806                                        errors,
807                                        syn::Error::new(
808                                            parsed_arg.span(),
809                                            "Syntax error in post attribute"
810                                        )
811                                    );
812                                }
813                                match list.nested.first() {
814                                    Some(NestedMeta::Meta(Meta::Path(path))) => {
815                                        match path.segments.last() {
816                                            Some(segment) => {
817                                                post.insert(segment.ident.clone());
818                                                transfer.insert(segment.ident.clone());
819                                            }
820                                            _ => extend_errors!(
821                                                errors,
822                                                syn::Error::new(
823                                                    parsed_arg.span(),
824                                                    "Syntax error in post attribute"
825                                                )
826                                            ),
827                                        }
828                                    }
829                                    _ => extend_errors!(
830                                        errors,
831                                        syn::Error::new(
832                                            parsed_arg.span(),
833                                            "Syntax error in post attribute"
834                                        )
835                                    ),
836                                }
837                            }
838                            _ => extend_errors!(
839                                errors,
840                                syn::Error::new(
841                                    parsed_arg.span(),
842                                    "Syntax error in post attribute"
843                                )
844                            ),
845                        },
846                        _ => extend_errors!(
847                            errors,
848                            syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
849                        ),
850                    },
851                    _ => extend_errors!(
852                        errors,
853                        syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
854                    ),
855                }
856            }
857        }
858
859        let is_async = input.parse::<Token![async]>().ok();
860        input.parse::<Token![fn]>()?;
861        let ident: Ident = input.parse()?;
862        let content;
863        parenthesized!(content in input);
864        let mut receiver: Option<syn::Receiver> = None;
865        let mut args = Vec::new();
866        for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
867            match arg {
868                FnArg::Typed(captured) => match &*captured.pat {
869                    Pat::Ident(_) => args.push(captured),
870                    _ => {
871                        extend_errors!(
872                            errors,
873                            syn::Error::new(
874                                captured.pat.span(),
875                                "patterns are not allowed in RPC arguments"
876                            )
877                        )
878                    }
879                },
880                FnArg::Receiver(ref recv) => {
881                    if recv.reference.is_none() || recv.mutability.is_some() {
882                        extend_errors!(
883                            errors,
884                            syn::Error::new(
885                                arg.span(),
886                                "RPC methods only support `&self` as a receiver"
887                            )
888                        );
889                    }
890                    receiver = Some(recv.clone());
891                }
892            }
893        }
894        let receiver = match receiver {
895            Some(r) => r,
896            None => {
897                extend_errors!(
898                    errors,
899                    syn::Error::new(
900                        ident.span(),
901                        "RPC methods must include `&self` as the first parameter"
902                    )
903                );
904                parse_quote!(&self)
905            }
906        };
907        let output: ReturnType = input.parse()?;
908        input.parse::<Token![;]>()?;
909
910        let arg_names: HashSet<_> = args
911            .iter()
912            .filter_map(|arg| match &*arg.pat {
913                Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
914                _ => None,
915            })
916            .collect();
917        let return_ident = Ident::new("return", output.span());
918        for ident in &post {
919            if *ident != return_ident && !arg_names.contains(ident) {
920                extend_errors!(
921                    errors,
922                    syn::Error::new(
923                        ident.span(),
924                        format!("`{}` does not match any parameter", ident)
925                    )
926                );
927            }
928        }
929        for ident in &transfer {
930            if *ident != return_ident && !post.contains(ident) {
931                extend_errors!(
932                    errors,
933                    syn::Error::new(
934                        ident.span(),
935                        format!("`{}` is marked as transfer but not as post", ident)
936                    )
937                );
938            }
939        }
940        errors?;
941
942        Ok(Self {
943            is_async,
944            attrs,
945            receiver,
946            ident,
947            args,
948            post,
949            transfer,
950            output,
951        })
952    }
953}
954
955/// This attribute macro should applied to traits that need to be turned into RPCs. The
956/// macro will consume the trait and output three items in its place. For example,
957/// a trait `Calculator` will be replaced with two structs `CalculatorClient` and
958/// `CalculatorService` and a new trait by the same name. All methods must include
959/// `&self` as their first parameter.
960#[proc_macro_attribute]
961pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
962    let Service {
963        ref attrs,
964        ref vis,
965        ref ident,
966        ref rpcs,
967    } = parse_macro_input!(input as Service);
968
969    let camel_case_fn_names: &Vec<_> = &rpcs
970        .iter()
971        .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
972        .collect();
973
974    let has_borrowed_args = rpcs.iter().any(|rpc| {
975        rpc.args.iter().any(|arg| {
976            matches!(&*arg.pat, Pat::Ident(pat_ident) if !rpc.post.contains(&pat_ident.ident))
977                && matches!(&*arg.ty, Type::Reference(_))
978        })
979    });
980
981    let has_streaming_methods = rpcs.iter().any(
982        |rpc| matches!(&rpc.output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()),
983    );
984
985    ServiceGenerator {
986        trait_ident: ident,
987        service_ident: &format_ident!("{}Service", ident),
988        client_ident: &format_ident!("{}Client", ident),
989        request_ident: &format_ident!("{}Request", ident),
990        response_ident: &format_ident!("{}Response", ident),
991        vis,
992        attrs,
993        rpcs,
994        camel_case_idents: &rpcs
995            .iter()
996            .zip(camel_case_fn_names.iter())
997            .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
998            .collect::<Vec<_>>(),
999        has_borrowed_args,
1000        has_streaming_methods,
1001    }
1002    .into_token_stream()
1003    .into()
1004}
1005
1006fn snake_to_camel(ident_str: &str) -> String {
1007    let mut camel_ty = String::with_capacity(ident_str.len());
1008
1009    let mut last_char_was_underscore = true;
1010    for c in ident_str.chars() {
1011        match c {
1012            '_' => last_char_was_underscore = true,
1013            c if last_char_was_underscore => {
1014                camel_ty.extend(c.to_uppercase());
1015                last_char_was_underscore = false;
1016            }
1017            c => camel_ty.extend(c.to_lowercase()),
1018        }
1019    }
1020
1021    camel_ty.shrink_to_fit();
1022    camel_ty
1023}