twirp_build/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use quote::format_ident;
4use syn::parse_quote;
5
6/// Generates twirp services for protobuf rpc service definitions.
7///
8/// In your `build.rs`, using `prost_build`, you can wire in the twirp
9/// `ServiceGenerator` to produce a Rust server for your proto services.
10///
11/// Add a call to `.service_generator(twirp_build::service_generator())` in
12/// main() of `build.rs`.
13pub fn service_generator() -> Box<ServiceGenerator> {
14    Box::new(ServiceGenerator {})
15}
16
17struct Service {
18    /// The name of the server trait, as parsed into a Rust identifier.
19    rpc_trait_name: syn::Ident,
20
21    /// The fully qualified protobuf name of this Service.
22    fqn: String,
23
24    /// The methods that make up this service.
25    methods: Vec<Method>,
26}
27
28struct Method {
29    /// The name of the method, as parsed into a Rust identifier.
30    name: syn::Ident,
31
32    /// The name of the method as it appears in the protobuf definition.
33    proto_name: String,
34
35    /// The input type of this method.
36    input_type: syn::Type,
37
38    /// The output type of this method.
39    output_type: syn::Type,
40}
41
42impl Service {
43    fn from_prost(s: prost_build::Service) -> Self {
44        let fqn = format!("{}.{}", s.package, s.proto_name);
45        let rpc_trait_name = format_ident!("{}", &s.name);
46        let methods = s
47            .methods
48            .into_iter()
49            .map(|m| Method::from_prost(&s.package, &s.proto_name, m))
50            .collect();
51
52        Self {
53            rpc_trait_name,
54            fqn,
55            methods,
56        }
57    }
58}
59
60impl Method {
61    fn from_prost(pkg_name: &str, svc_name: &str, m: prost_build::Method) -> Self {
62        let as_type = |s| -> syn::Type {
63            let Ok(typ) = syn::parse_str::<syn::Type>(s) else {
64                panic!(
65                    "twirp-build failed generated invalid Rust while processing {pkg}.{svc}/{name}). this is a bug in twirp-build, please file a GitHub issue",
66                    pkg = pkg_name,
67                    svc = svc_name,
68                    name = m.proto_name,
69                );
70            };
71            typ
72        };
73
74        let input_type = as_type(&m.input_type);
75        let output_type = as_type(&m.output_type);
76        let name = format_ident!("{}", m.name);
77        let message = m.proto_name;
78
79        Self {
80            name,
81            proto_name: message,
82            input_type,
83            output_type,
84        }
85    }
86}
87
88pub struct ServiceGenerator;
89
90impl prost_build::ServiceGenerator for ServiceGenerator {
91    fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
92        let service = Service::from_prost(service);
93
94        // generate the twirp server
95        let service_fqn_path = format!("/{}", service.fqn);
96        let mut trait_methods: Vec<syn::TraitItemFn> = Vec::with_capacity(service.methods.len());
97        let mut proxy_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
98        for m in &service.methods {
99            let name = &m.name;
100            let input_type = &m.input_type;
101            let output_type = &m.output_type;
102
103            trait_methods.push(parse_quote! {
104                async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>>;
105            });
106
107            proxy_methods.push(parse_quote! {
108                async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>> {
109                    T::#name(&*self, req).await
110                }
111            });
112        }
113
114        let rpc_trait_name = &service.rpc_trait_name;
115        let server_trait: syn::ItemTrait = parse_quote! {
116            #[twirp::async_trait::async_trait]
117            pub trait #rpc_trait_name: Send + Sync {
118                #(#trait_methods)*
119            }
120        };
121        let server_trait_impl: syn::ItemImpl = parse_quote! {
122            #[twirp::async_trait::async_trait]
123            impl<T> #rpc_trait_name for std::sync::Arc<T>
124            where
125                T: #rpc_trait_name + Sync + Send
126            {
127                #(#proxy_methods)*
128            }
129        };
130
131        // generate the router
132        let mut expr: syn::Expr = parse_quote! {
133            twirp::details::TwirpRouterBuilder::new(#service_fqn_path, api)
134        };
135        for m in &service.methods {
136            let name = &m.name;
137            let input_type = &m.input_type;
138            let path = format!("/{}", m.proto_name);
139
140            expr = parse_quote! {
141                #expr.route(#path, |api: T, req: twirp::Request<#input_type>| async move {
142                    api.#name(req).await
143                })
144            };
145        }
146        let router: syn::ItemFn = parse_quote! {
147            pub fn router<T>(api: T) -> twirp::Router
148                where
149                    T: #rpc_trait_name + Clone + Send + Sync + 'static
150                {
151                    #expr.build()
152                }
153        };
154
155        //
156        // generate the twirp client
157        //
158        let mut client_methods: Vec<syn::ImplItemFn> = Vec::with_capacity(service.methods.len());
159        for m in &service.methods {
160            let name = &m.name;
161            let input_type = &m.input_type;
162            let output_type = &m.output_type;
163            let request_path = format!("{}/{}", service.fqn, m.proto_name);
164
165            client_methods.push(parse_quote! {
166                async fn #name(&self, req: twirp::Request<#input_type>) -> twirp::Result<twirp::Response<#output_type>> {
167                    self.request(#request_path, req).await
168                }
169            })
170        }
171        let client_trait: syn::ItemImpl = parse_quote! {
172            #[twirp::async_trait::async_trait]
173            impl #rpc_trait_name for twirp::client::Client {
174                #(#client_methods)*
175            }
176        };
177
178        //
179        // generate the client mock helpers
180        //
181        // TODO: Gate this code on a feature flag e.g. `std::env::var("CARGO_CFG_FEATURE_<FEATURE>").is_ok()`
182        //
183        let service_fqn = &service.fqn;
184        let handler_name = format_ident!("{rpc_trait_name}Handler");
185        let handler_struct: syn::ItemStruct = parse_quote! {
186            pub struct #handler_name {
187                inner: std::sync::Arc<dyn #rpc_trait_name>,
188            }
189        };
190        let mut method_matches: Vec<syn::Arm> = Vec::with_capacity(service.methods.len());
191        for m in &service.methods {
192            let name = &m.name;
193            let method = &m.proto_name;
194            method_matches.push(parse_quote! {
195                #method => {
196                    twirp::details::encode_response(self.inner.#name(twirp::details::decode_request(req).await?).await?)
197                }
198            });
199        }
200        let handler_impl: syn::ItemImpl = parse_quote! {
201            impl #handler_name {
202                #[allow(clippy::new_ret_no_self)]
203                pub fn new<M: #rpc_trait_name + 'static>(inner: M) -> Self {
204                    Self { inner: std::sync::Arc::new(inner) }
205                }
206            }
207
208        };
209        let handler_direct_impl: syn::ItemImpl = parse_quote! {
210            #[twirp::async_trait::async_trait]
211            impl twirp::client::DirectHandler for #handler_name {
212                fn service(&self) -> &str {
213                    #service_fqn
214                }
215                async fn handle(&self, method: &str, req: twirp::reqwest::Request) -> twirp::Result<twirp::reqwest::Response> {
216                    match method {
217                        #(#method_matches)*
218                        _ => Err(twirp::bad_route(format!("unknown rpc `{method}` for service `{}`, url: {:?}", #service_fqn, req.url()))),
219                    }
220                }
221            }
222        };
223        let direct_api_handler: syn::ItemMod = parse_quote! {
224            #[allow(dead_code)]
225            pub mod handler {
226                use super::*;
227
228                #handler_struct
229                #handler_impl
230                #handler_direct_impl
231            }
232        };
233
234        // generate the service and client as a single file. run it through
235        // prettyplease before outputting it.
236        let ast: syn::File = parse_quote! {
237            pub use twirp;
238
239            #server_trait
240            #server_trait_impl
241
242            #router
243
244            #client_trait
245
246            #direct_api_handler
247        };
248
249        let code = prettyplease::unparse(&ast);
250        buf.push_str(&code);
251    }
252}