1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use quote::{format_ident, quote, ToTokens};
6use syn::{
7 braced,
8 ext::IdentExt,
9 parenthesized,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote,
12 punctuated::Punctuated,
13 spanned::Spanned,
14 token::Comma,
15 Attribute, FnArg, Ident, Meta, NestedMeta, Pat, PatType, ReturnType, Token, Type, Visibility,
16};
17
18macro_rules! extend_errors {
19 ($errors: ident, $e: expr) => {
20 match $errors {
21 Ok(_) => $errors = Err($e),
22 Err(ref mut errors) => errors.extend($e),
23 }
24 };
25}
26
27struct Service {
28 attrs: Vec<Attribute>,
29 vis: Visibility,
30 ident: Ident,
31 rpcs: Vec<RpcMethod>,
32}
33
34struct RpcMethod {
35 is_async: bool,
36 attrs: Vec<Attribute>,
37 ident: Ident,
38 args: Vec<PatType>,
39 transfer: HashSet<Ident>,
40 post: HashSet<Ident>,
41 output: ReturnType,
42}
43
44struct ServiceGenerator<'a> {
45 trait_ident: &'a Ident,
46 service_ident: &'a Ident,
47 client_ident: &'a Ident,
48 request_ident: &'a Ident,
49 response_ident: &'a Ident,
50 vis: &'a Visibility,
51 attrs: &'a [Attribute],
52 rpcs: &'a [RpcMethod],
53 camel_case_idents: &'a [Ident],
54}
55
56impl<'a> ServiceGenerator<'a> {
57 fn enum_request(&self) -> TokenStream2 {
58 let &Self {
59 vis,
60 request_ident,
61 camel_case_idents,
62 rpcs,
63 ..
64 } = self;
65 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
66 |(RpcMethod { args, post, .. }, camel_case_ident)| {
67 let args_filtered = args.iter().filter(
68 |arg| matches!(&*arg.pat, Pat::Ident(ident) if !post.contains(&ident.ident)),
69 );
70 quote! {
71 #camel_case_ident { #( #args_filtered ),* }
72 }
73 },
74 );
75 quote! {
76 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
77 #vis enum #request_ident {
78 #( #variants ),*
79 }
80 }
81 }
82
83 fn enum_response(&self) -> TokenStream2 {
84 let &Self {
85 vis,
86 response_ident,
87 camel_case_idents,
88 rpcs,
89 ..
90 } = self;
91 let variants = rpcs.iter().zip(camel_case_idents.iter()).map(
92 |(RpcMethod { output, post, .. }, camel_case_ident)| match output {
93 ReturnType::Type(_, ty) if !post.contains(&Ident::new("return", output.span())) => {
94 quote! {
95 #camel_case_ident ( #ty )
96 }
97 }
98 _ => quote! {
99 #camel_case_ident ( () )
100 },
101 },
102 );
103 quote! {
104 #[derive(web_rpc::serde::Serialize, web_rpc::serde::Deserialize)]
105 #vis enum #response_ident {
106 #( #variants ),*
107 }
108 }
109 }
110
111 fn trait_service(&self) -> TokenStream2 {
112 let &Self {
113 attrs,
114 rpcs,
115 vis,
116 trait_ident,
117 ..
118 } = self;
119
120 let unit_type: &Type = &parse_quote!(());
121 let rpc_fns = rpcs.iter().map(
122 |RpcMethod {
123 attrs,
124 args,
125 ident,
126 is_async,
127 output,
128 ..
129 }| {
130 let output = match output {
131 ReturnType::Type(_, ref ty) => ty,
132 ReturnType::Default => unit_type,
133 };
134 let is_async = match is_async {
135 true => quote!(async),
136 false => quote!(),
137 };
138 quote! {
139 #( #attrs )*
140 #is_async fn #ident(&self, #( #args ),*) -> #output;
141 }
142 },
143 );
144
145 let forward_fns = rpcs
146 .iter()
147 .map(
148 |RpcMethod {
149 attrs,
150 args,
151 ident,
152 is_async,
153 output,
154 ..
155 }| {
156 let output = match output {
157 ReturnType::Type(_, ref ty) => ty,
158 ReturnType::Default => unit_type,
159 };
160 let do_await = match is_async {
161 true => quote!(.await),
162 false => quote!(),
163 };
164 let is_async = match is_async {
165 true => quote!(async),
166 false => quote!(),
167 };
168 let forward_args = args.iter().filter_map(|arg| match &*arg.pat {
169 Pat::Ident(ident) => Some(&ident.ident),
170 _ => None,
171 });
172 quote! {
173 #( #attrs )*
174 #is_async fn #ident(&self, #( #args ),*) -> #output {
175 T::#ident(self, #( #forward_args ),*)#do_await
176 }
177 }
178 },
179 )
180 .collect::<Vec<_>>();
181
182 quote! {
183 #( #attrs )*
184 #vis trait #trait_ident {
185 #( #rpc_fns )*
186 }
187
188 impl<T> #trait_ident for std::sync::Arc<T> where T: #trait_ident {
189 #( #forward_fns )*
190 }
191 impl<T> #trait_ident for std::boxed::Box<T> where T: #trait_ident {
192 #( #forward_fns )*
193 }
194 impl<T> #trait_ident for std::rc::Rc<T> where T: #trait_ident {
195 #( #forward_fns )*
196 }
197 }
198 }
199
200 fn struct_client(&self) -> TokenStream2 {
201 let &Self {
202 vis,
203 client_ident,
204 request_ident,
205 response_ident,
206 camel_case_idents,
207 rpcs,
208 ..
209 } = self;
210
211 let rpc_fns = rpcs
212 .iter()
213 .zip(camel_case_idents.iter())
214 .map(|(RpcMethod { attrs, args, transfer, post, ident, output, .. }, camel_case_ident)| {
215 let serialize_arg_idents = args.iter()
217 .filter_map(|arg| match &*arg.pat {
218 Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
219 _ => None
220 });
221 let post_arg_idents = args.iter()
222 .filter_map(|arg| match &*arg.pat {
223 Pat::Ident(ident) if post.contains(&ident.ident) => Some(&ident.ident),
224 _ => None
225 });
226 let transfer_arg_idents = args.iter()
227 .filter_map(|arg| match &*arg.pat {
228 Pat::Ident(ident) if transfer.contains(&ident.ident) => Some(&ident.ident),
229 _ => None
230 });
231
232 let return_type = match output {
233 ReturnType::Type(_, ref ty) => quote! {
234 web_rpc::client::RequestFuture<#ty>
235 },
236 _ => quote!(())
237 };
238 let maybe_register_callback = match output {
239 ReturnType::Type(_, _) => quote! {
240 let (__response_tx, __response_rx) =
241 web_rpc::futures_channel::oneshot::channel();
242 self.callback_map.borrow_mut().insert(__seq_id, __response_tx);
243 },
244 _ => Default::default()
245 };
246
247 let unpack_response = if post.contains(&Ident::new("return", output.span())) {
248 let unit_output: &Type = &parse_quote!(());
249 let output = match output {
250 ReturnType::Type(_, ref ty) => ty,
251 _ => unit_output
252 };
253 quote! {
254 let (_, __post_response) = response;
255 web_rpc::wasm_bindgen::JsCast::dyn_into::<#output>(__post_response.shift())
256 .unwrap()
257 }
258 } else {
259 quote! {
260 let (__serialize_response, _) = response;
261 let #response_ident::#camel_case_ident(__inner) = __serialize_response else {
262 panic!("received incorrect response variant")
263 };
264 __inner
265 }
266 };
267
268 let maybe_unpack_and_return_future = match output {
269 ReturnType::Type(_, _) => quote! {
270 let __response_future = web_rpc::futures_util::FutureExt::map(
271 __response_rx,
272 |response| {
273 let response = response.unwrap();
274 #unpack_response
275 }
276 );
277 let __abort_sender = self.abort_sender.clone();
278 let __dispatcher = self.dispatcher.clone();
279 web_rpc::client::RequestFuture::new(
280 __response_future,
281 __dispatcher,
282 std::boxed::Box::new(move || (__abort_sender)(__seq_id)))
283 },
284 _ => Default::default()
285 };
286
287 quote! {
288 #( #attrs )*
289 #vis fn #ident(
290 &self,
291 #( #args ),*
292 ) -> #return_type {
293 let __seq_id = self.seq_id.replace_with(|seq_id| seq_id.wrapping_add(1));
294 let __request = #request_ident::#camel_case_ident {
295 #( #serialize_arg_idents ),*
296 };
297 let __serialized = (self.request_serializer)(__seq_id, __request);
298 let __serialized = js_sys::Uint8Array::from(&__serialized[..]).buffer();
299 let __post: &[&wasm_bindgen::JsValue] =
300 &[__serialized.as_ref(), #( #post_arg_idents.as_ref() ),*];
301 let __post = web_rpc::js_sys::Array::from_iter(__post);
302 let __transfer: &[&wasm_bindgen::JsValue] =
303 &[__serialized.as_ref(), #( #transfer_arg_idents.as_ref() ),*];
304 let __transfer = web_rpc::js_sys::Array::from_iter(__transfer);
305 #maybe_register_callback
306 self.port.post_message(&__post, &__transfer).unwrap();
307 #maybe_unpack_and_return_future
308 }
309 }
310 });
311
312 quote! {
313 #[derive(core::clone::Clone)]
314 #vis struct #client_ident {
315 callback_map: std::rc::Rc<
316 std::cell::RefCell<
317 web_rpc::client::CallbackMap<#response_ident>
318 >
319 >,
320 port: web_rpc::port::Port,
321 listener: std::rc::Rc<web_rpc::gloo_events::EventListener>,
322 dispatcher: web_rpc::futures_util::future::Shared<
323 web_rpc::futures_core::future::LocalBoxFuture<'static, ()>
324 >,
325 request_serializer: std::rc::Rc<
326 dyn std::ops::Fn(usize, #request_ident) -> std::vec::Vec<u8>
327 >,
328 abort_sender: std::rc::Rc<dyn std::ops::Fn(usize)>,
329 seq_id: std::rc::Rc<std::cell::RefCell<usize>>
330 }
331 impl std::fmt::Debug for #client_ident {
332 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 formatter.debug_struct(std::stringify!(#client_ident))
334 .finish()
335 }
336 }
337 impl web_rpc::client::Client for #client_ident {
338 type Request = #request_ident;
339 type Response = #response_ident;
340 }
341 impl From<web_rpc::client::Configuration<#request_ident, #response_ident>>
342 for #client_ident {
343 fn from((callback_map, port, listener, dispatcher, request_serializer, abort_sender):
344 web_rpc::client::Configuration<#request_ident, #response_ident>) -> Self {
345 Self {
346 callback_map,
347 port,
348 listener,
349 dispatcher,
350 request_serializer,
351 abort_sender,
352 seq_id: std::default::Default::default()
353 }
354 }
355 }
356 impl #client_ident {
357 #( #rpc_fns )*
358 }
359 }
360 }
361
362 fn struct_server(&self) -> TokenStream2 {
363 let &Self {
364 vis,
365 trait_ident,
366 service_ident,
367 request_ident,
368 response_ident,
369 camel_case_idents,
370 rpcs,
371 ..
372 } = self;
373
374 let handlers = rpcs.iter()
375 .zip(camel_case_idents.iter())
376 .map(|(RpcMethod { is_async, ident, args, transfer, post, output, .. }, camel_case_ident)| {
377 let serialize_arg_idents = args.iter()
378 .filter_map(|arg| match &*arg.pat {
379 Pat::Ident(ident) if !post.contains(&ident.ident) => Some(&ident.ident),
380 _ => None
381 });
382 let extract_js_args = args.iter()
383 .filter_map(|arg| match &*arg.pat {
384 Pat::Ident(ident) if post.contains(&ident.ident) => {
385 let arg_pat = &arg.pat;
386 let arg_ty = &arg.ty;
387 Some(quote! {
388 let #arg_pat = web_rpc::wasm_bindgen::JsCast::dyn_into::<#arg_ty>(__js_args.shift())
389 .unwrap();
390 })
391 },
392 _ => None
393 });
394 let return_ident = Ident::new("return", output.span());
395 let return_response = match (post.contains(&return_ident), transfer.contains(&return_ident)) {
396 (false, _) => quote! {
397 let __post = web_rpc::js_sys::Array::new();
398 let __transfer = web_rpc::js_sys::Array::new();
399 (Self::Response::#camel_case_ident(__response), __post, __transfer)
400 },
401 (true, false) => quote! {
402 let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
403 let __transfer = web_rpc::js_sys::Array::new();
404 (Self::Response::#camel_case_ident(()), __post, __transfer)
405 },
406 (true, true) => quote! {
407 let __post = web_rpc::js_sys::Array::of1(__response.as_ref());
408 let __transfer = web_rpc::js_sys::Array::of1(__response.as_ref());
409 (Self::Response::#camel_case_ident(()), __post, __transfer)
410 }
411 };
412 let args = args.iter().filter_map(|arg| match &*arg.pat {
413 Pat::Ident(ident) => Some(&ident.ident),
414 _ => None
415 });
416 match is_async {
417 true => quote! {
418 Self::Request::#camel_case_ident { #( #serialize_arg_idents ),* } => {
419 #( #extract_js_args )*
420 let __task =
421 web_rpc::futures_util::FutureExt::fuse(self.server_impl.#ident(#( #args ),*));
422 web_rpc::pin_utils::pin_mut!(__task);
423 web_rpc::futures_util::select! {
424 _ = __abort_rx => None,
425 __response = __task => Some({
426 #return_response
427 })
428 }
429 }
430 },
431 false => quote! {
432 Self::Request::#camel_case_ident { #( #serialize_arg_idents ),* } => {
433 #( #extract_js_args )*
434 let __response = self.server_impl.#ident(#( #args ),*);
435 Some({
436 #return_response
437 })
438 }
439 }
440 }
441 });
442
443 quote! {
444 #vis struct #service_ident<T> {
445 server_impl: T
446 }
447 impl<T: #trait_ident> web_rpc::service::Service for #service_ident<T> {
448 type Request = #request_ident;
449 type Response = #response_ident;
450 async fn execute(
451 &self,
452 __seq_id: usize,
453 mut __abort_rx: web_rpc::futures_channel::oneshot::Receiver<()>,
454 __request: Self::Request,
455 __js_args: web_rpc::js_sys::Array
456 ) -> (usize, Option<(Self::Response, web_rpc::js_sys::Array, web_rpc::js_sys::Array)>) {
457 let __result = match __request {
458 #( #handlers )*
459 };
460 (__seq_id, __result)
461 }
462 }
463 impl<T: #trait_ident> std::convert::From<T> for #service_ident<T> {
464 fn from(server_impl: T) -> Self {
465 Self { server_impl }
466 }
467 }
468 }
469 }
470}
471
472impl<'a> ToTokens for ServiceGenerator<'a> {
473 fn to_tokens(&self, output: &mut TokenStream2) {
474 output.extend(vec![
475 self.enum_request(),
476 self.enum_response(),
477 self.trait_service(),
478 self.struct_client(),
479 self.struct_server(),
480 ])
481 }
482}
483
484impl Parse for Service {
485 fn parse(input: ParseStream) -> syn::Result<Self> {
486 let attrs = input.call(Attribute::parse_outer)?;
487 let vis = input.parse()?;
488 input.parse::<Token![trait]>()?;
489 let ident: Ident = input.parse()?;
490 let content;
491 braced!(content in input);
492 let mut rpcs = Vec::<RpcMethod>::new();
493 while !content.is_empty() {
494 rpcs.push(content.parse()?);
495 }
496
497 Ok(Self {
498 attrs,
499 vis,
500 ident,
501 rpcs,
502 })
503 }
504}
505
506impl Parse for RpcMethod {
507 fn parse(input: ParseStream) -> syn::Result<Self> {
508 let mut errors = Ok(());
509 let attrs = input.call(Attribute::parse_outer)?;
510 let (post_attrs, attrs): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|attr| {
511 attr.path
512 .segments
513 .last()
514 .is_some_and(|last_segment| last_segment.ident == "post")
515 });
516 let mut transfer: HashSet<Ident> = HashSet::new();
517 let mut post: HashSet<Ident> = HashSet::new();
518 for post_attr in post_attrs {
519 let parsed_args =
520 post_attr.parse_args_with(Punctuated::<NestedMeta, Token![,]>::parse_terminated)?;
521 for parsed_arg in parsed_args {
522 match &parsed_arg {
523 NestedMeta::Meta(meta) => match meta {
524 Meta::Path(path) => {
525 if let Some(segment) = path.segments.last() {
526 post.insert(segment.ident.clone());
527 }
528 }
529 Meta::List(list) => match list.path.segments.last() {
530 Some(last_segment) if last_segment.ident == "transfer" => {
531 if list.nested.len() != 1 {
532 extend_errors!(
533 errors,
534 syn::Error::new(
535 parsed_arg.span(),
536 "Syntax error in post attribute"
537 )
538 );
539 }
540 match list.nested.first() {
541 Some(NestedMeta::Meta(Meta::Path(path))) => {
542 match path.segments.last() {
543 Some(segment) => {
544 post.insert(segment.ident.clone());
545 transfer.insert(segment.ident.clone());
546 }
547 _ => extend_errors!(
548 errors,
549 syn::Error::new(
550 parsed_arg.span(),
551 "Syntax error in post attribute"
552 )
553 ),
554 }
555 }
556 _ => extend_errors!(
557 errors,
558 syn::Error::new(
559 parsed_arg.span(),
560 "Syntax error in post attribute"
561 )
562 ),
563 }
564 }
565 _ => extend_errors!(
566 errors,
567 syn::Error::new(
568 parsed_arg.span(),
569 "Syntax error in post attribute"
570 )
571 ),
572 },
573 _ => extend_errors!(
574 errors,
575 syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
576 ),
577 },
578 _ => extend_errors!(
579 errors,
580 syn::Error::new(parsed_arg.span(), "Syntax error in post attribute")
581 ),
582 }
583 }
584 }
585
586 let is_async = input.parse::<Token![async]>().is_ok();
587 input.parse::<Token![fn]>()?;
588 let ident = input.parse()?;
589 let content;
590 parenthesized!(content in input);
591 let mut args = Vec::new();
592 for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
593 match arg {
594 FnArg::Typed(captured) => match &*captured.pat {
595 Pat::Ident(_) => args.push(captured),
596 _ => {
597 extend_errors!(
598 errors,
599 syn::Error::new(
600 captured.pat.span(),
601 "patterns are not allowed in RPC arguments"
602 )
603 )
604 }
605 },
606 FnArg::Receiver(_) => {
607 extend_errors!(
608 errors,
609 syn::Error::new(arg.span(), "receivers are not allowed in RPC arguments")
610 );
611 }
612 }
613 }
614 errors?;
615 let output = input.parse()?;
616 input.parse::<Token![;]>()?;
617
618 Ok(Self {
619 is_async,
620 attrs,
621 ident,
622 args,
623 post,
624 transfer,
625 output,
626 })
627 }
628}
629
630#[proc_macro_attribute]
636pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream {
637 let Service {
638 ref attrs,
639 ref vis,
640 ref ident,
641 ref rpcs,
642 } = parse_macro_input!(input as Service);
643
644 let camel_case_fn_names: &Vec<_> = &rpcs
645 .iter()
646 .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
647 .collect();
648
649 ServiceGenerator {
650 trait_ident: ident,
651 service_ident: &format_ident!("{}Service", ident),
652 client_ident: &format_ident!("{}Client", ident),
653 request_ident: &format_ident!("{}Request", ident),
654 response_ident: &format_ident!("{}Response", ident),
655 vis,
656 attrs,
657 rpcs,
658 camel_case_idents: &rpcs
659 .iter()
660 .zip(camel_case_fn_names.iter())
661 .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
662 .collect::<Vec<_>>(),
663 }
664 .into_token_stream()
665 .into()
666}
667
668fn snake_to_camel(ident_str: &str) -> String {
669 let mut camel_ty = String::with_capacity(ident_str.len());
670
671 let mut last_char_was_underscore = true;
672 for c in ident_str.chars() {
673 match c {
674 '_' => last_char_was_underscore = true,
675 c if last_char_was_underscore => {
676 camel_ty.extend(c.to_uppercase());
677 last_char_was_underscore = false;
678 }
679 c => camel_ty.extend(c.to_lowercase()),
680 }
681 }
682
683 camel_ty.shrink_to_fit();
684 camel_ty
685}