twurst_build/
lib.rs

1#![doc = include_str!("../README.md")]
2#![doc(
3    test(attr(deny(warnings))),
4    html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5    html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9pub use prost_build as prost;
10use prost_build::{Config, Module, Service, ServiceGenerator};
11use regex::Regex;
12use std::collections::HashSet;
13use std::fmt::Write;
14use std::io::{Error, Result};
15use std::path::{Path, PathBuf};
16use std::{env, fs};
17
18/// Builds protobuf bindings for Twirp.
19///
20/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
21#[derive(Default)]
22pub struct TwirpBuilder {
23    config: Config,
24    generator: TwirpServiceGenerator,
25    type_name_domain: Option<String>,
26}
27
28impl TwirpBuilder {
29    /// Builder with the default prost [`Config`].
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Builder with a custom prost [`Config`].
35    pub fn from_prost(config: Config) -> Self {
36        Self {
37            config,
38            generator: TwirpServiceGenerator::new(),
39            type_name_domain: None,
40        }
41    }
42
43    /// Generates code for the Twirp client.
44    pub fn with_client(mut self) -> Self {
45        self.generator = self.generator.with_client();
46        self
47    }
48
49    /// Generates code for the Twirp server.
50    pub fn with_server(mut self) -> Self {
51        self.generator = self.generator.with_server();
52        self
53    }
54
55    /// Generates code for gRPC alongside Twirp.
56    pub fn with_grpc(mut self) -> Self {
57        self.generator = self.generator.with_grpc();
58        self
59    }
60
61    /// Adds an extra parameter to generated server methods that implements [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
62    ///
63    /// For example
64    /// ```proto
65    /// message Service {
66    ///     rpc Test(TestRequest) returns (TestResponse) {}
67    /// }
68    /// ```
69    /// Compiled with option `.with_axum_request_extractor("headers", "::axum::http::HeaderMap")`
70    /// will generate the following code allowing to extract the request headers:
71    /// ```ignore
72    /// trait Service {
73    ///     async fn test(request: TestRequest, headers: ::axum::http::HeaderMap) -> Result<TestResponse, TwirpError>;
74    /// }
75    /// ```
76    ///
77    /// Note that the parameter type must implement [`axum::FromRequestParts`](https://docs.rs/axum/latest/axum/extract/trait.FromRequestParts.html).
78    pub fn with_axum_request_extractor(
79        mut self,
80        name: impl Into<String>,
81        type_name: impl Into<String>,
82    ) -> Self {
83        self.generator = self.generator.with_axum_request_extractor(name, type_name);
84        self
85    }
86
87    /// Customizes the type name domain.
88    ///
89    /// By default, 'type.googleapis.com' is used.
90    pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
91        self.type_name_domain = Some(domain.into());
92        self
93    }
94
95    /// Do compile the protos.
96    pub fn compile_protos(
97        mut self,
98        protos: &[impl AsRef<Path>],
99        includes: &[impl AsRef<Path>],
100    ) -> Result<()> {
101        let out_dir = PathBuf::from(
102            env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
103        );
104
105        // We make sure the script is executed again if a file changed
106        for proto in protos {
107            println!("cargo:rerun-if-changed={}", proto.as_ref().display());
108        }
109        self.config
110            .enable_type_names()
111            .type_name_domain(
112                ["."],
113                self.type_name_domain
114                    .as_deref()
115                    .unwrap_or("type.googleapis.com"),
116            )
117            .service_generator(Box::new(self.generator));
118
119        // We configure with prost reflect
120        prost_reflect_build::Builder::new()
121            .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
122            .configure(&mut self.config, protos, includes)?;
123
124        // We do the build itself while saving the list of modules
125        let config = self.config.skip_protoc_run();
126        let file_descriptor_set = config.load_fds(protos, includes)?;
127        let modules = file_descriptor_set
128            .file
129            .iter()
130            .map(|fd| Module::from_protobuf_package_name(fd.package()))
131            .collect::<HashSet<_>>();
132
133        // We generate the files
134        config.compile_fds(file_descriptor_set)?;
135
136        // TODO(vsiles) consider proper AST parsing in case we need to do something
137        // more robust
138        //
139        // prepare a regex to match `pub mod <module-name> {`
140        let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
141
142        // We add the file descriptor to every file to make reflection work automatically
143        for module in modules {
144            let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
145            if !file_path.exists() {
146                continue; // We ignore not built files
147            }
148            let original_content = fs::read_to_string(&file_path)?;
149
150            // scan for nested modules and insert the right FILE_DESCRIPTOR_SET_BYTES definition
151            let mut modified_content = original_content
152                .lines()
153                .flat_map(|line| {
154                    if let Some(captures) = re.captures(line) {
155                        let indentation = captures.get(1).map_or("", |m| m.as_str());
156                        vec![
157                            line.to_string(),
158                            // if there is no nested type, the next line would generate a warning
159                            format!("    {}{}", indentation, "#[allow(unused_imports)]"),
160                            format!(
161                                "    {}{}",
162                                indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
163                            ),
164                        ]
165                    } else {
166                        vec![line.to_string()]
167                    }
168                })
169                .collect::<Vec<_>>();
170
171            modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
172            let file_content = modified_content.join("\n");
173
174            fs::write(&file_path, &file_content)?;
175        }
176
177        Ok(())
178    }
179}
180
181/// Low level generator for Twirp related code.
182///
183/// This only useful if you want to customize builds. For common use cases, please use [`TwirpBuilder`].
184///
185/// Should be given to [`Config::service_generator`].
186///
187/// Client and server are not enabled by defaults and must be enabled with the [`with_client`](Self::with_client) and [`with_server`](Self::with_server) methods.
188#[derive(Default)]
189struct TwirpServiceGenerator {
190    client: bool,
191    server: bool,
192    grpc: bool,
193    request_extractors: Vec<(String, String)>,
194}
195
196impl TwirpServiceGenerator {
197    pub fn new() -> Self {
198        Self::default()
199    }
200
201    pub fn with_client(mut self) -> Self {
202        self.client = true;
203        self
204    }
205
206    pub fn with_server(mut self) -> Self {
207        self.server = true;
208        self
209    }
210
211    pub fn with_grpc(mut self) -> Self {
212        self.grpc = true;
213        self
214    }
215
216    pub fn with_axum_request_extractor(
217        mut self,
218        name: impl Into<String>,
219        type_name: impl Into<String>,
220    ) -> Self {
221        self.request_extractors
222            .push((name.into(), type_name.into()));
223        self
224    }
225}
226
227impl ServiceGenerator for TwirpServiceGenerator {
228    fn generate(&mut self, service: Service, buf: &mut String) {
229        self.do_generate(service, buf)
230            .expect("failed to generate Twirp service")
231    }
232}
233
234impl TwirpServiceGenerator {
235    fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
236        if self.client {
237            writeln!(buf)?;
238            for comment in &service.comments.leading {
239                writeln!(buf, "/// {comment}")?;
240            }
241            if service.options.deprecated.unwrap_or(false) {
242                writeln!(buf, "#[deprecated]")?;
243            }
244            writeln!(buf, "#[derive(Clone)]")?;
245            writeln!(
246                buf,
247                "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
248                service.name
249            )?;
250            writeln!(buf, "    client: ::twurst_client::TwirpHttpClient<C>")?;
251            writeln!(buf, "}}")?;
252            writeln!(buf)?;
253            writeln!(
254                buf,
255                "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
256                service.name
257            )?;
258            writeln!(
259                buf,
260                "    pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
261            )?;
262            writeln!(buf, "        Self {{ client: client.into() }}")?;
263            writeln!(buf, "    }}")?;
264            for method in &service.methods {
265                if method.client_streaming || method.server_streaming {
266                    continue; // Not supported
267                }
268                for comment in &method.comments.leading {
269                    writeln!(buf, "    /// {comment}")?;
270                }
271                if method.options.deprecated.unwrap_or(false) {
272                    writeln!(buf, "#[deprecated]")?;
273                }
274                writeln!(
275                    buf,
276                    "    pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
277                    method.name, method.input_type, method.output_type,
278                )?;
279                writeln!(
280                    buf,
281                    "        self.client.call(\"/{}.{}/{}\", request).await",
282                    service.package, service.proto_name, method.proto_name,
283                )?;
284                writeln!(buf, "    }}")?;
285            }
286            writeln!(buf, "}}")?;
287        }
288
289        if self.server {
290            writeln!(buf)?;
291            for comment in &service.comments.leading {
292                writeln!(buf, "/// {comment}")?;
293            }
294            writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
295            writeln!(buf, "pub trait {} {{", service.name)?;
296            for method in &service.methods {
297                if !self.grpc && (method.client_streaming || method.server_streaming) {
298                    continue; // No streaming
299                }
300                for comment in &method.comments.leading {
301                    writeln!(buf, "    /// {comment}")?;
302                }
303                write!(buf, "    async fn {}(&self, request: ", method.name)?;
304                if method.client_streaming {
305                    write!(
306                        buf,
307                        "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
308                        method.input_type,
309                    )?;
310                } else {
311                    write!(buf, "{}", method.input_type)?;
312                }
313                for (arg_name, arg_type) in &self.request_extractors {
314                    write!(buf, ", {arg_name}: {arg_type}")?;
315                }
316                writeln!(buf, ") -> Result<")?;
317                if method.server_streaming {
318                    // TODO: move back to `impl` when we will be able to use precise capturing to not capture &self
319                    writeln!(
320                        buf,
321                        "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>",
322                        method.output_type
323                    )?;
324                } else {
325                    writeln!(buf, "{}", method.output_type)?;
326                }
327                writeln!(buf, ", ::twurst_server::TwirpError>;")?;
328            }
329            writeln!(buf)?;
330            writeln!(
331                buf,
332                "    fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
333            )?;
334            writeln!(
335                buf,
336                "        ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
337            )?;
338            for method in &service.methods {
339                if method.client_streaming || method.server_streaming {
340                    writeln!(
341                        buf,
342                        "            .route_streaming(\"/{}.{}/{}\")",
343                        service.package, service.proto_name, method.proto_name,
344                    )?;
345                    continue;
346                }
347                write!(
348                    buf,
349                    "            .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
350                    service.package, service.proto_name, method.proto_name, method.input_type,
351                )?;
352                if self.request_extractors.is_empty() {
353                    write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
354                } else {
355                    write!(
356                        buf,
357                        ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
358                    )?;
359                }
360                write!(buf, "| {{")?;
361                writeln!(buf, "                async move {{")?;
362                write!(buf, "                    service.{}(request", method.name)?;
363                for (_name, type_name) in &self.request_extractors {
364                    write!(
365                        buf,
366                        ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
367                    )?;
368                }
369                writeln!(buf, ").await")?;
370                writeln!(buf, "                }}")?;
371                writeln!(buf, "            }})")?;
372            }
373            writeln!(buf, "            .build()")?;
374            writeln!(buf, "    }}")?;
375
376            if self.grpc {
377                writeln!(buf)?;
378                writeln!(
379                    buf,
380                    "    fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
381                )?;
382                writeln!(
383                    buf,
384                    "        ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
385                )?;
386                for method in &service.methods {
387                    let method_name = match (method.client_streaming, method.server_streaming) {
388                        (false, false) => "route",
389                        (false, true) => "route_server_streaming",
390                        (true, false) => "route_client_streaming",
391                        (true, true) => "route_streaming",
392                    };
393                    write!(
394                        buf,
395                        "            .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",
396                        method_name, service.package, service.proto_name, method.proto_name,
397                    )?;
398                    if method.client_streaming {
399                        write!(
400                            buf,
401                            "::twurst_server::codegen::GrpcClientStream<{}>",
402                            method.input_type,
403                        )?;
404                    } else {
405                        write!(buf, "{}", method.input_type)?;
406                    }
407                    if self.request_extractors.is_empty() {
408                        write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
409                    } else {
410                        write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
411                    }
412                    write!(buf, "| {{")?;
413                    write!(buf, "                async move {{")?;
414                    if method.server_streaming {
415                        write!(buf, "Ok(Box::into_pin(")?;
416                    }
417                    write!(buf, "service.{}(request", method.name)?;
418                    for (_name, type_name) in &self.request_extractors {
419                        write!(
420                            buf,
421                            ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
422                        )?;
423                    }
424                    write!(buf, ").await")?;
425                    if method.server_streaming {
426                        write!(buf, "?))")?;
427                    }
428                    writeln!(buf, "}}")?;
429                    writeln!(buf, "            }})")?;
430                }
431                writeln!(buf, "            .build()")?;
432                writeln!(buf, "    }}")?;
433            }
434
435            writeln!(buf, "}}")?;
436        }
437
438        Ok(())
439    }
440}