1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7 braced,
8 ext::IdentExt,
9 parenthesized,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote,
12 punctuated::Punctuated,
13 spanned::Spanned,
14 token::Comma,
15 Attribute, FnArg, Ident, Lifetime, Pat, PatType, ReturnType, Token, Type, Visibility,
16};
17
18macro_rules! extend_errors {
19 ($errors: ident, $e: expr) => {
20 match $errors {
21 Ok(_) => $errors = Err($e),
22 Err(ref mut errors) => errors.extend($e),
23 }
24 };
25}
26
27fn stream_item_type(ty: &Type) -> Option<&Type> {
29 if let Type::ImplTrait(impl_trait) = ty {
30 for bound in &impl_trait.bounds {
31 if let syn::TypeParamBound::Trait(trait_bound) = bound {
32 let last_segment = trait_bound.path.segments.last()?;
33 if last_segment.ident == "Stream" {
34 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
35 for arg in &args.args {
36 if let syn::GenericArgument::Binding(binding) = arg {
37 if binding.ident == "Item" {
38 return Some(&binding.ty);
39 }
40 }
41 }
42 }
43 }
44 }
45 }
46 }
47 None
48}
49
50fn option_inner_type(ty: &Type) -> Option<&Type> {
52 if let Type::Path(type_path) = ty {
53 let last_seg = type_path.path.segments.last()?;
54 if last_seg.ident == "Option" {
55 if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
56 if args.args.len() == 1 {
57 if let syn::GenericArgument::Type(inner) = &args.args[0] {
58 return Some(inner);
59 }
60 }
61 }
62 }
63 }
64 None
65}
66
67fn result_inner_types(ty: &Type) -> Option<(&Type, &Type)> {
69 if let Type::Path(type_path) = ty {
70 let last_seg = type_path.path.segments.last()?;
71 if last_seg.ident == "Result" {
72 if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
73 if args.args.len() == 2 {
74 if let (syn::GenericArgument::Type(ok_ty), syn::GenericArgument::Type(err_ty)) =
75 (&args.args[0], &args.args[1])
76 {
77 return Some((ok_ty, err_ty));
78 }
79 }
80 }
81 }
82 }
83 None
84}
85
86fn is_borrowed_serde_ref(ty: &Type) -> bool {
90 if let Type::Reference(r) = ty {
91 match &*r.elem {
92 Type::Path(p) if p.path.is_ident("str") => return true,
93 Type::Slice(s) => {
94 if let Type::Path(p) = &*s.elem {
95 if p.path.is_ident("u8") {
96 return true;
97 }
98 }
99 }
100 _ => {}
101 }
102 }
103 false
104}
105
106fn is_js_ref(ty: &Type) -> bool {
109 matches!(ty, Type::Reference(_)) && !is_borrowed_serde_ref(ty)
110}
111
112fn emit_encode(ty: &Type, value: TokenStream2, post: &TokenStream2) -> TokenStream2 {
123 if let Some(inner) = option_inner_type(ty) {
127 let inner_enc = emit_encode(inner, quote!(__inner), post);
128 quote_spanned! {ty.span()=>
129 match &#value {
130 ::core::option::Option::Some(__inner) =>
131 web_rpc::codec::WireArg::Some(std::boxed::Box::new(#inner_enc)),
132 ::core::option::Option::None =>
133 web_rpc::codec::WireArg::None,
134 }
135 }
136 } else if let Some((ok, err)) = result_inner_types(ty) {
137 let ok_enc = emit_encode(ok, quote!(__inner), post);
138 let err_enc = emit_encode(err, quote!(__inner), post);
139 quote_spanned! {ty.span()=>
140 match &#value {
141 ::core::result::Result::Ok(__inner) =>
142 web_rpc::codec::WireArg::Ok(std::boxed::Box::new(#ok_enc)),
143 ::core::result::Result::Err(__inner) =>
144 web_rpc::codec::WireArg::Err(std::boxed::Box::new(#err_enc)),
145 }
146 }
147 } else {
148 quote_spanned! {ty.span()=>
149 {
150 #[allow(unused_imports)]
151 use web_rpc::codec::{
152 __RpcJsEncode as _,
153 __RpcSerialEncode as _,
154 };
155 (&#value).__rpc_encode(#post)
156 }
157 }
158 }
159}
160
161fn emit_decode(ty: &Type, wire: TokenStream2, post: &TokenStream2) -> TokenStream2 {
167 if let Some(inner) = option_inner_type(ty) {
168 let inner_dec = emit_decode(inner, quote!(*__inner), post);
169 quote_spanned! {ty.span()=>
170 match #wire {
171 web_rpc::codec::WireArg::Some(__inner) =>
172 ::core::option::Option::Some(#inner_dec),
173 web_rpc::codec::WireArg::None =>
174 ::core::option::Option::None,
175 _ => panic!("web_rpc: wire/type mismatch — expected Some or None"),
176 }
177 }
178 } else if let Some((ok, err)) = result_inner_types(ty) {
179 let ok_dec = emit_decode(ok, quote!(*__inner), post);
180 let err_dec = emit_decode(err, quote!(*__inner), post);
181 quote_spanned! {ty.span()=>
182 match #wire {
183 web_rpc::codec::WireArg::Ok(__inner) =>
184 ::core::result::Result::Ok(#ok_dec),
185 web_rpc::codec::WireArg::Err(__inner) =>
186 ::core::result::Result::Err(#err_dec),
187 _ => panic!("web_rpc: wire/type mismatch — expected Ok or Err"),
188 }
189 }
190 } else {
191 quote_spanned! {ty.span()=>
192 {
193 #[allow(unused_imports)]
194 use web_rpc::codec::{
195 __RpcJsDecode as _,
196 __RpcSerialDecode as _,
197 };
198 (&web_rpc::codec::Decoder::<#ty>::default()).__rpc_decode(#wire, #post)
199 }
200 }
201 }
202}
203
204struct Service {
205 attrs: Vec<Attribute>,
206 vis: Visibility,
207 ident: Ident,
208 rpcs: Vec<RpcMethod>,
209}
210
211struct RpcMethod {
212 is_async: Option<Token![async]>,
213 attrs: Vec<Attribute>,
214 receiver: syn::Receiver,
215 ident: Ident,
216 args: Vec<PatType>,
217 transfer: Vec<TransferClause>,
218 output: ReturnType,
219}
220
221#[allow(dead_code)]
223enum TransferClause {
224 BareParam(Ident),
226 ParamExpr { name: Ident, body: syn::Expr },
228 ParamGated { name: Ident, gates: Vec<Gate> },
232 BareReturn,
234 ReturnGated { gates: Vec<Gate> },
236}
237
238#[allow(dead_code)]
239struct Gate {
240 pat: syn::Pat,
241 body: syn::Expr,
242}
243
244struct ServiceGenerator<'a> {
245 trait_ident: &'a Ident,
246 service_ident: &'a Ident,
247 client_ident: &'a Ident,
248 request_ident: &'a Ident,
249 response_ident: &'a Ident,
250 vis: &'a Visibility,
251 attrs: &'a [Attribute],
252 rpcs: &'a [RpcMethod],
253 camel_case_idents: &'a [Ident],
254 has_borrowed_args: bool,
255 has_streaming_methods: bool,
256}
257
258impl<'a> ServiceGenerator<'a> {
259 fn enum_request(&self) -> TokenStream2 {
260 let &Self {
261 vis,
262 request_ident,
263 camel_case_idents,
264 rpcs,
265 has_borrowed_args,
266 ..
267 } = self;
268 let lifetime = if has_borrowed_args {
269 quote!(<'a>)
270 } else {
271 quote!()
272 };
273 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
274 |(RpcMethod { args, .. }, camel_case_ident)| {
275 let fields = args.iter().map(|arg| {
276 let pat = &arg.pat;
277 if is_borrowed_serde_ref(&arg.ty) {
278 let mut type_ref = match &*arg.ty {
280 Type::Reference(r) => r.clone(),
281 _ => unreachable!("is_borrowed_serde_ref guarantees a reference"),
282 };
283 type_ref.lifetime =
284 Some(Lifetime::new("'a", type_ref.and_token.span()));
285 quote_spanned! {arg.ty.span()=> #pat: #type_ref }
286 } else {
287 quote_spanned! {arg.ty.span()=>
290 #pat: web_rpc::codec::WireArg
291 }
292 }
293 });
294 quote! {
295 #camel_case_ident { #( #fields ),* }
296 }
297 },
298 );
299 quote! {
300 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
301 #vis enum #request_ident #lifetime {
302 #( #variants ),*
303 }
304 }
305 }
306
307 fn enum_response(&self) -> TokenStream2 {
308 let &Self {
309 vis,
310 response_ident,
311 camel_case_idents,
312 rpcs,
313 ..
314 } = self;
315 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
316 |(_method, camel_case_ident)| {
317 quote! {
322 #camel_case_ident ( web_rpc::codec::WireArg )
323 }
324 },
325 );
326 quote! {
327 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
328 #vis enum #response_ident {
329 #( #variants ),*
330 }
331 }
332 }
333
334 fn trait_service(&self) -> TokenStream2 {
335 let &Self {
336 attrs,
337 rpcs,
338 vis,
339 trait_ident,
340 ..
341 } = self;
342
343 let unit_type: &Type = &parse_quote!(());
344 let rpc_fns = rpcs.iter().map(
345 |RpcMethod {
346 attrs,
347 args,
348 receiver,
349 ident,
350 is_async,
351 output,
352 ..
353 }| {
354 if let ReturnType::Type(_, ref ty) = output {
355 if let Some(item_ty) = stream_item_type(ty) {
356 return quote_spanned! {ident.span()=>
357 #( #attrs )*
358 #is_async fn #ident(#receiver, #( #args ),*) -> impl web_rpc::futures_core::Stream<Item = #item_ty>;
359 };
360 }
361 }
362 let output = match output {
363 ReturnType::Type(_, ref ty) => ty,
364 ReturnType::Default => unit_type,
365 };
366 quote_spanned! {ident.span()=>
367 #( #attrs )*
368 #is_async fn #ident(#receiver, #( #args ),*) -> #output;
369 }
370 },
371 );
372
373 let forward_fns = rpcs
374 .iter()
375 .map(
376 |RpcMethod {
377 attrs,
378 args,
379 receiver,
380 ident,
381 is_async,
382 output,
383 ..
384 }| {
385 {
386 let output = if let ReturnType::Type(_, ref ty) = output {
387 if let Some(item_ty) = stream_item_type(ty) {
388 quote! { impl web_rpc::futures_core::Stream<Item = #item_ty> }
389 } else {
390 let ty: &Type = ty;
391 quote! { #ty }
392 }
393 } else {
394 let ty = unit_type;
395 quote! { #ty }
396 };
397 let do_await = match is_async {
398 Some(token) => quote_spanned!(token.span=> .await),
399 None => quote!(),
400 };
401 let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
402 Pat::Ident(ident) => Some(&ident.ident),
403 _ => None,
404 });
405 quote_spanned! {ident.span()=>
406 #( #attrs )*
407 #is_async fn #ident(#receiver, #( #args ),*) -> #output {
408 T::#ident(self, #( #forward_args ),*)#do_await
409 }
410 }
411 }
412 },
413 )
414 .collect::<Vec<_>>();
415
416 quote! {
417 #( #attrs )*
418 #[allow(async_fn_in_trait)]
419 #vis trait #trait_ident {
420 #( #rpc_fns )*
421 }
422
423 impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
424 #( #forward_fns )*
425 }
426 impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
427 #( #forward_fns )*
428 }
429 impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
430 #( #forward_fns )*
431 }
432 }
433 }
434
435 fn struct_client(&self) -> TokenStream2 {
436 let &Self {
437 vis,
438 client_ident,
439 request_ident,
440 response_ident,
441 camel_case_idents,
442 rpcs,
443 has_streaming_methods,
444 ..
445 } = self;
446
447 let rpc_fns = rpcs
448 .iter()
449 .zip(camel_case_idents.iter())
450 .map(|(RpcMethod { attrs, args, transfer, ident, output, .. }, camel_case_ident)| {
451 let mut arg_encodings = Vec::<TokenStream2>::new();
454 let mut request_struct_fields = Vec::<TokenStream2>::new();
455 for arg in args {
456 let id = match &*arg.pat {
457 Pat::Ident(p) => &p.ident,
458 _ => continue,
459 };
460 if is_borrowed_serde_ref(&arg.ty) {
461 request_struct_fields.push(quote! { #id });
462 } else {
463 let wire_ident = format_ident!("__wire_{}", id);
464 let post = quote!(&__post);
465 let enc = emit_encode(&arg.ty, quote!(#id), &post);
466 arg_encodings.push(quote! { let #wire_ident = #enc; });
467 request_struct_fields.push(quote! { #id: #wire_ident });
468 }
469 }
470
471 let transfer_pushes = transfer.iter().filter_map(|c| match c {
474 TransferClause::BareParam(name) => Some(quote! {
475 __transfer.push(#name.as_ref());
476 }),
477 TransferClause::ParamExpr { name, body } => Some(quote_spanned! {body.span()=>
478 {
479 let _ = &#name; __transfer.push((#body).as_ref());
481 }
482 }),
483 TransferClause::ParamGated { name, gates } => {
484 let arms = gates.iter().map(|g| {
485 let pat = &g.pat;
486 let body = &g.body;
487 quote_spanned! {body.span()=>
488 if let #pat = &#name {
489 __transfer.push((#body).as_ref());
490 }
491 }
492 });
493 Some(quote! { #( #arms )* })
494 }
495 TransferClause::BareReturn | TransferClause::ReturnGated { .. } => None,
496 });
497
498 let send_request = quote! {
499 let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
500 let __post = web_rpc::js_sys::Array::new();
501 let __transfer = web_rpc::js_sys::Array::new();
502 #( #arg_encodings )*
503 let __request = #request_ident::#camel_case_ident {
504 #( #request_struct_fields ),*
505 };
506 let __header = web_rpc::MessageHeader::Request(__seq_id);
507 let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
508 let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
509 let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
510 let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
511 __post.unshift(&__payload_buffer);
513 __post.unshift(&__header_buffer);
514 __transfer.push(__header_buffer.as_ref());
515 __transfer.push(__payload_buffer.as_ref());
516 #( #transfer_pushes )*
517 self.port.post_message(&__post, &__transfer).unwrap();
518 };
519
520 let is_streaming = matches!(
521 output,
522 ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()
523 );
524
525 if is_streaming {
526 let item_ty = match output {
527 ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
528 _ => unreachable!(),
529 };
530 let dec = emit_decode(item_ty, quote!(__wire), "e!(&__post_array));
531
532 let unpack_stream_item = quote! {
533 |(__response, __post_array): (#response_ident, web_rpc::js_sys::Array)| {
534 let #response_ident::#camel_case_ident(__wire) = __response else {
535 panic!("web_rpc: received incorrect response variant")
536 };
537 #dec
538 }
539 };
540
541 quote! {
542 #( #attrs )*
543 #vis fn #ident(
544 &self,
545 #( #args ),*
546 ) -> web_rpc::client::StreamReceiver<#item_ty> {
547 #send_request
548 let (__item_tx, __item_rx) = web_rpc::futures_channel::mpsc::unbounded();
549 self.stream_callback_map.borrow_mut().insert(__seq_id, __item_tx);
550 let __mapped_rx = web_rpc::futures_util::StreamExt::map(
551 __item_rx,
552 #unpack_stream_item
553 );
554 let __abort_sender = self.abort_sender.clone();
555 let __stream_callback_map = self.stream_callback_map.clone();
556 let __dispatcher = self.dispatcher.clone();
557 web_rpc::client::StreamReceiver::new(
558 __mapped_rx,
559 __dispatcher,
560 std::boxed::Box::new(move || {
561 __stream_callback_map.borrow_mut().remove(&__seq_id);
562 (__abort_sender)(__seq_id);
563 }),
564 )
565 }
566 }
567 } else {
568 let return_type = match output {
569 ReturnType::Type(_, ref ty) => quote! {
570 web_rpc::client::RequestFuture<#ty>
571 },
572 _ => quote!(()),
573 };
574 let maybe_register_callback = match output {
575 ReturnType::Type(_, _) => quote! {
576 let (__response_tx, __response_rx) =
577 web_rpc::futures_channel::oneshot::channel();
578 self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
579 },
580 _ => Default::default(),
581 };
582
583 let maybe_unpack_and_return_future = match output {
584 ReturnType::Type(_, ref ret_ty) => {
585 let dec = emit_decode(ret_ty, quote!(__wire), "e!(&__post_array));
586 quote! {
587 let __response_future = web_rpc::futures_util::FutureExt::map(
588 __response_rx,
589 |response| {
590 let (__serialize_response, __post_array) = response.unwrap();
591 let #response_ident::#camel_case_ident(__wire) = __serialize_response else {
592 panic!("web_rpc: received incorrect response variant")
593 };
594 #dec
595 }
596 );
597 let __abort_sender = self.abort_sender.clone();
598 let __dispatcher = self.dispatcher.clone();
599 web_rpc::client::RequestFuture::new(
600 __response_future,
601 __dispatcher,
602 std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
603 }
604 }
605 _ => Default::default(),
606 };
607
608 quote! {
609 #( #attrs )*
610 #vis fn #ident(
611 &self,
612 #( #args ),*
613 ) -> #return_type {
614 #send_request
615 #maybe_register_callback
616 #maybe_unpack_and_return_future
617 }
618 }
619 }
620 });
621
622 let stream_callback_map_field = if has_streaming_methods {
623 quote! {
624 stream_callback_map: std::rc::Rc<
625 std::cell::RefCell<
626 web_rpc::client::StreamCallbackMap<#response_ident>
627 >
628 >,
629 }
630 } else {
631 quote!()
632 };
633
634 let stream_callback_map_pat = if has_streaming_methods {
635 quote! { stream_callback_map, }
636 } else {
637 quote! { _, }
638 };
639
640 let stream_callback_map_init = if has_streaming_methods {
641 quote! { stream_callback_map, }
642 } else {
643 quote! {}
644 };
645
646 quote! {
647 #[derive(core::clone::Clone)]
648 #vis struct #client_ident {
649 callback_map: std::rc::Rc<
650 std::cell::RefCell<
651 web_rpc::client::CallbackMap<#response_ident>
652 >
653 >,
654 #stream_callback_map_field
655 port: web_rpc::port::Port,
656 listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
657 dispatcher: web_rpc::futures_util::future::Shared<
658 web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
659 >,
660 abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
661 seq_id: std::rc::Rc<std::cell::RefCell<usize>>
662 }
663 impl std::fmt::Debug for #client_ident {
664 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
665 formatter.debug_struct(std::stringify!(#client_ident))
666 .finish()
667 }
668 }
669 impl web_rpc::client::Client for #client_ident {
670 type Response = #response_ident;
671 }
672 impl From<web_rpc::client::Configuration<#response_ident>>
673 for #client_ident {
674 fn from((callback_map, #stream_callback_map_pat port, listener, dispatcher, abort_sender):
675 web_rpc::client::Configuration<#response_ident>) -> Self {
676 Self {
677 callback_map,
678 #stream_callback_map_init
679 port,
680 listener,
681 dispatcher,
682 abort_sender,
683 seq_id: std::default::Default::default()
684 }
685 }
686 }
687 impl #client_ident {
688 #( #rpc_fns )*
689 }
690 }
691 }
692
693 fn struct_server(&self) -> TokenStream2 {
694 let &Self {
695 vis,
696 trait_ident,
697 service_ident,
698 request_ident,
699 response_ident,
700 camel_case_idents,
701 rpcs,
702 has_borrowed_args,
703 ..
704 } = self;
705
706 let request_type = if has_borrowed_args {
707 quote! { #request_ident<'_> }
708 } else {
709 quote! { #request_ident }
710 };
711
712 let handlers = rpcs.iter()
713 .zip(camel_case_idents.iter())
714 .map(|(RpcMethod { is_async, ident, args, transfer, output, .. }, camel_case_ident)| {
715 let destructure_fields: Vec<_> = args.iter()
718 .filter_map(|arg| {
719 let id = match &*arg.pat {
720 Pat::Ident(p) => &p.ident,
721 _ => return None,
722 };
723 Some(if is_borrowed_serde_ref(&arg.ty) {
724 quote! { #id }
725 } else {
726 let wire_ident = format_ident!("__wire_{}", id);
727 quote! { #id: #wire_ident }
728 })
729 })
730 .collect();
731
732 let arg_decodes: Vec<_> = args.iter()
734 .filter_map(|arg| {
735 let id = match &*arg.pat {
736 Pat::Ident(p) => &p.ident,
737 _ => return None,
738 };
739 if is_borrowed_serde_ref(&arg.ty) {
740 None
742 } else if is_js_ref(&arg.ty) {
743 let inner_ty = match &*arg.ty {
746 Type::Reference(r) => &*r.elem,
747 _ => unreachable!(),
748 };
749 let tmp_ident = format_ident!("__tmp_{}", id);
750 let wire_ident = format_ident!("__wire_{}", id);
751 let arg_ty = &arg.ty;
752 Some(quote! {
753 let #tmp_ident = match #wire_ident {
754 web_rpc::codec::WireArg::Js => __js_args.shift(),
755 _ => panic!("web_rpc: expected Js wire variant for reference arg"),
756 };
757 let #id: #arg_ty = web_rpc::wasm_bindgen::JsCast::dyn_ref::<#inner_ty>(&#tmp_ident)
758 .unwrap();
759 })
760 } else {
761 let wire_ident = format_ident!("__wire_{}", id);
762 let dec = emit_decode(&arg.ty, quote!(#wire_ident), "e!(&__js_args));
763 Some(quote! { let #id = #dec; })
764 }
765 })
766 .collect();
767
768 let call_args: Vec<_> = args.iter().filter_map(|arg| match &*arg.pat {
769 Pat::Ident(ident) => Some(&ident.ident),
770 _ => None,
771 }).collect();
772
773 let make_return_transfer = |scrutinee_ident: &Ident| -> TokenStream2 {
776 let pushes = transfer.iter().filter_map(|c| match c {
777 TransferClause::BareReturn => Some(quote! {
778 __transfer.push(#scrutinee_ident.as_ref());
779 }),
780 TransferClause::ReturnGated { gates } => {
781 let arms = gates.iter().map(|g| {
782 let pat = &g.pat;
783 let body = &g.body;
784 quote_spanned! {body.span()=>
785 if let #pat = &#scrutinee_ident {
786 __transfer.push((#body).as_ref());
787 }
788 }
789 });
790 Some(quote! { #( #arms )* })
791 }
792 _ => None,
793 });
794 quote! { #( #pushes )* }
795 };
796
797 let is_streaming = matches!(
798 output,
799 ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()
800 );
801
802 if is_streaming {
803 let item_ty = match output {
804 ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
805 _ => unreachable!(),
806 };
807 let item_enc = emit_encode(item_ty, quote!(__item), "e!(&__post));
808 let item_ident = Ident::new("__item", proc_macro2::Span::call_site());
809 let return_transfer = make_return_transfer(&item_ident);
810
811 let wrap_item = quote! {
812 let __post = web_rpc::js_sys::Array::new();
813 let __transfer = web_rpc::js_sys::Array::new();
814 let __wire_item = #item_enc;
815 #return_transfer
816 let __response = #response_ident::#camel_case_ident(__wire_item);
817 };
818
819 let fwd_body = quote! {
820 let __stream_tx_clone = __stream_tx.clone();
821 web_rpc::pin_utils::pin_mut!(__user_rx);
822 let __fwd = async move {
823 while let Some(__item) = web_rpc::futures_util::StreamExt::next(&mut __user_rx).await {
824 #wrap_item
825 if __stream_tx_clone.unbounded_send((__seq_id, Some((__response, __post, __transfer)))).is_err() {
826 break;
827 }
828 }
829 };
830 let __fwd = web_rpc::futures_util::FutureExt::fuse(__fwd);
831 web_rpc::pin_utils::pin_mut!(__fwd);
832 web_rpc::futures_util::select! {
833 _ = __abort_rx => {},
834 _ = __fwd => {},
835 }
836 let _ = __stream_tx.unbounded_send((__seq_id, None));
837 web_rpc::service::ExecuteResult::StreamComplete
838 };
839
840 match is_async {
841 Some(_) => quote! {
842 #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
843 #( #arg_decodes )*
844 let __get_rx = web_rpc::futures_util::FutureExt::fuse(
845 self.server_impl.#ident(#( #call_args ),*)
846 );
847 web_rpc::pin_utils::pin_mut!(__get_rx);
848 let __maybe_rx = web_rpc::futures_util::select! {
849 _ = __abort_rx => None,
850 __rx = __get_rx => Some(__rx),
851 };
852 if let Some(mut __user_rx) = __maybe_rx {
853 #fwd_body
854 } else {
855 let _ = __stream_tx.unbounded_send((__seq_id, None));
856 web_rpc::service::ExecuteResult::StreamComplete
857 }
858 }
859 },
860 None => quote! {
861 #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
862 #( #arg_decodes )*
863 let mut __user_rx = self.server_impl.#ident(#( #call_args ),*);
864 #fwd_body
865 }
866 },
867 }
868 } else {
869 let resp_ident = Ident::new("__response", proc_macro2::Span::call_site());
871 let return_transfer = make_return_transfer(&resp_ident);
872 let return_response = match output {
873 ReturnType::Type(_, ref ret_ty) => {
874 let enc = emit_encode(ret_ty, quote!(__response), "e!(&__post));
875 quote! {
876 let __post = web_rpc::js_sys::Array::new();
877 let __transfer = web_rpc::js_sys::Array::new();
878 let __wire = #enc;
879 #return_transfer
880 (#response_ident::#camel_case_ident(__wire), __post, __transfer)
881 }
882 }
883 _ => {
884 quote! {
886 let _ = __response;
887 let __post = web_rpc::js_sys::Array::new();
888 let __transfer = web_rpc::js_sys::Array::new();
889 let __wire = web_rpc::codec::WireArg::Bytes(
890 web_rpc::bincode::serialize(&()).unwrap()
891 );
892 (#response_ident::#camel_case_ident(__wire), __post, __transfer)
893 }
894 }
895 };
896
897 match is_async {
898 Some(_) => quote! {
899 #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
900 #( #arg_decodes )*
901 let __task =
902 web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #call_args ),*));
903 web_rpc::pin_utils::pin_mut!(__task);
904 web_rpc::service::ExecuteResult::Response(
905 web_rpc::futures_util::select! {
906 _ = __abort_rx => None,
907 __response = __task => Some({
908 #return_response
909 })
910 }
911 )
912 }
913 },
914 None => quote! {
915 #request_ident::#camel_case_ident { #( #destructure_fields ),* } => {
916 #( #arg_decodes )*
917 let __response = self.server_impl.#ident(#( #call_args ),*);
918 web_rpc::service::ExecuteResult::Response(
919 Some({
920 #return_response
921 })
922 )
923 }
924 }
925 }
926 }
927 });
928
929 quote! {
930 #vis struct #service_ident<T> {
931 server_impl: T
932 }
933 impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
934 type Response = #response_ident;
935 async fn execute(
936 &self,
937 __seq_id: usize,
938 mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
939 __payload: std::vec::Vec<u8>,
940 __js_args: web_rpc::js_sys::Array,
941 __stream_tx: web_rpc::futures_channel::mpsc::UnboundedSender<
942 web_rpc::service::StreamMessage<Self::Response>
943 >,
944 ) -> (usize, web_rpc::service::ExecuteResult<Self::Response>) {
945 let __request: #request_type = web_rpc::bincode::deserialize(&__payload).unwrap();
946 let __result = match __request {
947 #( #handlers )*
948 };
949 (__seq_id, __result)
950 }
951 }
952 impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
953 fn from(server_impl: T) -> Self {
954 Self { server_impl }
955 }
956 }
957 }
958 }
959}
960
961impl<'a> ToTokens for ServiceGenerator<'a> {
962 fn to_tokens(&self, output: &mut TokenStream2) {
963 output.extend(vec![
964 self.enum_request(),
965 self.enum_response(),
966 self.trait_service(),
967 self.struct_client(),
968 self.struct_server(),
969 ])
970 }
971}
972
973impl Parse for Service {
974 fn parse(input: ParseStream) -> syn::Result<Self> {
975 let attrs = input.call(Attribute::parse_outer)?;
976 let vis = input.parse()?;
977 input.parse::<Token![trait]>()?;
978 let ident: Ident = input.parse()?;
979 let content;
980 braced!(content in input);
981 let mut rpcs = Vec::<RpcMethod>::new();
982 while !content.is_empty() {
983 rpcs.push(content.parse()?);
984 }
985
986 Ok(Self {
987 attrs,
988 vis,
989 ident,
990 rpcs,
991 })
992 }
993}
994
995enum TransferRhs {
997 Expr(syn::Expr),
998 Gates(Vec<Gate>),
999}
1000
1001fn parse_transfer_rhs(input: ParseStream) -> syn::Result<TransferRhs> {
1002 if input.peek(Token![|]) || input.peek(Token![||]) {
1003 let closure: syn::ExprClosure = input.parse()?;
1005 if closure.inputs.len() != 1 {
1006 return Err(syn::Error::new_spanned(
1007 &closure,
1008 "transfer closure must have exactly one parameter",
1009 ));
1010 }
1011 let pat = closure.inputs.into_iter().next().unwrap();
1012 let body = *closure.body;
1013 Ok(TransferRhs::Gates(vec![Gate { pat, body }]))
1014 } else if input.peek(Token![match]) {
1015 input.parse::<Token![match]>()?;
1017 let content;
1018 braced!(content in input);
1019 let arms: Punctuated<syn::Arm, Token![,]> =
1020 content.parse_terminated(syn::Arm::parse)?;
1021 let gates = arms
1022 .into_iter()
1023 .map(|a| Gate {
1024 pat: a.pat,
1025 body: *a.body,
1026 })
1027 .collect();
1028 Ok(TransferRhs::Gates(gates))
1029 } else {
1030 let body: syn::Expr = input.parse()?;
1032 Ok(TransferRhs::Expr(body))
1033 }
1034}
1035
1036impl Parse for TransferClause {
1037 fn parse(input: ParseStream) -> syn::Result<Self> {
1038 let is_return = input.peek(Token![return]);
1039 let lhs_name: Option<Ident> = if is_return {
1040 input.parse::<Token![return]>()?;
1041 None
1042 } else {
1043 Some(input.parse()?)
1044 };
1045
1046 if input.peek(Token![=>]) {
1047 input.parse::<Token![=>]>()?;
1048 let rhs = parse_transfer_rhs(input)?;
1049 match (lhs_name, rhs) {
1050 (Some(name), TransferRhs::Expr(body)) => {
1051 Ok(TransferClause::ParamExpr { name, body })
1052 }
1053 (Some(name), TransferRhs::Gates(gates)) => {
1054 Ok(TransferClause::ParamGated { name, gates })
1055 }
1056 (None, TransferRhs::Gates(gates)) => {
1057 Ok(TransferClause::ReturnGated { gates })
1058 }
1059 (None, TransferRhs::Expr(_)) => Err(syn::Error::new(
1060 input.span(),
1061 "`return =>` requires a closure (`|pat| body`) or `match { arms }` block",
1062 )),
1063 }
1064 } else {
1065 Ok(match lhs_name {
1066 Some(name) => TransferClause::BareParam(name),
1067 None => TransferClause::BareReturn,
1068 })
1069 }
1070 }
1071}
1072
1073impl Parse for RpcMethod {
1074 fn parse(input: ParseStream) -> syn::Result<Self> {
1075 let mut errors = Ok(());
1076 let attrs = input.call(Attribute::parse_outer)?;
1077
1078 for attr in &attrs {
1080 if attr
1081 .path
1082 .segments
1083 .last()
1084 .is_some_and(|seg| seg.ident == "post")
1085 {
1086 extend_errors!(
1087 errors,
1088 syn::Error::new_spanned(
1089 attr,
1090 "`#[post(...)]` has been removed. JS-vs-serialize routing is now \
1091 inferred from each argument and return type. For transfer semantics, \
1092 use `#[transfer(...)]` (e.g. `#[transfer(canvas)]`, \
1093 `#[transfer(data => data.buffer())]`, or \
1094 `#[transfer(return => |Ok(o)| o.buffer())]`)."
1095 )
1096 );
1097 }
1098 }
1099
1100 let (transfer_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
1102 attr.path
1103 .segments
1104 .last()
1105 .is_some_and(|last_segment| last_segment.ident == "transfer")
1106 });
1107 let mut transfer: Vec<TransferClause> = Vec::new();
1108 for transfer_attr in transfer_attrs {
1109 let parsed = transfer_attr
1110 .parse_args_with(Punctuated::<TransferClause, Token![,]>::parse_terminated)?;
1111 transfer.extend(parsed.into_iter());
1112 }
1113
1114 let is_async = input.parse::<Token![async]>().ok();
1115 input.parse::<Token![fn]>()?;
1116 let ident: Ident = input.parse()?;
1117
1118 if input.peek(Token![<]) {
1120 let generics: syn::Generics = input.parse()?;
1121 extend_errors!(
1122 errors,
1123 syn::Error::new_spanned(
1124 generics,
1125 "web_rpc::service trait methods may not have generic parameters; \
1126 concrete types are required so the macro can route each argument."
1127 )
1128 );
1129 }
1130
1131 let content;
1132 parenthesized!(content in input);
1133 let mut receiver: Option<syn::Receiver> = None;
1134 let mut args = Vec::new();
1135 for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
1136 match arg {
1137 FnArg::Typed(captured) => match &*captured.pat {
1138 Pat::Ident(_) => {
1139 args.push(captured)
1145 }
1146 _ => extend_errors!(
1147 errors,
1148 syn::Error::new(
1149 captured.pat.span(),
1150 "patterns are not allowed in RPC arguments"
1151 )
1152 ),
1153 },
1154 FnArg::Receiver(ref recv) => {
1155 if recv.reference.is_none() || recv.mutability.is_some() {
1156 extend_errors!(
1157 errors,
1158 syn::Error::new(
1159 arg.span(),
1160 "RPC methods only support `&self` as a receiver"
1161 )
1162 );
1163 }
1164 receiver = Some(recv.clone());
1165 }
1166 }
1167 }
1168 let receiver = match receiver {
1169 Some(r) => r,
1170 None => {
1171 extend_errors!(
1172 errors,
1173 syn::Error::new(
1174 ident.span(),
1175 "RPC methods must include `&self` as the first parameter"
1176 )
1177 );
1178 parse_quote!(&self)
1179 }
1180 };
1181 let output: ReturnType = input.parse()?;
1182 input.parse::<Token![;]>()?;
1183
1184 let arg_names: HashSet<_> = args
1187 .iter()
1188 .filter_map(|arg| match &*arg.pat {
1189 Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
1190 _ => None,
1191 })
1192 .collect();
1193 for clause in &transfer {
1194 let name_ref = match clause {
1195 TransferClause::BareParam(name)
1196 | TransferClause::ParamExpr { name, .. }
1197 | TransferClause::ParamGated { name, .. } => Some(name),
1198 TransferClause::BareReturn | TransferClause::ReturnGated { .. } => None,
1199 };
1200 if let Some(name) = name_ref {
1201 if !arg_names.contains(name) {
1202 extend_errors!(
1203 errors,
1204 syn::Error::new(
1205 name.span(),
1206 format!(
1207 "`{}` in #[transfer(...)] does not match any parameter",
1208 name
1209 )
1210 )
1211 );
1212 }
1213 }
1214 }
1215 errors?;
1216
1217 Ok(Self {
1218 is_async,
1219 attrs,
1220 receiver,
1221 ident,
1222 args,
1223 transfer,
1224 output,
1225 })
1226 }
1227}
1228
1229#[proc_macro_attribute]
1235pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
1236 let Service {
1237 ref attrs,
1238 ref vis,
1239 ref ident,
1240 ref rpcs,
1241 } = parse_macro_input!(input as Service);
1242
1243 let camel_case_fn_names: &Vec<_> = &rpcs
1244 .iter()
1245 .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
1246 .collect();
1247
1248 let has_borrowed_args = rpcs
1249 .iter()
1250 .any(|rpc| rpc.args.iter().any(|arg| is_borrowed_serde_ref(&arg.ty)));
1251
1252 let has_streaming_methods = rpcs.iter().any(
1253 |rpc| matches!(&rpc.output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()),
1254 );
1255
1256 ServiceGenerator {
1257 trait_ident: ident,
1258 service_ident: &format_ident!("{}Service", ident),
1259 client_ident: &format_ident!("{}Client", ident),
1260 request_ident: &format_ident!("{}Request", ident),
1261 response_ident: &format_ident!("{}Response", ident),
1262 vis,
1263 attrs,
1264 rpcs,
1265 camel_case_idents: &rpcs
1266 .iter()
1267 .zip(camel_case_fn_names.iter())
1268 .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
1269 .collect::<Vec<_>>(),
1270 has_borrowed_args,
1271 has_streaming_methods,
1272 }
1273 .into_token_stream()
1274 .into()
1275}
1276
1277fn snake_to_camel(ident_str: &str) -> String {
1278 let mut camel_ty = String::with_capacity(ident_str.len());
1279
1280 let mut last_char_was_underscore = true;
1281 for c in ident_str.chars() {
1282 match c {
1283 '_' => last_char_was_underscore = true,
1284 c if last_char_was_underscore => {
1285 camel_ty.extend(c.to_uppercase());
1286 last_char_was_underscore = false;
1287 }
1288 c => camel_ty.extend(c.to_lowercase()),
1289 }
1290 }
1291
1292 camel_ty.shrink_to_fit();
1293 camel_ty
1294}