1#![recursion_limit = "512"]
8
9extern crate proc_macro;
10extern crate proc_macro2;
11extern crate quote;
12extern crate syn;
13
14use proc_macro::TokenStream;
15use proc_macro2::TokenStream as TokenStream2;
16use quote::{format_ident, quote, ToTokens};
17use syn::{
18 braced,
19 ext::IdentExt,
20 parenthesized,
21 parse::{Parse, ParseStream},
22 parse_macro_input, parse_quote,
23 spanned::Spanned,
24 token::Comma,
25 AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path,
26 ReturnType, Token, Type, Visibility,
27};
28
29macro_rules! extend_errors {
33 ($errors: ident, $e: expr) => {
34 match $errors {
35 Ok(_) => $errors = Err($e),
36 Err(ref mut errors) => errors.extend($e),
37 }
38 };
39}
40
41struct Service {
42 attrs: Vec<Attribute>,
43 vis: Visibility,
44 ident: Ident,
45 rpcs: Vec<RpcMethod>,
46}
47
48struct RpcMethod {
49 attrs: Vec<Attribute>,
50 ident: Ident,
51 args: Vec<PatType>,
52 output: ReturnType,
53}
54
55impl Parse for Service {
56 fn parse(input: ParseStream) -> syn::Result<Self> {
57 let attrs = input.call(Attribute::parse_outer)?;
58 let vis = input.parse()?;
59 input.parse::<Token![trait]>()?;
60 let ident: Ident = input.parse()?;
61 let content;
62 braced!(content in input);
63 let mut rpcs = Vec::<RpcMethod>::new();
64 while !content.is_empty() {
65 rpcs.push(content.parse()?);
66 }
67 let mut ident_errors = Ok(());
68 for rpc in &rpcs {
69 if rpc.ident == "new" {
70 extend_errors!(
71 ident_errors,
72 syn::Error::new(
73 rpc.ident.span(),
74 format!(
75 "method name conflicts with generated fn `{}Client::new`",
76 ident.unraw()
77 )
78 )
79 );
80 }
81 if rpc.ident == "serve" {
82 extend_errors!(
83 ident_errors,
84 syn::Error::new(
85 rpc.ident.span(),
86 format!("method name conflicts with generated fn `{ident}::serve`")
87 )
88 );
89 }
90 }
91 ident_errors?;
92
93 Ok(Self {
94 attrs,
95 vis,
96 ident,
97 rpcs,
98 })
99 }
100}
101
102impl Parse for RpcMethod {
103 fn parse(input: ParseStream) -> syn::Result<Self> {
104 let attrs = input.call(Attribute::parse_outer)?;
105 input.parse::<Token![async]>()?;
106 input.parse::<Token![fn]>()?;
107 let ident = input.parse()?;
108 let content;
109 parenthesized!(content in input);
110 let mut args = Vec::new();
111 let mut errors = Ok(());
112 for arg in content.parse_terminated(FnArg::parse, Comma)? {
113 match arg {
114 FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
115 args.push(captured);
116 }
117 FnArg::Typed(captured) => {
118 extend_errors!(
119 errors,
120 syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
121 );
122 }
123 FnArg::Receiver(_) => {
124 extend_errors!(
125 errors,
126 syn::Error::new(arg.span(), "method args cannot start with self")
127 );
128 }
129 }
130 }
131 errors?;
132 let output = input.parse()?;
133 input.parse::<Token![;]>()?;
134
135 Ok(Self {
136 attrs,
137 ident,
138 args,
139 output,
140 })
141 }
142}
143
144#[derive(Default)]
145struct DeriveMeta {
146 derive: Option<Derive>,
147 warnings: Vec<TokenStream2>,
148}
149
150impl DeriveMeta {
151 fn with_derives(mut self, new: Vec<Path>) -> Self {
152 match self.derive.as_mut() {
153 Some(Derive::Explicit(old)) => old.extend(new),
154 _ => self.derive = Some(Derive::Explicit(new)),
155 }
156
157 self
158 }
159}
160
161enum Derive {
162 Explicit(Vec<Path>),
163 Serde(bool),
164}
165
166impl Parse for DeriveMeta {
167 fn parse(input: ParseStream) -> syn::Result<Self> {
168 let mut result = Ok(DeriveMeta::default());
169
170 let mut derives = Vec::new();
171 let mut derive_serde = Vec::new();
172 let mut has_derive_serde = false;
173 let mut has_explicit_derives = false;
174
175 let meta_items = input.parse_terminated(MetaNameValue::parse, Comma)?;
176 for meta in meta_items {
177 if meta.path.segments.len() != 1 {
178 extend_errors!(
179 result,
180 syn::Error::new(
181 meta.span(),
182 "tarpc::service does not support this meta item"
183 )
184 );
185 continue;
186 }
187 let segment = meta.path.segments.first().unwrap();
188 if segment.ident == "derive" {
189 has_explicit_derives = true;
190 let Expr::Array(ref array) = meta.value else {
191 extend_errors!(
192 result,
193 syn::Error::new(
194 meta.span(),
195 "tarpc::service does not support this meta item"
196 )
197 );
198 continue;
199 };
200
201 let paths = array
202 .elems
203 .iter()
204 .filter_map(|e| {
205 if let Expr::Path(path) = e {
206 Some(path.path.clone())
207 } else {
208 extend_errors!(
209 result,
210 syn::Error::new(e.span(), "Expected Path or Type")
211 );
212 None
213 }
214 })
215 .collect::<Vec<_>>();
216
217 result = result.map(|d| d.with_derives(paths));
218 derives.push(meta);
219 } else if segment.ident == "derive_serde" {
220 has_derive_serde = true;
221 let Expr::Lit(expr_lit) = &meta.value else {
222 extend_errors!(
223 result,
224 syn::Error::new(meta.value.span(), "expected literal")
225 );
226 continue;
227 };
228 match expr_lit.lit {
229 Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => {
230 result = result.map(|d| DeriveMeta {
231 derive: Some(Derive::Serde(true)),
232 ..d
233 })
234 }
235 Lit::Bool(LitBool { value: true, .. }) => {
236 extend_errors!(
237 result,
238 syn::Error::new(
239 meta.span(),
240 "To enable serde, first enable the `serde1` feature of tarpc"
241 )
242 );
243 }
244 Lit::Bool(LitBool { value: false, .. }) => {
245 result = result.map(|d| DeriveMeta {
246 derive: Some(Derive::Serde(false)),
247 ..d
248 })
249 }
250 _ => extend_errors!(
251 result,
252 syn::Error::new(
253 expr_lit.lit.span(),
254 "`derive_serde` expects a value of type `bool`"
255 )
256 ),
257 }
258 derive_serde.push(meta);
259 } else {
260 extend_errors!(
261 result,
262 syn::Error::new(
263 meta.span(),
264 "tarpc::service does not support this meta item"
265 )
266 );
267 continue;
268 }
269 }
270
271 if has_derive_serde {
272 let deprecation_hack = quote! {
273 const _: () = {
274 #[deprecated(
275 note = "\nThe form `tarpc::service(derive_serde = true)` is deprecated.\
276 \nUse `tarpc::service(derive = [Serialize, Deserialize])`."
277 )]
278 const DEPRECATED_SYNTAX: () = ();
279 let _ = DEPRECATED_SYNTAX;
280 };
281 };
282
283 result = result.map(|mut d| {
284 d.warnings.push(deprecation_hack.to_token_stream());
285 d
286 });
287 }
288
289 if has_explicit_derives & has_derive_serde {
290 extend_errors!(
291 result,
292 syn::Error::new(
293 input.span(),
294 "tarpc does not support `derive_serde` and `derive` at the same time"
295 )
296 );
297 }
298
299 if derive_serde.len() > 1 {
300 for (i, derive_serde) in derive_serde.iter().enumerate() {
301 extend_errors!(
302 result,
303 syn::Error::new(
304 derive_serde.span(),
305 format!(
306 "`derive_serde` appears more than once (occurrence #{})",
307 i + 1
308 )
309 )
310 );
311 }
312 }
313
314 if derives.len() > 1 {
315 for (i, derive) in derives.iter().enumerate() {
316 extend_errors!(
317 result,
318 syn::Error::new(
319 derive.span(),
320 format!("`derive` appears more than once (occurrence #{})", i + 1)
321 )
322 );
323 }
324 }
325
326 result
327 }
328}
329
330#[proc_macro_attribute]
340#[cfg(feature = "serde1")]
341pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream {
342 let mut gen: proc_macro2::TokenStream = quote! {
343 #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
344 #[serde(crate = "::tarpc::serde")]
345 };
346 gen.extend(proc_macro2::TokenStream::from(item));
347 proc_macro::TokenStream::from(gen)
348}
349
350fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
351 rpcs.iter()
352 .map(|rpc| {
353 rpc.attrs
354 .iter()
355 .filter(|att| {
356 att.style == AttrStyle::Outer
357 && match &att.meta {
358 syn::Meta::List(syn::MetaList { path, .. }) => {
359 path.get_ident() == Some(&Ident::new("cfg", rpc.ident.span()))
360 }
361 _ => false,
362 }
363 })
364 .collect::<Vec<_>>()
365 })
366 .collect::<Vec<_>>()
367}
368
369#[proc_macro_attribute]
415pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream {
416 let derive_meta = parse_macro_input!(attr as DeriveMeta);
417 let unit_type: &Type = &parse_quote!(());
418 let Service {
419 ref attrs,
420 ref vis,
421 ref ident,
422 ref rpcs,
423 } = parse_macro_input!(input as Service);
424
425 let camel_case_fn_names: &Vec<_> = &rpcs
426 .iter()
427 .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string()))
428 .collect();
429 let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::<Vec<_>>();
430
431 let derives = match derive_meta.derive.as_ref() {
432 Some(Derive::Explicit(paths)) => {
433 if !paths.is_empty() {
434 Some(quote! {
435 #[derive(
436 #(
437 #paths
438 ),*
439 )]
440 })
441 } else {
442 None
443 }
444 }
445 Some(Derive::Serde(serde)) => {
446 if *serde {
447 Some(quote! {
448 #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
449 #[serde(crate = "::tarpc::serde")]
450 })
451 } else {
452 None
453 }
454 }
455 None => {
456 if cfg!(feature = "serde1") {
457 Some(quote! {
458 #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)]
459 #[serde(crate = "::tarpc::serde")]
460 })
461 } else {
462 None
463 }
464 }
465 };
466
467 let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::<Vec<_>>();
468 let request_names = methods
469 .iter()
470 .map(|m| format!("{ident}.{m}"))
471 .collect::<Vec<_>>();
472
473 ServiceGenerator {
474 service_ident: ident,
475 client_stub_ident: &format_ident!("{}Stub", ident),
476 server_ident: &format_ident!("Serve{}", ident),
477 client_ident: &format_ident!("{}Client", ident),
478 request_ident: &format_ident!("{}Request", ident),
479 response_ident: &format_ident!("{}Response", ident),
480 vis,
481 args,
482 method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::<Vec<_>>(),
483 method_cfgs: &collect_cfg_attrs(rpcs),
484 method_idents: &methods,
485 request_names: &request_names,
486 attrs,
487 rpcs,
488 return_types: &rpcs
489 .iter()
490 .map(|rpc| match rpc.output {
491 ReturnType::Type(_, ref ty) => ty.as_ref(),
492 ReturnType::Default => unit_type,
493 })
494 .collect::<Vec<_>>(),
495 arg_pats: &args
496 .iter()
497 .map(|args| args.iter().map(|arg| &*arg.pat).collect())
498 .collect::<Vec<_>>(),
499 camel_case_idents: &rpcs
500 .iter()
501 .zip(camel_case_fn_names.iter())
502 .map(|(rpc, name)| Ident::new(name, rpc.ident.span()))
503 .collect::<Vec<_>>(),
504 derives: derives.as_ref(),
505 warnings: &derive_meta.warnings,
506 }
507 .into_token_stream()
508 .into()
509}
510
511struct ServiceGenerator<'a> {
514 service_ident: &'a Ident,
515 client_stub_ident: &'a Ident,
516 server_ident: &'a Ident,
517 client_ident: &'a Ident,
518 request_ident: &'a Ident,
519 response_ident: &'a Ident,
520 vis: &'a Visibility,
521 attrs: &'a [Attribute],
522 rpcs: &'a [RpcMethod],
523 camel_case_idents: &'a [Ident],
524 method_idents: &'a [&'a Ident],
525 request_names: &'a [String],
526 method_attrs: &'a [&'a [Attribute]],
527 method_cfgs: &'a [Vec<&'a Attribute>],
528 args: &'a [&'a [PatType]],
529 return_types: &'a [&'a Type],
530 arg_pats: &'a [Vec<&'a Pat>],
531 derives: Option<&'a TokenStream2>,
532 warnings: &'a [TokenStream2],
533}
534
535impl ServiceGenerator<'_> {
536 fn trait_service(&self) -> TokenStream2 {
537 let &Self {
538 attrs,
539 rpcs,
540 vis,
541 return_types,
542 service_ident,
543 client_stub_ident,
544 request_ident,
545 response_ident,
546 server_ident,
547 ..
548 } = self;
549
550 let rpc_fns = rpcs
551 .iter()
552 .zip(return_types.iter())
553 .map(
554 |(
555 RpcMethod {
556 attrs, ident, args, ..
557 },
558 output,
559 )| {
560 quote! {
561 #( #attrs )*
562 async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
563 }
564 },
565 );
566
567 let stub_doc = format!("The stub trait for service [`{service_ident}`].");
568 quote! {
569 #( #attrs )*
570 #vis trait #service_ident: ::core::marker::Sized {
571 #( #rpc_fns )*
572
573 fn serve(self) -> #server_ident<Self> {
576 #server_ident { service: self }
577 }
578 }
579
580 #[doc = #stub_doc]
581 #vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
582 }
583
584 impl<S> #client_stub_ident for S
585 where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
586 {
587 }
588 }
589 }
590
591 fn struct_server(&self) -> TokenStream2 {
592 let &Self {
593 vis, server_ident, ..
594 } = self;
595
596 quote! {
597 #[derive(Clone)]
599 #vis struct #server_ident<S> {
600 service: S,
601 }
602 }
603 }
604
605 fn impl_serve_for_server(&self) -> TokenStream2 {
606 let &Self {
607 request_ident,
608 server_ident,
609 service_ident,
610 response_ident,
611 camel_case_idents,
612 arg_pats,
613 method_idents,
614 method_cfgs,
615 ..
616 } = self;
617
618 quote! {
619 impl<S> ::tarpc::server::Serve for #server_ident<S>
620 where S: #service_ident
621 {
622 type Req = #request_ident;
623 type Resp = #response_ident;
624
625
626 async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
627 -> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
628 match req {
629 #(
630 #( #method_cfgs )*
631 #request_ident::#camel_case_idents{ #( #arg_pats ),* } => {
632 ::core::result::Result::Ok(#response_ident::#camel_case_idents(
633 #service_ident::#method_idents(
634 self.service, ctx, #( #arg_pats ),*
635 ).await
636 ))
637 }
638 )*
639 }
640 }
641 }
642 }
643 }
644
645 fn enum_request(&self) -> TokenStream2 {
646 let &Self {
647 derives,
648 vis,
649 request_ident,
650 camel_case_idents,
651 args,
652 request_names,
653 method_cfgs,
654 ..
655 } = self;
656
657 quote! {
658 #[allow(missing_docs)]
660 #[derive(Debug)]
661 #derives
662 #vis enum #request_ident {
663 #(
664 #( #method_cfgs )*
665 #camel_case_idents{ #( #args ),* }
666 ),*
667 }
668 impl ::tarpc::RequestName for #request_ident {
669 fn name(&self) -> &str {
670 match self {
671 #(
672 #( #method_cfgs )*
673 #request_ident::#camel_case_idents{..} => {
674 #request_names
675 }
676 )*
677 }
678 }
679 }
680 }
681 }
682
683 fn enum_response(&self) -> TokenStream2 {
684 let &Self {
685 derives,
686 vis,
687 response_ident,
688 camel_case_idents,
689 return_types,
690 ..
691 } = self;
692
693 quote! {
694 #[allow(missing_docs)]
696 #[derive(Debug)]
697 #derives
698 #vis enum #response_ident {
699 #( #camel_case_idents(#return_types) ),*
700 }
701 }
702 }
703
704 fn struct_client(&self) -> TokenStream2 {
705 let &Self {
706 vis,
707 client_ident,
708 request_ident,
709 response_ident,
710 ..
711 } = self;
712
713 quote! {
714 #[allow(unused)]
715 #[derive(Clone, Debug)]
716 #vis struct #client_ident<
719 Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
720 >(Stub);
721 }
722 }
723
724 fn impl_client_new(&self) -> TokenStream2 {
725 let &Self {
726 client_ident,
727 vis,
728 request_ident,
729 response_ident,
730 ..
731 } = self;
732
733 quote! {
734 impl #client_ident {
735 #vis fn new<T>(config: ::tarpc::client::Config, transport: T)
737 -> ::tarpc::client::NewClient<
738 Self,
739 ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
740 >
741 where
742 T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
743 {
744 let new_client = ::tarpc::client::new(config, transport);
745 ::tarpc::client::NewClient {
746 client: #client_ident(new_client.client),
747 dispatch: new_client.dispatch,
748 }
749 }
750 }
751
752 impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
753 where Stub: ::tarpc::client::stub::Stub<
754 Req = #request_ident,
755 Resp = #response_ident>
756 {
757 fn from(stub: Stub) -> Self {
759 #client_ident(stub)
760 }
761
762 }
763 }
764 }
765
766 fn impl_client_rpc_methods(&self) -> TokenStream2 {
767 let &Self {
768 client_ident,
769 request_ident,
770 response_ident,
771 method_attrs,
772 vis,
773 method_idents,
774 args,
775 return_types,
776 arg_pats,
777 camel_case_idents,
778 ..
779 } = self;
780
781 quote! {
782 impl<Stub> #client_ident<Stub>
783 where Stub: ::tarpc::client::stub::Stub<
784 Req = #request_ident,
785 Resp = #response_ident>
786 {
787 #(
788 #[allow(unused)]
789 #( #method_attrs )*
790 #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
791 -> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
792 let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
793 let resp = self.0.call(ctx, request);
794 async move {
795 match resp.await? {
796 #response_ident::#camel_case_idents(msg) => ::core::result::Result::Ok(msg),
797 _ => ::core::unreachable!(),
798 }
799 }
800 }
801 )*
802 }
803 }
804 }
805
806 fn emit_warnings(&self) -> TokenStream2 {
807 self.warnings.iter().map(|w| w.to_token_stream()).collect()
808 }
809}
810
811impl ToTokens for ServiceGenerator<'_> {
812 fn to_tokens(&self, output: &mut TokenStream2) {
813 output.extend(vec![
814 self.trait_service(),
815 self.struct_server(),
816 self.impl_serve_for_server(),
817 self.enum_request(),
818 self.enum_response(),
819 self.struct_client(),
820 self.impl_client_new(),
821 self.impl_client_rpc_methods(),
822 self.emit_warnings(),
823 ]);
824 }
825}
826
827fn snake_to_camel(ident_str: &str) -> String {
828 let mut camel_ty = String::with_capacity(ident_str.len());
829
830 let mut last_char_was_underscore = true;
831 for c in ident_str.chars() {
832 match c {
833 '_' => last_char_was_underscore = true,
834 c if last_char_was_underscore => {
835 camel_ty.extend(c.to_uppercase());
836 last_char_was_underscore = false;
837 }
838 c => camel_ty.extend(c.to_lowercase()),
839 }
840 }
841
842 camel_ty.shrink_to_fit();
843 camel_ty
844}
845
846#[test]
847fn snake_to_camel_basic() {
848 assert_eq!(snake_to_camel("abc_def"), "AbcDef");
849}
850
851#[test]
852fn snake_to_camel_underscore_suffix() {
853 assert_eq!(snake_to_camel("abc_def_"), "AbcDef");
854}
855
856#[test]
857fn snake_to_camel_underscore_prefix() {
858 assert_eq!(snake_to_camel("_abc_def"), "AbcDef");
859}
860
861#[test]
862fn snake_to_camel_underscore_consecutive() {
863 assert_eq!(snake_to_camel("abc__def"), "AbcDef");
864}
865
866#[test]
867fn snake_to_camel_capital_in_middle() {
868 assert_eq!(snake_to_camel("aBc_dEf"), "AbcDef");
869}