1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7 braced,
8 ext::IdentExt,
9 parenthesized,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote,
12 punctuated::Punctuated,
13 spanned::Spanned,
14 token::Comma,
15 Attribute, FnArg, Ident, Lifetime, Meta, NestedMeta, Pat, PatType, ReturnType, Token, Type,
16 Visibility,
17};
18
19macro_rules! extend_errors {
20 ($errors: ident, $e: expr) => {
21 match $errors {
22 Ok(_) => $errors = Err($e),
23 Err(ref mut errors) => errors.extend($e),
24 }
25 };
26}
27
28fn stream_item_type(ty: &Type) -> Option<&Type> {
30 if let Type::ImplTrait(impl_trait) = ty {
31 for bound in &impl_trait.bounds {
32 if let syn::TypeParamBound::Trait(trait_bound) = bound {
33 let last_segment = trait_bound.path.segments.last()?;
34 if last_segment.ident == "Stream" {
35 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
36 for arg in &args.args {
37 if let syn::GenericArgument::Binding(binding) = arg {
38 if binding.ident == "Item" {
39 return Some(&binding.ty);
40 }
41 }
42 }
43 }
44 }
45 }
46 }
47 }
48 None
49}
50
51fn option_inner_type(ty: &Type) -> Option<&Type> {
53 if let Type::Path(type_path) = ty {
54 let last_seg = type_path.path.segments.last()?;
55 if last_seg.ident == "Option" {
56 if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
57 if args.args.len() == 1 {
58 if let syn::GenericArgument::Type(inner) = &args.args[0] {
59 return Some(inner);
60 }
61 }
62 }
63 }
64 }
65 None
66}
67
68fn result_inner_types(ty: &Type) -> Option<(&Type, &Type)> {
70 if let Type::Path(type_path) = ty {
71 let last_seg = type_path.path.segments.last()?;
72 if last_seg.ident == "Result" {
73 if let syn::PathArguments::AngleBracketed(args) = &last_seg.arguments {
74 if args.args.len() == 2 {
75 if let (syn::GenericArgument::Type(ok_ty), syn::GenericArgument::Type(err_ty)) =
76 (&args.args[0], &args.args[1])
77 {
78 return Some((ok_ty, err_ty));
79 }
80 }
81 }
82 }
83 }
84 None
85}
86
87enum PostWrapperKind<'a> {
89 Bare,
91 Option { inner_ty: &'a Type },
93 Result { ok_ty: &'a Type, err_ty: &'a Type },
95}
96
97fn post_wrapper_kind(ty: &Type) -> PostWrapperKind<'_> {
98 if let Some(inner) = option_inner_type(ty) {
99 PostWrapperKind::Option { inner_ty: inner }
100 } else if let Some((ok_ty, err_ty)) = result_inner_types(ty) {
101 PostWrapperKind::Result { ok_ty, err_ty }
102 } else {
103 PostWrapperKind::Bare
104 }
105}
106
107struct Service {
108 attrs: Vec<Attribute>,
109 vis: Visibility,
110 ident: Ident,
111 rpcs: Vec<RpcMethod>,
112}
113
114struct RpcMethod {
115 is_async: Option<Token![async]>,
116 attrs: Vec<Attribute>,
117 receiver: syn::Receiver,
118 ident: Ident,
119 args: Vec<PatType>,
120 transfer: HashSet<Ident>,
121 post: HashSet<Ident>,
122 output: ReturnType,
123}
124
125struct ServiceGenerator<'a> {
126 trait_ident: &'a Ident,
127 service_ident: &'a Ident,
128 client_ident: &'a Ident,
129 request_ident: &'a Ident,
130 response_ident: &'a Ident,
131 vis: &'a Visibility,
132 attrs: &'a [Attribute],
133 rpcs: &'a [RpcMethod],
134 camel_case_idents: &'a [Ident],
135 has_borrowed_args: bool,
136 has_streaming_methods: bool,
137}
138
139impl<'a> ServiceGenerator<'a> {
140 fn enum_request(&self) -> TokenStream2 {
141 let &Self {
142 vis,
143 request_ident,
144 camel_case_idents,
145 rpcs,
146 has_borrowed_args,
147 ..
148 } = self;
149 let lifetime = if has_borrowed_args {
150 quote!(<'a>)
151 } else {
152 quote!()
153 };
154 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
155 |(RpcMethod { args, post, .. }, camel_case_ident)| {
156 let fields = args.iter().filter_map(|arg| {
157 let is_post =
158 matches!(&*arg.pat, Pat::Ident(ident) if post.contains(&ident.ident));
159 if is_post {
160 let pat = &arg.pat;
162 match post_wrapper_kind(&arg.ty) {
163 PostWrapperKind::Option { .. } => Some(quote! { #pat: bool }),
164 PostWrapperKind::Result { .. } => Some(quote! { #pat: bool }),
165 PostWrapperKind::Bare => None, }
167 } else {
168 Some(if has_borrowed_args {
170 if let Type::Reference(type_ref) = &*arg.ty {
171 let mut type_ref = type_ref.clone();
172 type_ref.lifetime =
173 Some(Lifetime::new("'a", type_ref.and_token.span()));
174 let pat = &arg.pat;
175 quote! { #pat: #type_ref }
176 } else {
177 quote! { #arg }
178 }
179 } else {
180 quote! { #arg }
181 })
182 }
183 });
184 quote! {
185 #camel_case_ident { #( #fields ),* }
186 }
187 },
188 );
189 quote! {
190 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
191 #vis enum #request_ident #lifetime {
192 #( #variants ),*
193 }
194 }
195 }
196
197 fn enum_response(&self) -> TokenStream2 {
198 let &Self {
199 vis,
200 response_ident,
201 camel_case_idents,
202 rpcs,
203 ..
204 } = self;
205 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
206 |(RpcMethod { output, post, .. }, camel_case_ident)| match output {
207 ReturnType::Type(_, ty) if !post.contains(&Ident::new("return", output.span())) => {
208 if let Some(item_ty) = stream_item_type(ty) {
210 quote! {
211 #camel_case_ident ( #item_ty )
212 }
213 } else {
214 quote! {
215 #camel_case_ident ( #ty )
216 }
217 }
218 }
219 ReturnType::Type(_, ty) => {
220 let effective_ty = stream_item_type(ty).unwrap_or(ty);
222 match post_wrapper_kind(effective_ty) {
223 PostWrapperKind::Bare => quote! {
224 #camel_case_ident ( () )
225 },
226 PostWrapperKind::Option { .. } => quote! {
227 #camel_case_ident ( bool )
228 },
229 PostWrapperKind::Result { .. } => quote! {
230 #camel_case_ident ( bool )
231 },
232 }
233 }
234 _ => quote! {
235 #camel_case_ident ( () )
236 },
237 },
238 );
239 quote! {
240 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
241 #vis enum #response_ident {
242 #( #variants ),*
243 }
244 }
245 }
246
247 fn trait_service(&self) -> TokenStream2 {
248 let &Self {
249 attrs,
250 rpcs,
251 vis,
252 trait_ident,
253 ..
254 } = self;
255
256 let unit_type: &Type = &parse_quote!(());
257 let rpc_fns = rpcs.iter().map(
258 |RpcMethod {
259 attrs,
260 args,
261 receiver,
262 ident,
263 is_async,
264 output,
265 ..
266 }| {
267 if let ReturnType::Type(_, ref ty) = output {
268 if let Some(item_ty) = stream_item_type(ty) {
269 return quote_spanned! {ident.span()=>
270 #( #attrs )*
271 #is_async fn #ident(#receiver, #( #args ),*) -> impl web_rpc::futures_core::Stream<Item = #item_ty>;
272 };
273 }
274 }
275 let output = match output {
276 ReturnType::Type(_, ref ty) => ty,
277 ReturnType::Default => unit_type,
278 };
279 quote_spanned! {ident.span()=>
280 #( #attrs )*
281 #is_async fn #ident(#receiver, #( #args ),*) -> #output;
282 }
283 },
284 );
285
286 let forward_fns = rpcs
287 .iter()
288 .map(
289 |RpcMethod {
290 attrs,
291 args,
292 receiver,
293 ident,
294 is_async,
295 output,
296 ..
297 }| {
298 {
299 let output = if let ReturnType::Type(_, ref ty) = output {
300 if let Some(item_ty) = stream_item_type(ty) {
301 quote! { impl web_rpc::futures_core::Stream<Item = #item_ty> }
302 } else {
303 let ty: &Type = ty;
304 quote! { #ty }
305 }
306 } else {
307 let ty = unit_type;
308 quote! { #ty }
309 };
310 let do_await = match is_async {
311 Some(token) => quote_spanned!(token.span=> .await),
312 None => quote!(),
313 };
314 let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
315 Pat::Ident(ident) => Some(&ident.ident),
316 _ => None,
317 });
318 quote_spanned! {ident.span()=>
319 #( #attrs )*
320 #is_async fn #ident(#receiver, #( #args ),*) -> #output {
321 T::#ident(self, #( #forward_args ),*)#do_await
322 }
323 }
324 }
325 },
326 )
327 .collect::<Vec<_>>();
328
329 quote! {
330 #( #attrs )*
331 #[allow(async_fn_in_trait)]
332 #vis trait #trait_ident {
333 #( #rpc_fns )*
334 }
335
336 impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
337 #( #forward_fns )*
338 }
339 impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
340 #( #forward_fns )*
341 }
342 impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
343 #( #forward_fns )*
344 }
345 }
346 }
347
348 fn struct_client(&self) -> TokenStream2 {
349 let &Self {
350 vis,
351 client_ident,
352 request_ident,
353 response_ident,
354 camel_case_idents,
355 rpcs,
356 has_streaming_methods,
357 ..
358 } = self;
359
360 let rpc_fns = rpcs
361 .iter()
362 .zip(camel_case_idents.iter())
363 .map(|(RpcMethod { attrs, args, transfer, post, ident, output, .. }, camel_case_ident)| {
364 let request_struct_fields: Vec<_> = args.iter()
367 .filter_map(|arg| match &*arg.pat {
368 Pat::Ident(pat_ident) => {
369 let id = &pat_ident.ident;
370 if post.contains(id) {
371 match post_wrapper_kind(&arg.ty) {
372 PostWrapperKind::Bare => None,
373 PostWrapperKind::Option { .. } => {
374 Some(quote! { #id: #id.is_some() })
375 }
376 PostWrapperKind::Result { .. } => {
377 Some(quote! { #id: #id.is_ok() })
378 }
379 }
380 } else {
381 Some(quote! { #id })
382 }
383 }
384 _ => None
385 })
386 .collect();
387
388 let post_pushes: Vec<_> = args.iter()
391 .filter_map(|arg| match &*arg.pat {
392 Pat::Ident(pat_ident) if post.contains(&pat_ident.ident) => {
393 let id = &pat_ident.ident;
394 let is_transfer = transfer.contains(&pat_ident.ident);
395 match post_wrapper_kind(&arg.ty) {
396 PostWrapperKind::Bare => {
397 let transfer_push = if is_transfer {
398 quote! { __transfer.push(#id.as_ref()); }
399 } else {
400 quote! {}
401 };
402 Some(quote! {
403 __post.push(#id.as_ref());
404 #transfer_push
405 })
406 }
407 PostWrapperKind::Option { .. } => {
408 let transfer_push = if is_transfer {
409 quote! { __transfer.push(__val.as_ref()); }
410 } else {
411 quote! {}
412 };
413 Some(quote! {
414 if let Some(ref __val) = #id {
415 __post.push(__val.as_ref());
416 #transfer_push
417 }
418 })
419 }
420 PostWrapperKind::Result { .. } => {
421 let transfer_push = if is_transfer {
422 quote! { __transfer.push(__val.as_ref()); }
423 } else {
424 quote! {}
425 };
426 Some(quote! {
427 match #id {
428 Ok(ref __val) => {
429 __post.push(__val.as_ref());
430 #transfer_push
431 }
432 Err(ref __val) => {
433 __post.push(__val.as_ref());
434 #transfer_push
435 }
436 }
437 })
438 }
439 }
440 }
441 _ => None
442 })
443 .collect();
444
445 let bare_transfer_arg_idents: Vec<_> = args.iter()
447 .filter_map(|arg| match &*arg.pat {
448 Pat::Ident(pat_ident) if transfer.contains(&pat_ident.ident)
449 && matches!(post_wrapper_kind(&arg.ty), PostWrapperKind::Bare) =>
450 {
451 Some(&pat_ident.ident)
452 }
453 _ => None
454 })
455 .collect();
456
457 let has_wrapped_post_args = args.iter().any(|arg| {
458 matches!(&*arg.pat, Pat::Ident(pat_ident)
459 if post.contains(&pat_ident.ident)
460 && !matches!(post_wrapper_kind(&arg.ty), PostWrapperKind::Bare))
461 });
462
463 let send_request = if has_wrapped_post_args {
465 quote! {
467 let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
468 let __request = #request_ident::#camel_case_ident {
469 #( #request_struct_fields ),*
470 };
471 let __header = web_rpc::MessageHeader::Request(__seq_id);
472 let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
473 let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
474 let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
475 let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
476 let __post = web_rpc::js_sys::Array::new();
477 let __transfer = web_rpc::js_sys::Array::new();
478 __post.push(__header_buffer.as_ref());
479 __post.push(__payload_buffer.as_ref());
480 __transfer.push(__header_buffer.as_ref());
481 __transfer.push(__payload_buffer.as_ref());
482 #( #post_pushes )*
483 self.port.post_message(&__post, &__transfer).unwrap();
484 }
485 } else {
486 let bare_post_arg_idents: Vec<_> = args.iter()
488 .filter_map(|arg| match &*arg.pat {
489 Pat::Ident(pat_ident) if post.contains(&pat_ident.ident) => Some(&pat_ident.ident),
490 _ => None
491 })
492 .collect();
493 quote! {
494 let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
495 let __request = #request_ident::#camel_case_ident {
496 #( #request_struct_fields ),*
497 };
498 let __header = web_rpc::MessageHeader::Request(__seq_id);
499 let __header_bytes = web_rpc::bincode::serialize(&__header).unwrap();
500 let __header_buffer = web_rpc::js_sys::Uint8Array::from(&__header_bytes[..]).buffer();
501 let __payload_bytes = web_rpc::bincode::serialize(&__request).unwrap();
502 let __payload_buffer = web_rpc::js_sys::Uint8Array::from(&__payload_bytes[..]).buffer();
503 let __post: &[&web_rpc::wasm_bindgen::JsValue] =
504 &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #bare_post_arg_idents.as_ref() ),*];
505 let __post = web_rpc::js_sys::Array::from_iter(__post);
506 let __transfer: &[&web_rpc::wasm_bindgen::JsValue] =
507 &[__header_buffer.as_ref(), __payload_buffer.as_ref(), #( #bare_transfer_arg_idents.as_ref() ),*];
508 let __transfer = web_rpc::js_sys::Array::from_iter(__transfer);
509 self.port.post_message(&__post, &__transfer).unwrap();
510 }
511 };
512
513 let is_streaming = matches!(output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some());
515
516 if is_streaming {
517 let item_ty = match output {
518 ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
519 _ => unreachable!(),
520 };
521
522 let unpack_stream_item = if post.contains(&Ident::new("return", output.span())) {
523 match post_wrapper_kind(item_ty) {
524 PostWrapperKind::Bare => quote! {
525 |(_response, __post_array)| {
526 web_rpc::wasm_bindgen::JsCast::dyn_into::<#item_ty>(__post_array.shift())
527 .unwrap()
528 }
529 },
530 PostWrapperKind::Option { inner_ty } => quote! {
531 |(__response, __post_array)| {
532 let #response_ident::#camel_case_ident(__has_value) = __response else {
533 panic!("received incorrect response variant")
534 };
535 if __has_value {
536 Some(web_rpc::wasm_bindgen::JsCast::dyn_into::<#inner_ty>(__post_array.shift())
537 .unwrap())
538 } else {
539 None
540 }
541 }
542 },
543 PostWrapperKind::Result { ok_ty, err_ty } => quote! {
544 |(__response, __post_array)| {
545 let #response_ident::#camel_case_ident(__is_ok) = __response else {
546 panic!("received incorrect response variant")
547 };
548 if __is_ok {
549 Ok(web_rpc::wasm_bindgen::JsCast::dyn_into::<#ok_ty>(__post_array.shift())
550 .unwrap())
551 } else {
552 Err(web_rpc::wasm_bindgen::JsCast::dyn_into::<#err_ty>(__post_array.shift())
553 .unwrap())
554 }
555 }
556 },
557 }
558 } else {
559 quote! {
560 |(__response, _post_array)| {
561 let #response_ident::#camel_case_ident(__inner) = __response else {
562 panic!("received incorrect response variant")
563 };
564 __inner
565 }
566 }
567 };
568
569 quote! {
570 #( #attrs )*
571 #vis fn #ident(
572 &self,
573 #( #args ),*
574 ) -> web_rpc::client::StreamReceiver<#item_ty> {
575 #send_request
576 let (__item_tx, __item_rx) = web_rpc::futures_channel::mpsc::unbounded();
577 self.stream_callback_map.borrow_mut().insert(__seq_id, __item_tx);
578 let __mapped_rx = web_rpc::futures_util::StreamExt::map(
579 __item_rx,
580 #unpack_stream_item
581 );
582 let __abort_sender = self.abort_sender.clone();
583 let __stream_callback_map = self.stream_callback_map.clone();
584 let __dispatcher = self.dispatcher.clone();
585 web_rpc::client::StreamReceiver::new(
586 __mapped_rx,
587 __dispatcher,
588 std::boxed::Box::new(move || {
589 __stream_callback_map.borrow_mut().remove(&__seq_id);
590 (__abort_sender)(__seq_id);
591 }),
592 )
593 }
594 }
595 } else {
596 let return_type = match output {
598 ReturnType::Type(_, ref ty) => quote! {
599 web_rpc::client::RequestFuture<#ty>
600 },
601 _ => quote!(())
602 };
603 let maybe_register_callback = match output {
604 ReturnType::Type(_, _) => quote! {
605 let (__response_tx, __response_rx) =
606 web_rpc::futures_channel::oneshot::channel();
607 self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
608 },
609 _ => Default::default()
610 };
611
612 let unpack_response = if post.contains(&Ident::new("return", output.span())) {
613 let unit_output: &Type = &parse_quote!(());
614 let ret_ty = match output {
615 ReturnType::Type(_, ref ty) => ty,
616 _ => unit_output
617 };
618 match post_wrapper_kind(ret_ty) {
619 PostWrapperKind::Bare => quote! {
620 let (_, __post_response) = response;
621 web_rpc::wasm_bindgen::JsCast::dyn_into::<#ret_ty>(__post_response.shift())
622 .unwrap()
623 },
624 PostWrapperKind::Option { inner_ty } => quote! {
625 let (__serialize_response, __post_response) = response;
626 let #response_ident::#camel_case_ident(__has_value) = __serialize_response else {
627 panic!("received incorrect response variant")
628 };
629 if __has_value {
630 Some(web_rpc::wasm_bindgen::JsCast::dyn_into::<#inner_ty>(__post_response.shift())
631 .unwrap())
632 } else {
633 None
634 }
635 },
636 PostWrapperKind::Result { ok_ty, err_ty } => quote! {
637 let (__serialize_response, __post_response) = response;
638 let #response_ident::#camel_case_ident(__is_ok) = __serialize_response else {
639 panic!("received incorrect response variant")
640 };
641 if __is_ok {
642 Ok(web_rpc::wasm_bindgen::JsCast::dyn_into::<#ok_ty>(__post_response.shift())
643 .unwrap())
644 } else {
645 Err(web_rpc::wasm_bindgen::JsCast::dyn_into::<#err_ty>(__post_response.shift())
646 .unwrap())
647 }
648 },
649 }
650 } else {
651 quote! {
652 let (__serialize_response, _) = response;
653 let #response_ident::#camel_case_ident(__inner) = __serialize_response else {
654 panic!("received incorrect response variant")
655 };
656 __inner
657 }
658 };
659
660 let maybe_unpack_and_return_future = match output {
661 ReturnType::Type(_, _) => quote! {
662 let __response_future = web_rpc::futures_util::FutureExt::map(
663 __response_rx,
664 |response| {
665 let response = response.unwrap();
666 #unpack_response
667 }
668 );
669 let __abort_sender = self.abort_sender.clone();
670 let __dispatcher = self.dispatcher.clone();
671 web_rpc::client::RequestFuture::new(
672 __response_future,
673 __dispatcher,
674 std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
675 },
676 _ => Default::default()
677 };
678
679 quote! {
680 #( #attrs )*
681 #vis fn #ident(
682 &self,
683 #( #args ),*
684 ) -> #return_type {
685 #send_request
686 #maybe_register_callback
687 #maybe_unpack_and_return_future
688 }
689 }
690 }
691 });
692
693 let stream_callback_map_field = if has_streaming_methods {
694 quote! {
695 stream_callback_map: std::rc::Rc<
696 std::cell::RefCell<
697 web_rpc::client::StreamCallbackMap<#response_ident>
698 >
699 >,
700 }
701 } else {
702 quote!()
703 };
704
705 let stream_callback_map_pat = if has_streaming_methods {
706 quote! { stream_callback_map, }
707 } else {
708 quote! { _, }
709 };
710
711 let stream_callback_map_init = if has_streaming_methods {
712 quote! { stream_callback_map, }
713 } else {
714 quote! {}
715 };
716
717 quote! {
718 #[derive(core::clone::Clone)]
719 #vis struct #client_ident {
720 callback_map: std::rc::Rc<
721 std::cell::RefCell<
722 web_rpc::client::CallbackMap<#response_ident>
723 >
724 >,
725 #stream_callback_map_field
726 port: web_rpc::port::Port,
727 listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
728 dispatcher: web_rpc::futures_util::future::Shared<
729 web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
730 >,
731 abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
732 seq_id: std::rc::Rc<std::cell::RefCell<usize>>
733 }
734 impl std::fmt::Debug for #client_ident {
735 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
736 formatter.debug_struct(std::stringify!(#client_ident))
737 .finish()
738 }
739 }
740 impl web_rpc::client::Client for #client_ident {
741 type Response = #response_ident;
742 }
743 impl From<web_rpc::client::Configuration<#response_ident>>
744 for #client_ident {
745 fn from((callback_map, #stream_callback_map_pat port, listener, dispatcher, abort_sender):
746 web_rpc::client::Configuration<#response_ident>) -> Self {
747 Self {
748 callback_map,
749 #stream_callback_map_init
750 port,
751 listener,
752 dispatcher,
753 abort_sender,
754 seq_id: std::default::Default::default()
755 }
756 }
757 }
758 impl #client_ident {
759 #( #rpc_fns )*
760 }
761 }
762 }
763
764 fn struct_server(&self) -> TokenStream2 {
765 let &Self {
766 vis,
767 trait_ident,
768 service_ident,
769 request_ident,
770 response_ident,
771 camel_case_idents,
772 rpcs,
773 has_borrowed_args,
774 ..
775 } = self;
776
777 let request_type = if has_borrowed_args {
778 quote! { #request_ident<'_> }
779 } else {
780 quote! { #request_ident }
781 };
782
783 let handlers = rpcs.iter()
784 .zip(camel_case_idents.iter())
785 .map(|(RpcMethod { is_async, ident, args, transfer, post, output, .. }, camel_case_ident)| {
786 let destructure_arg_idents: Vec<_> = args.iter()
789 .filter_map(|arg| match &*arg.pat {
790 Pat::Ident(pat_ident) => {
791 let id = &pat_ident.ident;
792 if post.contains(id) {
793 match post_wrapper_kind(&arg.ty) {
795 PostWrapperKind::Bare => None,
796 _ => Some(id),
797 }
798 } else {
799 Some(id)
800 }
801 }
802 _ => None
803 })
804 .collect();
805 let extract_js_args = args.iter()
806 .filter_map(|arg| match &*arg.pat {
807 Pat::Ident(pat_ident) if post.contains(&pat_ident.ident) => {
808 let arg_pat = &arg.pat;
809 let arg_ty = &arg.ty;
810 match post_wrapper_kind(arg_ty) {
811 PostWrapperKind::Bare => {
812 if let Type::Reference(type_ref) = &**arg_ty {
814 let inner_ty = &type_ref.elem;
815 let tmp_ident = format_ident!("__tmp_{}", pat_ident.ident);
816 Some(quote! {
817 let #tmp_ident = __js_args.shift();
818 let #arg_pat: #arg_ty = web_rpc::wasm_bindgen::JsCast::dyn_ref::<#inner_ty>(&#tmp_ident)
819 .unwrap();
820 })
821 } else {
822 Some(quote! {
823 let #arg_pat = web_rpc::wasm_bindgen::JsCast::dyn_into::<#arg_ty>(__js_args.shift())
824 .unwrap();
825 })
826 }
827 }
828 PostWrapperKind::Option { inner_ty } => {
829 Some(quote! {
832 let #arg_pat: #arg_ty = if #arg_pat {
833 Some(web_rpc::wasm_bindgen::JsCast::dyn_into::<#inner_ty>(__js_args.shift())
834 .unwrap())
835 } else {
836 None
837 };
838 })
839 }
840 PostWrapperKind::Result { ok_ty, .. } => {
841 Some(quote! {
844 let #arg_pat: #arg_ty = match #arg_pat {
845 Ok(()) => Ok(web_rpc::wasm_bindgen::JsCast::dyn_into::<#ok_ty>(__js_args.shift())
846 .unwrap()),
847 Err(__e) => Err(__e),
848 };
849 })
850 }
851 }
852 },
853 _ => None
854 });
855
856 let is_streaming = matches!(output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some());
858
859 if is_streaming {
860 let call_args = args.iter().filter_map(|arg| match &*arg.pat {
861 Pat::Ident(ident) => Some(&ident.ident),
862 _ => None
863 });
864 let return_ident = Ident::new("return", output.span());
865 let is_post_return = post.contains(&return_ident);
866 let is_transfer_return = transfer.contains(&return_ident);
867 let item_ty = match output {
868 ReturnType::Type(_, ref ty) => stream_item_type(ty).unwrap(),
869 _ => unreachable!(),
870 };
871 let wrap_item = if !is_post_return {
872 quote! {
873 let __response = #response_ident::#camel_case_ident(__item);
874 let __post = web_rpc::js_sys::Array::new();
875 let __transfer = web_rpc::js_sys::Array::new();
876 }
877 } else {
878 let wrapper_kind = post_wrapper_kind(item_ty);
879 match wrapper_kind {
880 PostWrapperKind::Bare => {
881 let transfer_code = if is_transfer_return {
882 quote! { web_rpc::js_sys::Array::of1(__item.as_ref()) }
883 } else {
884 quote! { web_rpc::js_sys::Array::new() }
885 };
886 quote! {
887 let __response = #response_ident::#camel_case_ident(());
888 let __post = web_rpc::js_sys::Array::of1(__item.as_ref());
889 let __transfer = #transfer_code;
890 }
891 }
892 PostWrapperKind::Option { .. } => {
893 let transfer_push = if is_transfer_return {
894 quote! { __transfer.push(__val.as_ref()); }
895 } else {
896 quote! {}
897 };
898 quote! {
899 let (__response, __post, __transfer) = match __item {
900 Some(ref __val) => {
901 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
902 let __transfer = web_rpc::js_sys::Array::new();
903 #transfer_push
904 (#response_ident::#camel_case_ident(true), __post, __transfer)
905 }
906 None => {
907 (#response_ident::#camel_case_ident(false),
908 web_rpc::js_sys::Array::new(),
909 web_rpc::js_sys::Array::new())
910 }
911 };
912 }
913 }
914 PostWrapperKind::Result { .. } => {
915 let transfer_push = if is_transfer_return {
916 quote! { __transfer.push(__val.as_ref()); }
917 } else {
918 quote! {}
919 };
920 quote! {
921 let (__response, __post, __transfer) = match __item {
922 Ok(ref __val) => {
923 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
924 let __transfer = web_rpc::js_sys::Array::new();
925 #transfer_push
926 (#response_ident::#camel_case_ident(true), __post, __transfer)
927 }
928 Err(ref __val) => {
929 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
930 let __transfer = web_rpc::js_sys::Array::new();
931 (#response_ident::#camel_case_ident(false), __post, __transfer)
932 }
933 };
934 }
935 }
936 }
937 };
938 let fwd_body = quote! {
940 let __stream_tx_clone = __stream_tx.clone();
941 web_rpc::pin_utils::pin_mut!(__user_rx);
942 let __fwd = async move {
943 while let Some(__item) = web_rpc::futures_util::StreamExt::next(&mut __user_rx).await {
944 #wrap_item
945 if __stream_tx_clone.unbounded_send((__seq_id, Some((__response, __post, __transfer)))).is_err() {
946 break;
947 }
948 }
949 };
950 let __fwd = web_rpc::futures_util::FutureExt::fuse(__fwd);
951 web_rpc::pin_utils::pin_mut!(__fwd);
952 web_rpc::futures_util::select! {
953 _ = __abort_rx => {},
954 _ = __fwd => {},
955 }
956 let _ = __stream_tx.unbounded_send((__seq_id, None));
957 web_rpc::service::ExecuteResult::StreamComplete
958 };
959
960 match is_async {
961 Some(_) => quote! {
962 #request_ident::#camel_case_ident { #( #destructure_arg_idents ),* } => {
963 #( #extract_js_args )*
964 let __get_rx = web_rpc::futures_util::FutureExt::fuse(
965 self.server_impl.#ident(#( #call_args ),*)
966 );
967 web_rpc::pin_utils::pin_mut!(__get_rx);
968 let __maybe_rx = web_rpc::futures_util::select! {
969 _ = __abort_rx => None,
970 __rx = __get_rx => Some(__rx),
971 };
972 if let Some(mut __user_rx) = __maybe_rx {
973 #fwd_body
974 } else {
975 let _ = __stream_tx.unbounded_send((__seq_id, None));
976 web_rpc::service::ExecuteResult::StreamComplete
977 }
978 }
979 },
980 None => quote! {
981 #request_ident::#camel_case_ident { #( #destructure_arg_idents ),* } => {
982 #( #extract_js_args )*
983 let mut __user_rx = self.server_impl.#ident(#( #call_args ),*);
984 #fwd_body
985 }
986 },
987 }
988 } else {
989 let return_ident = Ident::new("return", output.span());
991 let is_post_return = post.contains(&return_ident);
992 let is_transfer_return = transfer.contains(&return_ident);
993 let ret_ty = match output {
994 ReturnType::Type(_, ref ty) => Some(ty.as_ref()),
995 _ => None,
996 };
997 let return_response = if !is_post_return {
998 quote! {
999 let __post = web_rpc::js_sys::Array::new();
1000 let __transfer = web_rpc::js_sys::Array::new();
1001 (#response_ident::#camel_case_ident(__response), __post, __transfer)
1002 }
1003 } else {
1004 let wrapper_kind = ret_ty.map(|ty| post_wrapper_kind(ty));
1005 match wrapper_kind {
1006 Some(PostWrapperKind::Option { .. }) => {
1007 let transfer_push = if is_transfer_return {
1008 quote! { __transfer.push(__val.as_ref()); }
1009 } else {
1010 quote! {}
1011 };
1012 quote! {
1013 match __response {
1014 Some(ref __val) => {
1015 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
1016 let __transfer = web_rpc::js_sys::Array::new();
1017 #transfer_push
1018 (#response_ident::#camel_case_ident(true), __post, __transfer)
1019 }
1020 None => {
1021 (#response_ident::#camel_case_ident(false),
1022 web_rpc::js_sys::Array::new(),
1023 web_rpc::js_sys::Array::new())
1024 }
1025 }
1026 }
1027 }
1028 Some(PostWrapperKind::Result { .. }) => {
1029 let transfer_push = if is_transfer_return {
1030 quote! { __transfer.push(__val.as_ref()); }
1031 } else {
1032 quote! {}
1033 };
1034 quote! {
1035 match __response {
1036 Ok(ref __val) => {
1037 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
1038 let __transfer = web_rpc::js_sys::Array::new();
1039 #transfer_push
1040 (#response_ident::#camel_case_ident(true), __post, __transfer)
1041 }
1042 Err(ref __val) => {
1043 let __post = web_rpc::js_sys::Array::of1(__val.as_ref());
1044 let __transfer = web_rpc::js_sys::Array::new();
1045 (#response_ident::#camel_case_ident(false), __post, __transfer)
1046 }
1047 }
1048 }
1049 }
1050 _ => {
1051 let transfer_code = if is_transfer_return {
1053 quote! { web_rpc::js_sys::Array::of1(__response.as_ref()) }
1054 } else {
1055 quote! { web_rpc::js_sys::Array::new() }
1056 };
1057 quote! {
1058 let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
1059 let __transfer = #transfer_code;
1060 (#response_ident::#camel_case_ident(()), __post, __transfer)
1061 }
1062 }
1063 }
1064 };
1065 let call_args = args.iter().filter_map(|arg| match &*arg.pat {
1066 Pat::Ident(ident) => Some(&ident.ident),
1067 _ => None
1068 });
1069 match is_async {
1070 Some(_) => quote! {
1071 #request_ident::#camel_case_ident { #( #destructure_arg_idents ),* } => {
1072 #( #extract_js_args )*
1073 let __task =
1074 web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #call_args ),*));
1075 web_rpc::pin_utils::pin_mut!(__task);
1076 web_rpc::service::ExecuteResult::Response(
1077 web_rpc::futures_util::select! {
1078 _ = __abort_rx => None,
1079 __response = __task => Some({
1080 #return_response
1081 })
1082 }
1083 )
1084 }
1085 },
1086 None => quote! {
1087 #request_ident::#camel_case_ident { #( #destructure_arg_idents ),* } => {
1088 #( #extract_js_args )*
1089 let __response = self.server_impl.#ident(#( #call_args ),*);
1090 web_rpc::service::ExecuteResult::Response(
1091 Some({
1092 #return_response
1093 })
1094 )
1095 }
1096 }
1097 }
1098 }
1099 });
1100
1101 quote! {
1102 #vis struct #service_ident<T> {
1103 server_impl: T
1104 }
1105 impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
1106 type Response = #response_ident;
1107 async fn execute(
1108 &self,
1109 __seq_id: usize,
1110 mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
1111 __payload: std::vec::Vec<u8>,
1112 __js_args: web_rpc::js_sys::Array,
1113 __stream_tx: web_rpc::futures_channel::mpsc::UnboundedSender<
1114 web_rpc::service::StreamMessage<Self::Response>
1115 >,
1116 ) -> (usize, web_rpc::service::ExecuteResult<Self::Response>) {
1117 let __request: #request_type = web_rpc::bincode::deserialize(&__payload).unwrap();
1118 let __result = match __request {
1119 #( #handlers )*
1120 };
1121 (__seq_id, __result)
1122 }
1123 }
1124 impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
1125 fn from(server_impl: T) -> Self {
1126 Self { server_impl }
1127 }
1128 }
1129 }
1130 }
1131}
1132
1133impl<'a> ToTokens for ServiceGenerator<'a> {
1134 fn to_tokens(&self, output: &mut TokenStream2) {
1135 output.extend(vec![
1136 self.enum_request(),
1137 self.enum_response(),
1138 self.trait_service(),
1139 self.struct_client(),
1140 self.struct_server(),
1141 ])
1142 }
1143}
1144
1145impl Parse for Service {
1146 fn parse(input: ParseStream) -> syn::Result<Self> {
1147 let attrs = input.call(Attribute::parse_outer)?;
1148 let vis = input.parse()?;
1149 input.parse::<Token![trait]>()?;
1150 let ident: Ident = input.parse()?;
1151 let content;
1152 braced!(content in input);
1153 let mut rpcs = Vec::<RpcMethod>::new();
1154 while !content.is_empty() {
1155 rpcs.push(content.parse()?);
1156 }
1157
1158 Ok(Self {
1159 attrs,
1160 vis,
1161 ident,
1162 rpcs,
1163 })
1164 }
1165}
1166
1167impl Parse for RpcMethod {
1168 fn parse(input: ParseStream) -> syn::Result<Self> {
1169 let mut errors = Ok(());
1170 let attrs = input.call(Attribute::parse_outer)?;
1171 let (post_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
1172 attr.path
1173 .segments
1174 .last()
1175 .is_some_and(|last_segment| last_segment.ident == "post")
1176 });
1177 let mut transfer: HashSet<Ident> = HashSet::new();
1178 let mut post: HashSet<Ident> = HashSet::new();
1179 for post_attr in post_attrs {
1180 let parsed_args =
1181 post_attr.parse_args_with(Punctuated::<NestedMeta, Token![,]>::parse_terminated)?;
1182 for parsed_arg in parsed_args {
1183 match &parsed_arg {
1184 NestedMeta::Meta(meta) => match meta {
1185 Meta::Path(path) => {
1186 if let Some(segment) = path.segments.last() {
1187 post.insert(segment.ident.clone());
1188 }
1189 }
1190 Meta::List(list) => match list.path.segments.last() {
1191 Some(last_segment) if last_segment.ident == "transfer" => {
1192 if list.nested.len() != 1 {
1193 extend_errors!(
1194 errors,
1195 syn::Error::new(
1196 parsed_arg.span(),
1197 "Syntax error in post attribute"
1198 )
1199 );
1200 }
1201 match list.nested.first() {
1202 Some(NestedMeta::Meta(Meta::Path(path))) => {
1203 match path.segments.last() {
1204 Some(segment) => {
1205 post.insert(segment.ident.clone());
1206 transfer.insert(segment.ident.clone());
1207 }
1208 _ => extend_errors!(
1209 errors,
1210 syn::Error::new(
1211 parsed_arg.span(),
1212 "Syntax error in post attribute"
1213 )
1214 ),
1215 }
1216 }
1217 _ => extend_errors!(
1218 errors,
1219 syn::Error::new(
1220 parsed_arg.span(),
1221 "Syntax error in post attribute"
1222 )
1223 ),
1224 }
1225 }
1226 _ => extend_errors!(
1227 errors,
1228 syn::Error::new(
1229 parsed_arg.span(),
1230 "Syntax error in post attribute"
1231 )
1232 ),
1233 },
1234 _ => extend_errors!(
1235 errors,
1236 syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
1237 ),
1238 },
1239 _ => extend_errors!(
1240 errors,
1241 syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
1242 ),
1243 }
1244 }
1245 }
1246
1247 let is_async = input.parse::<Token![async]>().ok();
1248 input.parse::<Token![fn]>()?;
1249 let ident: Ident = input.parse()?;
1250 let content;
1251 parenthesized!(content in input);
1252 let mut receiver: Option<syn::Receiver> = None;
1253 let mut args = Vec::new();
1254 for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
1255 match arg {
1256 FnArg::Typed(captured) => match &*captured.pat {
1257 Pat::Ident(_) => args.push(captured),
1258 _ => {
1259 extend_errors!(
1260 errors,
1261 syn::Error::new(
1262 captured.pat.span(),
1263 "patterns are not allowed in RPC arguments"
1264 )
1265 )
1266 }
1267 },
1268 FnArg::Receiver(ref recv) => {
1269 if recv.reference.is_none() || recv.mutability.is_some() {
1270 extend_errors!(
1271 errors,
1272 syn::Error::new(
1273 arg.span(),
1274 "RPC methods only support `&self` as a receiver"
1275 )
1276 );
1277 }
1278 receiver = Some(recv.clone());
1279 }
1280 }
1281 }
1282 let receiver = match receiver {
1283 Some(r) => r,
1284 None => {
1285 extend_errors!(
1286 errors,
1287 syn::Error::new(
1288 ident.span(),
1289 "RPC methods must include `&self` as the first parameter"
1290 )
1291 );
1292 parse_quote!(&self)
1293 }
1294 };
1295 let output: ReturnType = input.parse()?;
1296 input.parse::<Token![;]>()?;
1297
1298 let arg_names: HashSet<_> = args
1299 .iter()
1300 .filter_map(|arg| match &*arg.pat {
1301 Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
1302 _ => None,
1303 })
1304 .collect();
1305 let return_ident = Ident::new("return", output.span());
1306 for ident in &post {
1307 if *ident != return_ident && !arg_names.contains(ident) {
1308 extend_errors!(
1309 errors,
1310 syn::Error::new(
1311 ident.span(),
1312 format!("`{}` does not match any parameter", ident)
1313 )
1314 );
1315 }
1316 }
1317 for ident in &transfer {
1318 if *ident != return_ident && !post.contains(ident) {
1319 extend_errors!(
1320 errors,
1321 syn::Error::new(
1322 ident.span(),
1323 format!("`{}` is marked as transfer but not as post", ident)
1324 )
1325 );
1326 }
1327 }
1328 errors?;
1329
1330 Ok(Self {
1331 is_async,
1332 attrs,
1333 receiver,
1334 ident,
1335 args,
1336 post,
1337 transfer,
1338 output,
1339 })
1340 }
1341}
1342
1343#[proc_macro_attribute]
1349pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
1350 let Service {
1351 ref attrs,
1352 ref vis,
1353 ref ident,
1354 ref rpcs,
1355 } = parse_macro_input!(input as Service);
1356
1357 let camel_case_fn_names: &Vec<_> = &rpcs
1358 .iter()
1359 .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
1360 .collect();
1361
1362 let has_borrowed_args = rpcs.iter().any(|rpc| {
1363 rpc.args.iter().any(|arg| {
1364 matches!(&*arg.pat, Pat::Ident(pat_ident) if !rpc.post.contains(&pat_ident.ident))
1365 && matches!(&*arg.ty, Type::Reference(_))
1366 })
1367 });
1368
1369 let has_streaming_methods = rpcs.iter().any(
1370 |rpc| matches!(&rpc.output, ReturnType::Type(_, ref ty) if stream_item_type(ty).is_some()),
1371 );
1372
1373 ServiceGenerator {
1374 trait_ident: ident,
1375 service_ident: &format_ident!("{}Service", ident),
1376 client_ident: &format_ident!("{}Client", ident),
1377 request_ident: &format_ident!("{}Request", ident),
1378 response_ident: &format_ident!("{}Response", ident),
1379 vis,
1380 attrs,
1381 rpcs,
1382 camel_case_idents: &rpcs
1383 .iter()
1384 .zip(camel_case_fn_names.iter())
1385 .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
1386 .collect::<Vec<_>>(),
1387 has_borrowed_args,
1388 has_streaming_methods,
1389 }
1390 .into_token_stream()
1391 .into()
1392}
1393
1394fn snake_to_camel(ident_str: &str) -> String {
1395 let mut camel_ty = String::with_capacity(ident_str.len());
1396
1397 let mut last_char_was_underscore = true;
1398 for c in ident_str.chars() {
1399 match c {
1400 '_' => last_char_was_underscore = true,
1401 c if last_char_was_underscore => {
1402 camel_ty.extend(c.to_uppercase());
1403 last_char_was_underscore = false;
1404 }
1405 c => camel_ty.extend(c.to_lowercase()),
1406 }
1407 }
1408
1409 camel_ty.shrink_to_fit();
1410 camel_ty
1411}