Skip to main content

web_rpc_macro/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7    braced,
8    ext::IdentExt,
9    parenthesized,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote,
12    punctuated::Punctuated,
13    spanned::Spanned,
14    token::Comma,
15    Attribute, FnArg, Ident, Lifetime, Pat, PatType, ReturnType, Token, Type, Visibility,
16};
17
18macro_rules! extend_errors {
19    ($errors: ident, $e: expr) => {
20        match $errors {
21            Ok(_) => $errors = Err($e),
22            Err(ref mut errors) => errors.extend($e),
23        }
24    };
25}
26
27/// If `ty` is `impl Stream<Item = T>`, returns Some(T).
28fn stream_item_type(ty: &Type) -> Option<&Type> {
29    if let Type::ImplTrait(impl_trait) = ty {
30        for bound in &impl_trait.bounds {
31            if let syn::TypeParamBound::Trait(trait_bound) = bound {
32                let last_segment = trait_bound.path.segments.last()?;
33                if last_segment.ident == "Stream" {
34                    if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
35                        for arg in &args.args {
36                            if let syn::GenericArgument::Binding(binding) = arg {
37                                if binding.ident == "Item" {
38                                    return Some(&binding.ty);
39                                }
40                            }
41                        }
42                    }
43                }
44            }
45        }
46    }
47    None
48}
49
50/// If `ty` is `Option<T>`, returns Some(T).
51fn option_inner_type(ty: &Type) -> Option<&Type> {
52    if let Type::Path(type_path) = ty {
53        let last_seg = type_path.path.segments.last()?;
54        if last_seg.ident == "Option" {
55            if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
56                if args.args.len() == 1 {
57                    if let syn::GenericArgument::Type(inner) = &args.args[0] {
58                        return Some(inner);
59                    }
60                }
61            }
62        }
63    }
64    None
65}
66
67/// If `ty` is `Result<T, E>`, returns Some((T, E)).
68fn result_inner_types(ty: &Type) -> Option<(&Type, &Type)> {
69    if let Type::Path(type_path) = ty {
70        let last_seg = type_path.path.segments.last()?;
71        if last_seg.ident == "Result" {
72            if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
73                if args.args.len() == 2 {
74                    if let (syn::GenericArgument::Type(ok_ty), syn::GenericArgument::Type(err_ty)) =
75                        (&args.args[0], &args.args[1])
76                    {
77                        return Some((ok_ty, err_ty));
78                    }
79                }
80            }
81        }
82    }
83    None
84}
85
86/// True if `ty` is `&str` or `&[u8]` — the only reference shapes we route through
87/// the existing serde-borrowing path (zero-copy, with an `'a` lifetime injected
88/// into the request enum). Any other reference shape goes through the JS path.
89fn is_borrowed_serde_ref(ty: &Type) -> bool {
90    if let Type::Reference(r) = ty {
91        match &*r.elem {
92            Type::Path(p) if p.path.is_ident("str") => return true,
93            Type::Slice(s) => {
94                if let Type::Path(p) = &*s.elem {
95                    if p.path.is_ident("u8") {
96                        return true;
97                    }
98                }
99            }
100            _ => {}
101        }
102    }
103    false
104}
105
106/// True if `ty` is a reference to a presumed JS type (anything other than
107/// `&str`/`&[u8]`). The receiver side decodes these via `JsCast::dyn_ref`.
108fn is_js_ref(ty: &Type) -> bool {
109    matches!(ty, Type::Reference(_)) && !is_borrowed_serde_ref(ty)
110}
111
112/// Recursively emit code that encodes a value of type `ty` into a `WireArg`,
113/// pushing JS values onto `post` as a side-effect.
114///
115/// Caller supplies `value` as a token-tree expression (typically an ident binding).
116/// The emitted code matches structure on `Option`/`Result` and recurses; bare
117/// leaves dispatch through the autoref encoder traits.
118///
119/// Reference-to-JS types should be handled by the caller before invoking this
120/// helper — they cannot be encoded as nested elements (no `Decoder<&T>` impl on
121/// the receiver side).
122fn emit_encode(ty: &Type, value: TokenStream2, post: &TokenStream2) -> TokenStream2 {
123    // Match against `&#value` so the original binding remains accessible to any
124    // transfer-side code emitted alongside the encoder. Match ergonomics binds
125    // `__inner` as a reference inside each arm.
126    if let Some(inner) = option_inner_type(ty) {
127        let inner_enc = emit_encode(inner, quote!(__inner), post);
128        quote_spanned! {ty.span()=>
129            match &#value {
130                ::core::option::Option::Some(__inner) =>
131                    web_rpc::codec::WireArg::Some(std::boxed::Box::new(#inner_enc)),
132                ::core::option::Option::None =>
133                    web_rpc::codec::WireArg::None,
134            }
135        }
136    } else if let Some((ok, err)) = result_inner_types(ty) {
137        let ok_enc = emit_encode(ok, quote!(__inner), post);
138        let err_enc = emit_encode(err, quote!(__inner), post);
139        quote_spanned! {ty.span()=>
140            match &#value {
141                ::core::result::Result::Ok(__inner) =>
142                    web_rpc::codec::WireArg::Ok(std::boxed::Box::new(#ok_enc)),
143                ::core::result::Result::Err(__inner) =>
144                    web_rpc::codec::WireArg::Err(std::boxed::Box::new(#err_enc)),
145            }
146        }
147    } else {
148        quote_spanned! {ty.span()=>
149            {
150                #[allow(unused_imports)]
151                use web_rpc::codec::{
152                    __RpcJsEncode as _,
153                    __RpcSerialEncode as _,
154                };
155                (&#value).__rpc_encode(#post)
156            }
157        }
158    }
159}
160
161/// Recursively emit code that decodes a `WireArg` of type `ty` into a Rust value,
162/// shifting JS values off `post` as needed.
163///
164/// Caller supplies `wire` as a token-tree expression evaluating to a `WireArg`.
165/// Reference-to-JS types should be handled by the caller — see `emit_encode`.
166fn emit_decode(ty: &Type, wire: TokenStream2, post: &TokenStream2) -> TokenStream2 {
167    if let Some(inner) = option_inner_type(ty) {
168        let inner_dec = emit_decode(inner, quote!(*__inner), post);
169        quote_spanned! {ty.span()=>
170            match #wire {
171                web_rpc::codec::WireArg::Some(__inner) =>
172                    ::core::option::Option::Some(#inner_dec),
173                web_rpc::codec::WireArg::None =>
174                    ::core::option::Option::None,
175                _ => panic!("web_rpc: wire/type mismatch — expected Some or None"),
176            }
177        }
178    } else if let Some((ok, err)) = result_inner_types(ty) {
179        let ok_dec = emit_decode(ok, quote!(*__inner), post);
180        let err_dec = emit_decode(err, quote!(*__inner), post);
181        quote_spanned! {ty.span()=>
182            match #wire {
183                web_rpc::codec::WireArg::Ok(__inner) =>
184                    ::core::result::Result::Ok(#ok_dec),
185                web_rpc::codec::WireArg::Err(__inner) =>
186                    ::core::result::Result::Err(#err_dec),
187                _ => panic!("web_rpc: wire/type mismatch — expected Ok or Err"),
188            }
189        }
190    } else {
191        quote_spanned! {ty.span()=>
192            {
193                #[allow(unused_imports)]
194                use web_rpc::codec::{
195                    __RpcJsDecode as _,
196                    __RpcSerialDecode as _,
197                };
198                (&web_rpc::codec::Decoder::<#ty>::default()).__rpc_decode(#wire, #post)
199            }
200        }
201    }
202}
203
204struct Service {
205    attrs: Vec<Attribute>,
206    vis: Visibility,
207    ident: Ident,
208    rpcs: Vec<RpcMethod>,
209}
210
211struct RpcMethod {
212    is_async: Option<Token![async]>,
213    attrs: Vec<Attribute>,
214    receiver: syn::Receiver,
215    ident: Ident,
216    args: Vec<PatType>,
217    transfer: Vec<TransferClause>,
218    output: ReturnType,
219}
220
221/// One entry inside a `#[transfer(...)]` attribute.
222#[allow(dead_code)]
223enum TransferClause {
224    /// `name` — push the parameter itself, unconditionally.
225    BareParam(Ident),
226    /// `data => data.buffer()` — push the expression's result, unconditionally.
227    ParamExpr { name: Ident, body: syn::Expr },
228    /// `data => |Some(d)| d.buffer()` (closure) or
229    /// `data => match { Some(d) => d.buffer(), ... }` (match-block).
230    /// Each `Gate` becomes one `if let pat = &name { __transfer.push(body) }`.
231    ParamGated { name: Ident, gates: Vec<Gate> },
232    /// `return` — push the response value itself, unconditionally.
233    BareReturn,
234    /// `return => |Ok(o)| o.buffer()` or `return => match { ... }`.
235    ReturnGated { gates: Vec<Gate> },
236}
237
238#[allow(dead_code)]
239struct Gate {
240    pat: syn::Pat,
241    body: syn::Expr,
242}
243
244struct ServiceGenerator<'a> {
245    trait_ident: &'a Ident,
246    service_ident: &'a Ident,
247    client_ident: &'a Ident,
248    request_ident: &'a Ident,
249    response_ident: &'a Ident,
250    vis: &'a Visibility,
251    attrs: &'a [Attribute],
252    rpcs: &'a [RpcMethod],
253    camel_case_idents: &'a [Ident],
254    has_borrowed_args: bool,
255    has_streaming_methods: bool,
256}
257
258impl<'a> ServiceGenerator<'a> {
259    fn enum_request(&self) -> TokenStream2 {
260        let &Self {
261            vis,
262            request_ident,
263            camel_case_idents,
264            rpcs,
265            has_borrowed_args,
266            ..
267        } = self;
268        let lifetime = if has_borrowed_args {
269            quote!(<'a>)
270        } else {
271            quote!()
272        };
273        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
274            |(RpcMethod { args, .. }, camel_case_ident)| {
275                let fields = args.iter().map(|arg| {
276                    let pat = &arg.pat;
277                    if is_borrowed_serde_ref(&arg.ty) {
278                        // `&str` / `&[u8]` — keep zero-copy serde borrowing path.
279                        let mut type_ref = match &*arg.ty {
280                            Type::Reference(r) => r.clone(),
281                            _ => unreachable!("is_borrowed_serde_ref guarantees a reference"),
282                        };
283                        type_ref.lifetime =
284                            Some(Lifetime::new("'a", type_ref.and_token.span()));
285                        quote_spanned! {arg.ty.span()=> #pat: #type_ref }
286                    } else {
287                        // Everything else (including `&JsT`) uses the universal
288                        // recursive WireArg representation.
289                        quote_spanned! {arg.ty.span()=>
290                            #pat: web_rpc::codec::WireArg
291                        }
292                    }
293                });
294                quote! {
295                    #camel_case_ident { #( #fields ),* }
296                }
297            },
298        );
299        quote! {
300            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
301            #vis enum #request_ident #lifetime {
302                #( #variants ),*
303            }
304        }
305    }
306
307    fn enum_response(&self) -> TokenStream2 {
308        let &Self {
309            vis,
310            response_ident,
311            camel_case_idents,
312            rpcs,
313            ..
314        } = self;
315        let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
316            |(_method, camel_case_ident)| {
317                // Every method's response variant carries a single uniform
318                // `WireArg`. Notification methods (no return) still get a
319                // variant — the macro fills it with a placeholder that the
320                // client never reads.
321                quote! {
322                    #camel_case_ident ( web_rpc::codec::WireArg )
323                }
324            },
325        );
326        quote! {
327            #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
328            #vis enum #response_ident {
329                #( #variants ),*
330            }
331        }
332    }
333
334    fn trait_service(&self) -> TokenStream2 {
335        let &Self {
336            attrs,
337            rpcs,
338            vis,
339            trait_ident,
340            ..
341        } = self;
342
343        let unit_type: &Type = &parse_quote!(());
344        let rpc_fns = rpcs.iter().map(
345            |RpcMethod {
346                 attrs,
347                 args,
348                 receiver,
349                 ident,
350                 is_async,
351                 output,
352                 ..
353             }| {
354                if let ReturnType::Type(_, ref ty) = output {
355                    if let Some(item_ty) = stream_item_type(ty) {
356                        return quote_spanned! {ident.span()=>
357                            #( #attrs )*
358                            #is_async fn #ident(#receiver, #( #args ),*) -> impl web_rpc::futures_core::Stream<Item = #item_ty>;
359                        };
360                    }
361                }
362                let output = match output {
363                    ReturnType::Type(_, ref ty) => ty,
364                    ReturnType::Default => unit_type,
365                };
366                quote_spanned! {ident.span()=>
367                    #( #attrs )*
368                    #is_async fn #ident(#receiver, #( #args ),*) -> #output;
369                }
370            },
371        );
372
373        let forward_fns = rpcs
374            .iter()
375            .map(
376                |RpcMethod {
377                     attrs,
378                     args,
379                     receiver,
380                     ident,
381                     is_async,
382                     output,
383                     ..
384                 }| {
385                    {
386                        let output = if let ReturnType::Type(_, ref ty) = output {
387                            if let Some(item_ty) = stream_item_type(ty) {
388                                quote! { impl web_rpc::futures_core::Stream<Item = #item_ty> }
389                            } else {
390                                let ty: &Type = ty;
391                                quote! { #ty }
392                            }
393                        } else {
394                            let ty = unit_type;
395                            quote! { #ty }
396                        };
397                        let do_await = match is_async {
398                            Some(token) => quote_spanned!(token.span=> .await),
399                            None => quote!(),
400                        };
401                        let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
402                            Pat::Ident(ident) => Some(&ident.ident),
403                            _ => None,
404                        });
405                        quote_spanned! {ident.span()=>
406                            #( #attrs )*
407                            #is_async fn #ident(#receiver, #( #args ),*) -> #output {
408                                T::#ident(self, #( #forward_args ),*)#do_await
409                            }
410                        }
411                    }
412                },
413            )
414            .collect::<Vec<_>>();
415
416        quote! {
417            #( #attrs )*
418            #[allow(async_fn_in_trait)]
419            #vis trait #trait_ident {
420                #( #rpc_fns )*
421            }
422
423            impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
424                #( #forward_fns )*
425            }
426            impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
427                #( #forward_fns )*
428            }
429            impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
430                #( #forward_fns )*
431            }
432        }
433    }
434
435    fn struct_client(&self) -> TokenStream2 {
436        let &Self {
437            vis,
438            client_ident,
439            request_ident,
440            response_ident,
441            camel_case_idents,
442            rpcs,
443            has_streaming_methods,
444            ..
445        } = self;
446
447        let rpc_fns = rpcs
448            .iter()
449            .zip(camel_case_idents.iter())
450            .map(|(RpcMethod { attrs, args, transfer, ident, output, .. }, camel_case_ident)| {
451                // 1. Per-arg encoding: borrowed `&str`/`&[u8]` pass through inline;
452                // everything else routes through the autoref-dispatched `__rpc_encode`.
453                let mut arg_encodings = Vec::<TokenStream2>::new();
454                let mut request_struct_fields = Vec::<TokenStream2>::new();
455                for arg in args {
456                    let id = match &*arg.pat {
457                        Pat::Ident(p) => &p.ident,
458                        _ => continue,
459                    };
460                    if is_borrowed_serde_ref(&arg.ty) {
461                        request_struct_fields.push(quote! { #id });
462                    } else {
463                        let wire_ident = format_ident!("__wire_{}", id);
464                        let post = quote!(&__post);
465                        let enc = emit_encode(&arg.ty, quote!(#id), &post);
466                        arg_encodings.push(quote! { let #wire_ident = #enc; });
467                        request_struct_fields.push(quote! { #id: #wire_ident });
468                    }
469                }
470
471                // 2. Per-method transfer pushes (param-side only; return-side
472                // clauses are handled in struct_server).
473                let transfer_pushes = transfer.iter().filter_map(|c| match c {
474                    TransferClause::BareParam(name) => Some(quote! {
475                        __transfer.push(#name.as_ref());
476                    }),
477                    TransferClause::ParamExpr { name, body } => Some(quote_spanned! {body.span()=>
478                        {
479                            let _ = &#name; // ensure name is referenced
480                            __transfer.push((#body).as_ref());
481                        }
482                    }),
483                    TransferClause::ParamGated { name, gates } => {
484                        let arms = gates.iter().map(|g| {
485                            let pat = &g.pat;
486                            let body = &g.body;
487                            quote_spanned! {body.span()=>
488                                if let #pat = &#name {
489                                    __transfer.push((#body).as_ref());
490                                }
491                            }
492                        });
493                        Some(quote! { #( #arms )* })
494                    }
495                    TransferClause::BareReturn | TransferClause::ReturnGated { .. } => None,
496                });
497
498                let send_request = quote! {
499                    let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
500                    let __post = web_rpc::js_sys::Array::new();
501                    let __transfer = web_rpc::js_sys::Array::new();
502                    #( #arg_encodings )*
503                    let __request = #request_ident::#camel_case_ident {
504                        #( #request_struct_fields ),*
505                    };
506                    let __header = web_rpc::MessageHeader::Request(__seq_id);
507                    let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
508                    let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
509                    let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
510                    let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
511                    // Prepend [header, payload] in front of the encoded JS values.
512                    __post.unshift(&__payload_buffer);
513                    __post.unshift(&__header_buffer);
514                    __transfer.push(__header_buffer.as_ref());
515                    __transfer.push(__payload_buffer.as_ref());
516                    #( #transfer_pushes )*
517                    self.port.post_message(&__post, &__transfer).unwrap();
518                };
519
520                let is_streaming = matches!(
521                    output,
522                    ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()
523                );
524
525                if is_streaming {
526                    let item_ty = match output {
527                        ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
528                        _ => unreachable!(),
529                    };
530                    let dec = emit_decode(item_ty, quote!(__wire), &quote!(&__post_array));
531
532                    let unpack_stream_item = quote! {
533                        |(__response, __post_array): (#response_ident, web_rpc::js_sys::Array)| {
534                            let #response_ident::#camel_case_ident(__wire) = __response else {
535                                panic!("web_rpc: received incorrect response variant")
536                            };
537                            #dec
538                        }
539                    };
540
541                    quote! {
542                        #( #attrs )*
543                        #vis fn #ident(
544                            &self,
545                            #( #args ),*
546                        ) -> web_rpc::client::StreamReceiver<#item_ty> {
547                            #send_request
548                            let (__item_tx, __item_rx) = web_rpc::futures_channel::mpsc::unbounded();
549                            self.stream_callback_map.borrow_mut().insert(__seq_id, __item_tx);
550                            let __mapped_rx = web_rpc::futures_util::StreamExt::map(
551                                __item_rx,
552                                #unpack_stream_item
553                            );
554                            let __abort_sender = self.abort_sender.clone();
555                            let __stream_callback_map = self.stream_callback_map.clone();
556                            let __dispatcher = self.dispatcher.clone();
557                            web_rpc::client::StreamReceiver::new(
558                                __mapped_rx,
559                                __dispatcher,
560                                std::boxed::Box::new(move || {
561                                    __stream_callback_map.borrow_mut().remove(&__seq_id);
562                                    (__abort_sender)(__seq_id);
563                                }),
564                            )
565                        }
566                    }
567                } else {
568                    let return_type = match output {
569                        ReturnType::Type(_, ref ty) => quote! {
570                            web_rpc::client::RequestFuture<#ty>
571                        },
572                        _ => quote!(()),
573                    };
574                    let maybe_register_callback = match output {
575                        ReturnType::Type(_, _) => quote! {
576                            let (__response_tx, __response_rx) =
577                                web_rpc::futures_channel::oneshot::channel();
578                            self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
579                        },
580                        _ => Default::default(),
581                    };
582
583                    let maybe_unpack_and_return_future = match output {
584                        ReturnType::Type(_, ref ret_ty) => {
585                            let dec = emit_decode(ret_ty, quote!(__wire), &quote!(&__post_array));
586                            quote! {
587                                let __response_future = web_rpc::futures_util::FutureExt::map(
588                                    __response_rx,
589                                    |response| {
590                                        let (__serialize_response, __post_array) = response.unwrap();
591                                        let #response_ident::#camel_case_ident(__wire) = __serialize_response else {
592                                            panic!("web_rpc: received incorrect response variant")
593                                        };
594                                        #dec
595                                    }
596                                );
597                                let __abort_sender = self.abort_sender.clone();
598                                let __dispatcher = self.dispatcher.clone();
599                                web_rpc::client::RequestFuture::new(
600                                    __response_future,
601                                    __dispatcher,
602                                    std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
603                            }
604                        }
605                        _ => Default::default(),
606                    };
607
608                    quote! {
609                        #( #attrs )*
610                        #vis fn #ident(
611                            &self,
612                            #( #args ),*
613                        ) -> #return_type {
614                            #send_request
615                            #maybe_register_callback
616                            #maybe_unpack_and_return_future
617                        }
618                    }
619                }
620            });
621
622        let stream_callback_map_field = if has_streaming_methods {
623            quote! {
624                stream_callback_map: std::rc::Rc<
625                    std::cell::RefCell<
626                        web_rpc::client::StreamCallbackMap<#response_ident>
627                    >
628                >,
629            }
630        } else {
631            quote!()
632        };
633
634        let stream_callback_map_pat = if has_streaming_methods {
635            quote! { stream_callback_map, }
636        } else {
637            quote! { _, }
638        };
639
640        let stream_callback_map_init = if has_streaming_methods {
641            quote! { stream_callback_map, }
642        } else {
643            quote! {}
644        };
645
646        quote! {
647            #[derive(core::clone::Clone)]
648            #vis struct #client_ident {
649                callback_map: std::rc::Rc<
650                    std::cell::RefCell<
651                        web_rpc::client::CallbackMap<#response_ident>
652                    >
653                >,
654                #stream_callback_map_field
655                port: web_rpc::port::Port,
656                listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
657                dispatcher: web_rpc::futures_util::future::Shared<
658                    web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
659                >,
660                abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
661                seq_id: std::rc::Rc<std::cell::RefCell<usize>>
662            }
663            impl std::fmt::Debug for #client_ident {
664                fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665                    formatter.debug_struct(std::stringify!(#client_ident))
666                        .finish()
667                }
668            }
669            impl web_rpc::client::Client for #client_ident {
670                type Response = #response_ident;
671            }
672            impl From<web_rpc::client::Configuration<#response_ident>>
673                for #client_ident {
674                fn from((callback_map, #stream_callback_map_pat port, listener, dispatcher, abort_sender):
675                    web_rpc::client::Configuration<#response_ident>) -> Self {
676                    Self {
677                        callback_map,
678                        #stream_callback_map_init
679                        port,
680                        listener,
681                        dispatcher,
682                        abort_sender,
683                        seq_id: std::default::Default::default()
684                    }
685                }
686            }
687            impl #client_ident {
688                #( #rpc_fns )*
689            }
690        }
691    }
692
693    fn struct_server(&self) -> TokenStream2 {
694        let &Self {
695            vis,
696            trait_ident,
697            service_ident,
698            request_ident,
699            response_ident,
700            camel_case_idents,
701            rpcs,
702            has_borrowed_args,
703            ..
704        } = self;
705
706        let request_type = if has_borrowed_args {
707            quote! { #request_ident<'_> }
708        } else {
709            quote! { #request_ident }
710        };
711
712        let handlers = rpcs.iter()
713            .zip(camel_case_idents.iter())
714            .map(|(RpcMethod { is_async, ident, args, transfer, output, .. }, camel_case_ident)| {
715                // 1. Destructure pattern for the request enum variant.
716                // Borrowed args use their own ident; non-borrowed args bind to __wire_<id>.
717                let destructure_fields: Vec<_> = args.iter()
718                    .filter_map(|arg| {
719                        let id = match &*arg.pat {
720                            Pat::Ident(p) => &p.ident,
721                            _ => return None,
722                        };
723                        Some(if is_borrowed_serde_ref(&arg.ty) {
724                            quote! { #id }
725                        } else {
726                            let wire_ident = format_ident!("__wire_{}", id);
727                            quote! { #id: #wire_ident }
728                        })
729                    })
730                    .collect();
731
732                // 2. Per-arg decoding statements.
733                let arg_decodes: Vec<_> = args.iter()
734                    .filter_map(|arg| {
735                        let id = match &*arg.pat {
736                            Pat::Ident(p) => &p.ident,
737                            _ => return None,
738                        };
739                        if is_borrowed_serde_ref(&arg.ty) {
740                            // Already bound by destructuring.
741                            None
742                        } else if is_js_ref(&arg.ty) {
743                            // `&T` where T: JsCast — no `Decoder<&T>` impl, so we
744                            // shift from the post-array and bind via dyn_ref locally.
745                            let inner_ty = match &*arg.ty {
746                                Type::Reference(r) => &*r.elem,
747                                _ => unreachable!(),
748                            };
749                            let tmp_ident = format_ident!("__tmp_{}", id);
750                            let wire_ident = format_ident!("__wire_{}", id);
751                            let arg_ty = &arg.ty;
752                            Some(quote! {
753                                let #tmp_ident = match #wire_ident {
754                                    web_rpc::codec::WireArg::Js => __js_args.shift(),
755                                    _ => panic!("web_rpc: expected Js wire variant for reference arg"),
756                                };
757                                let #id: #arg_ty = web_rpc::wasm_bindgen::JsCast::dyn_ref::<#inner_ty>(&#tmp_ident)
758                                    .unwrap();
759                            })
760                        } else {
761                            let wire_ident = format_ident!("__wire_{}", id);
762                            let dec = emit_decode(&arg.ty, quote!(#wire_ident), &quote!(&__js_args));
763                            Some(quote! { let #id = #dec; })
764                        }
765                    })
766                    .collect();
767
768                let call_args: Vec<_> = args.iter().filter_map(|arg| match &*arg.pat {
769                    Pat::Ident(ident) => Some(&ident.ident),
770                    _ => None,
771                }).collect();
772
773                // Return-side transfer clauses (BareReturn / ReturnGated).
774                // The scrutinee is `__response` for non-streaming and `__item` for streaming.
775                let make_return_transfer = |scrutinee_ident: &Ident| -> TokenStream2 {
776                    let pushes = transfer.iter().filter_map(|c| match c {
777                        TransferClause::BareReturn => Some(quote! {
778                            __transfer.push(#scrutinee_ident.as_ref());
779                        }),
780                        TransferClause::ReturnGated { gates } => {
781                            let arms = gates.iter().map(|g| {
782                                let pat = &g.pat;
783                                let body = &g.body;
784                                quote_spanned! {body.span()=>
785                                    if let #pat = &#scrutinee_ident {
786                                        __transfer.push((#body).as_ref());
787                                    }
788                                }
789                            });
790                            Some(quote! { #( #arms )* })
791                        }
792                        _ => None,
793                    });
794                    quote! { #( #pushes )* }
795                };
796
797                let is_streaming = matches!(
798                    output,
799                    ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()
800                );
801
802                if is_streaming {
803                    let item_ty = match output {
804                        ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
805                        _ => unreachable!(),
806                    };
807                    let item_enc = emit_encode(item_ty, quote!(__item), &quote!(&__post));
808                    let item_ident = Ident::new("__item", proc_macro2::Span::call_site());
809                    let return_transfer = make_return_transfer(&item_ident);
810
811                    let wrap_item = quote! {
812                        let __post = web_rpc::js_sys::Array::new();
813                        let __transfer = web_rpc::js_sys::Array::new();
814                        let __wire_item = #item_enc;
815                        #return_transfer
816                        let __response = #response_ident::#camel_case_ident(__wire_item);
817                    };
818
819                    let fwd_body = quote! {
820                        let __stream_tx_clone = __stream_tx.clone();
821                        web_rpc::pin_utils::pin_mut!(__user_rx);
822                        let __fwd = async move {
823                            while let Some(__item) = web_rpc::futures_util::StreamExt::next(&mut __user_rx).await {
824                                #wrap_item
825                                if __stream_tx_clone.unbounded_send((__seq_id, Some((__response, __post, __transfer)))).is_err() {
826                                    break;
827                                }
828                            }
829                        };
830                        let __fwd = web_rpc::futures_util::FutureExt::fuse(__fwd);
831                        web_rpc::pin_utils::pin_mut!(__fwd);
832                        web_rpc::futures_util::select! {
833                            _ = __abort_rx => {},
834                            _ = __fwd => {},
835                        }
836                        let _ = __stream_tx.unbounded_send((__seq_id, None));
837                        web_rpc::service::ExecuteResult::StreamComplete
838                    };
839
840                    match is_async {
841                        Some(_) => quote! {
842                            #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
843                                #( #arg_decodes )*
844                                let __get_rx = web_rpc::futures_util::FutureExt::fuse(
845                                    self.server_impl.#ident(#( #call_args ),*)
846                                );
847                                web_rpc::pin_utils::pin_mut!(__get_rx);
848                                let __maybe_rx = web_rpc::futures_util::select! {
849                                    _ = __abort_rx => None,
850                                    __rx = __get_rx => Some(__rx),
851                                };
852                                if let Some(mut __user_rx) = __maybe_rx {
853                                    #fwd_body
854                                } else {
855                                    let _ = __stream_tx.unbounded_send((__seq_id, None));
856                                    web_rpc::service::ExecuteResult::StreamComplete
857                                }
858                            }
859                        },
860                        None => quote! {
861                            #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
862                                #( #arg_decodes )*
863                                let mut __user_rx = self.server_impl.#ident(#( #call_args ),*);
864                                #fwd_body
865                            }
866                        },
867                    }
868                } else {
869                    // Non-streaming.
870                    let resp_ident = Ident::new("__response", proc_macro2::Span::call_site());
871                    let return_transfer = make_return_transfer(&resp_ident);
872                    let return_response = match output {
873                        ReturnType::Type(_, ref ret_ty) => {
874                            let enc = emit_encode(ret_ty, quote!(__response), &quote!(&__post));
875                            quote! {
876                                let __post = web_rpc::js_sys::Array::new();
877                                let __transfer = web_rpc::js_sys::Array::new();
878                                let __wire = #enc;
879                                #return_transfer
880                                (#response_ident::#camel_case_ident(__wire), __post, __transfer)
881                            }
882                        }
883                        _ => {
884                            // Notification — emit a placeholder WireArg.
885                            quote! {
886                                let _ = __response;
887                                let __post = web_rpc::js_sys::Array::new();
888                                let __transfer = web_rpc::js_sys::Array::new();
889                                let __wire = web_rpc::codec::WireArg::Bytes(
890                                    web_rpc::bincode::serialize(&()).unwrap()
891                                );
892                                (#response_ident::#camel_case_ident(__wire), __post, __transfer)
893                            }
894                        }
895                    };
896
897                    match is_async {
898                        Some(_) => quote! {
899                            #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
900                                #( #arg_decodes )*
901                                let __task =
902                                    web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #call_args ),*));
903                                web_rpc::pin_utils::pin_mut!(__task);
904                                web_rpc::service::ExecuteResult::Response(
905                                    web_rpc::futures_util::select! {
906                                        _ = __abort_rx => None,
907                                        __response = __task => Some({
908                                            #return_response
909                                        })
910                                    }
911                                )
912                            }
913                        },
914                        None => quote! {
915                            #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
916                                #( #arg_decodes )*
917                                let __response = self.server_impl.#ident(#( #call_args ),*);
918                                web_rpc::service::ExecuteResult::Response(
919                                    Some({
920                                        #return_response
921                                    })
922                                )
923                            }
924                        }
925                    }
926                }
927            });
928
929        quote! {
930            #vis struct #service_ident<T> {
931                server_impl: T
932            }
933            impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
934                type Response = #response_ident;
935                async fn execute(
936                    &self,
937                    __seq_id: usize,
938                    mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
939                    __payload: std::vec::Vec<u8>,
940                    __js_args: web_rpc::js_sys::Array,
941                    __stream_tx: web_rpc::futures_channel::mpsc::UnboundedSender<
942                        web_rpc::service::StreamMessage<Self::Response>
943                    >,
944                ) -> (usize, web_rpc::service::ExecuteResult<Self::Response>) {
945                    let __request: #request_type = web_rpc::bincode::deserialize(&__payload).unwrap();
946                    let __result = match __request {
947                        #( #handlers )*
948                    };
949                    (__seq_id, __result)
950                }
951            }
952            impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
953                fn from(server_impl: T) -> Self {
954                    Self { server_impl }
955                }
956            }
957        }
958    }
959}
960
961impl<'a> ToTokens for ServiceGenerator<'a> {
962    fn to_tokens(&self, output: &mut TokenStream2) {
963        output.extend(vec![
964            self.enum_request(),
965            self.enum_response(),
966            self.trait_service(),
967            self.struct_client(),
968            self.struct_server(),
969        ])
970    }
971}
972
973impl Parse for Service {
974    fn parse(input: ParseStream) -> syn::Result<Self> {
975        let attrs = input.call(Attribute::parse_outer)?;
976        let vis = input.parse()?;
977        input.parse::<Token![trait]>()?;
978        let ident: Ident = input.parse()?;
979        let content;
980        braced!(content in input);
981        let mut rpcs = Vec::<RpcMethod>::new();
982        while !content.is_empty() {
983            rpcs.push(content.parse()?);
984        }
985
986        Ok(Self {
987            attrs,
988            vis,
989            ident,
990            rpcs,
991        })
992    }
993}
994
995/// Parsed RHS of a `name => ...` or `return => ...` clause.
996enum TransferRhs {
997    Expr(syn::Expr),
998    Gates(Vec<Gate>),
999}
1000
1001fn parse_transfer_rhs(input: ParseStream) -> syn::Result<TransferRhs> {
1002    if input.peek(Token![|]) || input.peek(Token![||]) {
1003        // Closure form: `|pat| body` (or `|| body` — rejected).
1004        let closure: syn::ExprClosure = input.parse()?;
1005        if closure.inputs.len() != 1 {
1006            return Err(syn::Error::new_spanned(
1007                &closure,
1008                "transfer closure must have exactly one parameter",
1009            ));
1010        }
1011        let pat = closure.inputs.into_iter().next().unwrap();
1012        let body = *closure.body;
1013        Ok(TransferRhs::Gates(vec![Gate { pat, body }]))
1014    } else if input.peek(Token![match]) {
1015        // `match { arms }` — no scrutinee. Bespoke syntax.
1016        input.parse::<Token![match]>()?;
1017        let content;
1018        braced!(content in input);
1019        let arms: Punctuated<syn::Arm, Token![,]> =
1020            content.parse_terminated(syn::Arm::parse)?;
1021        let gates = arms
1022            .into_iter()
1023            .map(|a| Gate {
1024                pat: a.pat,
1025                body: *a.body,
1026            })
1027            .collect();
1028        Ok(TransferRhs::Gates(gates))
1029    } else {
1030        // Bare expression — only valid for params; the caller checks.
1031        let body: syn::Expr = input.parse()?;
1032        Ok(TransferRhs::Expr(body))
1033    }
1034}
1035
1036impl Parse for TransferClause {
1037    fn parse(input: ParseStream) -> syn::Result<Self> {
1038        let is_return = input.peek(Token![return]);
1039        let lhs_name: Option<Ident> = if is_return {
1040            input.parse::<Token![return]>()?;
1041            None
1042        } else {
1043            Some(input.parse()?)
1044        };
1045
1046        if input.peek(Token![=>]) {
1047            input.parse::<Token![=>]>()?;
1048            let rhs = parse_transfer_rhs(input)?;
1049            match (lhs_name, rhs) {
1050                (Some(name), TransferRhs::Expr(body)) => {
1051                    Ok(TransferClause::ParamExpr { name, body })
1052                }
1053                (Some(name), TransferRhs::Gates(gates)) => {
1054                    Ok(TransferClause::ParamGated { name, gates })
1055                }
1056                (None, TransferRhs::Gates(gates)) => {
1057                    Ok(TransferClause::ReturnGated { gates })
1058                }
1059                (None, TransferRhs::Expr(_)) => Err(syn::Error::new(
1060                    input.span(),
1061                    "`return =>` requires a closure (`|pat| body`) or `match { arms }` block",
1062                )),
1063            }
1064        } else {
1065            Ok(match lhs_name {
1066                Some(name) => TransferClause::BareParam(name),
1067                None => TransferClause::BareReturn,
1068            })
1069        }
1070    }
1071}
1072
1073impl Parse for RpcMethod {
1074    fn parse(input: ParseStream) -> syn::Result<Self> {
1075        let mut errors = Ok(());
1076        let attrs = input.call(Attribute::parse_outer)?;
1077
1078        // Reject the removed `#[post(...)]` attribute with a migration message.
1079        for attr in &attrs {
1080            if attr
1081                .path
1082                .segments
1083                .last()
1084                .is_some_and(|seg| seg.ident == "post")
1085            {
1086                extend_errors!(
1087                    errors,
1088                    syn::Error::new_spanned(
1089                        attr,
1090                        "`#[post(...)]` has been removed. JS-vs-serialize routing is now \
1091                         inferred from each argument and return type. For transfer semantics, \
1092                         use `#[transfer(...)]` (e.g. `#[transfer(canvas)]`, \
1093                         `#[transfer(data => data.buffer())]`, or \
1094                         `#[transfer(return => |Ok(o)| o.buffer())]`)."
1095                    )
1096                );
1097            }
1098        }
1099
1100        // Partition out the new `#[transfer(...)]` attribute(s).
1101        let (transfer_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
1102            attr.path
1103                .segments
1104                .last()
1105                .is_some_and(|last_segment| last_segment.ident == "transfer")
1106        });
1107        let mut transfer: Vec<TransferClause> = Vec::new();
1108        for transfer_attr in transfer_attrs {
1109            let parsed = transfer_attr
1110                .parse_args_with(Punctuated::<TransferClause, Token![,]>::parse_terminated)?;
1111            transfer.extend(parsed.into_iter());
1112        }
1113
1114        let is_async = input.parse::<Token![async]>().ok();
1115        input.parse::<Token![fn]>()?;
1116        let ident: Ident = input.parse()?;
1117
1118        // Reject generic methods up front — autoref dispatch needs concrete types.
1119        if input.peek(Token![<]) {
1120            let generics: syn::Generics = input.parse()?;
1121            extend_errors!(
1122                errors,
1123                syn::Error::new_spanned(
1124                    generics,
1125                    "web_rpc::service trait methods may not have generic parameters; \
1126                     concrete types are required so the macro can route each argument."
1127                )
1128            );
1129        }
1130
1131        let content;
1132        parenthesized!(content in input);
1133        let mut receiver: Option<syn::Receiver> = None;
1134        let mut args = Vec::new();
1135        for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
1136            match arg {
1137                FnArg::Typed(captured) => match &*captured.pat {
1138                    Pat::Ident(_) => {
1139                        // Reject reference args other than `&str`/`&[u8]`/`&JsT`.
1140                        // (The is_js_ref / is_borrowed_serde_ref classifiers will
1141                        // accept any reference; we let them through here and rely
1142                        // on the receiver-side decoder to fail on unsupported
1143                        // shapes. A dedicated diagnostic comes later.)
1144                        args.push(captured)
1145                    }
1146                    _ => extend_errors!(
1147                        errors,
1148                        syn::Error::new(
1149                            captured.pat.span(),
1150                            "patterns are not allowed in RPC arguments"
1151                        )
1152                    ),
1153                },
1154                FnArg::Receiver(ref recv) => {
1155                    if recv.reference.is_none() || recv.mutability.is_some() {
1156                        extend_errors!(
1157                            errors,
1158                            syn::Error::new(
1159                                arg.span(),
1160                                "RPC methods only support `&self` as a receiver"
1161                            )
1162                        );
1163                    }
1164                    receiver = Some(recv.clone());
1165                }
1166            }
1167        }
1168        let receiver = match receiver {
1169            Some(r) => r,
1170            None => {
1171                extend_errors!(
1172                    errors,
1173                    syn::Error::new(
1174                        ident.span(),
1175                        "RPC methods must include `&self` as the first parameter"
1176                    )
1177                );
1178                parse_quote!(&self)
1179            }
1180        };
1181        let output: ReturnType = input.parse()?;
1182        input.parse::<Token![;]>()?;
1183
1184        // Validate that every transfer clause references a real parameter
1185        // (or `return`, which has no name to check).
1186        let arg_names: HashSet<_> = args
1187            .iter()
1188            .filter_map(|arg| match &*arg.pat {
1189                Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
1190                _ => None,
1191            })
1192            .collect();
1193        for clause in &transfer {
1194            let name_ref = match clause {
1195                TransferClause::BareParam(name)
1196                | TransferClause::ParamExpr { name, .. }
1197                | TransferClause::ParamGated { name, .. } => Some(name),
1198                TransferClause::BareReturn | TransferClause::ReturnGated { .. } => None,
1199            };
1200            if let Some(name) = name_ref {
1201                if !arg_names.contains(name) {
1202                    extend_errors!(
1203                        errors,
1204                        syn::Error::new(
1205                            name.span(),
1206                            format!(
1207                                "`{}` in #[transfer(...)] does not match any parameter",
1208                                name
1209                            )
1210                        )
1211                    );
1212                }
1213            }
1214        }
1215        errors?;
1216
1217        Ok(Self {
1218            is_async,
1219            attrs,
1220            receiver,
1221            ident,
1222            args,
1223            transfer,
1224            output,
1225        })
1226    }
1227}
1228
1229/// This attribute macro should applied to traits that need to be turned into RPCs. The
1230/// macro will consume the trait and output three items in its place. For example,
1231/// a trait `Calculator` will be replaced with two structs `CalculatorClient` and
1232/// `CalculatorService` and a new trait by the same name. All methods must include
1233/// `&self` as their first parameter.
1234#[proc_macro_attribute]
1235pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
1236    let Service {
1237        ref attrs,
1238        ref vis,
1239        ref ident,
1240        ref rpcs,
1241    } = parse_macro_input!(input as Service);
1242
1243    let camel_case_fn_names: &Vec<_> = &rpcs
1244        .iter()
1245        .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
1246        .collect();
1247
1248    let has_borrowed_args = rpcs
1249        .iter()
1250        .any(|rpc| rpc.args.iter().any(|arg| is_borrowed_serde_ref(&arg.ty)));
1251
1252    let has_streaming_methods = rpcs.iter().any(
1253        |rpc| matches!(&rpc.output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()),
1254    );
1255
1256    ServiceGenerator {
1257        trait_ident: ident,
1258        service_ident: &format_ident!("{}Service", ident),
1259        client_ident: &format_ident!("{}Client", ident),
1260        request_ident: &format_ident!("{}Request", ident),
1261        response_ident: &format_ident!("{}Response", ident),
1262        vis,
1263        attrs,
1264        rpcs,
1265        camel_case_idents: &rpcs
1266            .iter()
1267            .zip(camel_case_fn_names.iter())
1268            .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
1269            .collect::<Vec<_>>(),
1270        has_borrowed_args,
1271        has_streaming_methods,
1272    }
1273    .into_token_stream()
1274    .into()
1275}
1276
1277fn snake_to_camel(ident_str: &str) -> String {
1278    let mut camel_ty = String::with_capacity(ident_str.len());
1279
1280    let mut last_char_was_underscore = true;
1281    for c in ident_str.chars() {
1282        match c {
1283            '_' => last_char_was_underscore = true,
1284            c if last_char_was_underscore => {
1285                camel_ty.extend(c.to_uppercase());
1286                last_char_was_underscore = false;
1287            }
1288            c => camel_ty.extend(c.to_lowercase()),
1289        }
1290    }
1291
1292    camel_ty.shrink_to_fit();
1293    camel_ty
1294}