twirp_build_rs/
lib.rs

1use std::fmt::Write;
2
3/// Generates twirp services for protobuf rpc service definitions.
4///
5/// In your `build.rs`, using `prost_build`, you can wire in the twirp
6/// `ServiceGenerator` to produce a Rust server for your proto services.
7///
8/// Add a call to `.service_generator(twirp_build::service_generator())` in
9/// main() of `build.rs`.
10pub fn service_generator() -> Box<ServiceGenerator> {
11    Box::new(ServiceGenerator {})
12}
13
14pub struct ServiceGenerator;
15
16impl prost_build::ServiceGenerator for ServiceGenerator {
17    fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
18        let service_name = service.name;
19        let service_fqn = format!("{}.{}", service.package, service.proto_name);
20        writeln!(buf).unwrap();
21
22        writeln!(buf, "pub use twirp;").unwrap();
23        writeln!(buf).unwrap();
24        writeln!(buf, "pub const SERVICE_FQN: &str = \"/{service_fqn}\";").unwrap();
25
26        //
27        // generate the twirp server
28        //
29        writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap();
30        writeln!(buf, "pub trait {} {{", service_name).unwrap();
31        for m in &service.methods {
32            writeln!(
33                buf,
34                "    async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
35                m.name, m.input_type, m.output_type,
36            )
37            .unwrap();
38        }
39        writeln!(buf, "}}").unwrap();
40
41        // add_service
42        writeln!(
43            buf,
44            r#"pub fn router<T>(api: std::sync::Arc<T>) -> twirp::Router
45where
46    T: {service_name} + Send + Sync + 'static,
47{{
48    twirp::details::TwirpRouterBuilder::new(api)"#,
49        )
50        .unwrap();
51        for m in &service.methods {
52            let uri = &m.proto_name;
53            let req_type = &m.input_type;
54            let rust_method_name = &m.name;
55            writeln!(
56                buf,
57                r#"        .route("/{uri}", |api: std::sync::Arc<T>, ctx: twirp::Context, req: {req_type}| async move {{
58            api.{rust_method_name}(ctx, req).await
59        }})"#,
60            )
61            .unwrap();
62        }
63        writeln!(
64            buf,
65            r#"
66        .build()
67}}"#
68        )
69        .unwrap();
70
71        //
72        // generate the twirp client
73        //
74        writeln!(buf).unwrap();
75        writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap();
76        writeln!(
77            buf,
78            "pub trait {service_name}Client: Send + Sync + std::fmt::Debug {{",
79        )
80        .unwrap();
81        for m in &service.methods {
82            // Define: <METHOD>
83            writeln!(
84                buf,
85                "    async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError>;",
86                m.name, m.input_type, m.output_type,
87            )
88            .unwrap();
89        }
90        writeln!(buf, "}}").unwrap();
91
92        // Implement the rpc traits for: `twirp::client::Client`
93        writeln!(buf, "#[twirp::async_trait::async_trait]").unwrap();
94        writeln!(
95            buf,
96            "impl {service_name}Client for twirp::client::Client {{",
97        )
98        .unwrap();
99        for m in &service.methods {
100            // Define the rpc `<METHOD>`
101            writeln!(
102                buf,
103                "    async fn {}(&self, req: {}) -> Result<{}, twirp::ClientError> {{",
104                m.name, m.input_type, m.output_type,
105            )
106            .unwrap();
107            writeln!(
108                buf,
109                r#"    let url = self.base_url.join("{}/{}")?;"#,
110                service_fqn, m.proto_name,
111            )
112            .unwrap();
113            writeln!(buf, "    self.request(url, req).await",).unwrap();
114            writeln!(buf, "    }}").unwrap();
115        }
116        writeln!(buf, "}}").unwrap();
117    }
118}