1#![doc = include_str!("../README.md")]
2#![doc(test(attr(deny(warnings))))]
3#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4
5pub use prost_build as prost;
6use prost_build::{Config, Module, Service, ServiceGenerator};
7use regex::Regex;
8use std::collections::HashSet;
9use std::fmt::Write;
10use std::io::{Error, Result};
11use std::path::{Path, PathBuf};
12use std::{env, fs};
13
14#[derive(Default)]
18pub struct TwirpBuilder {
19 config: Config,
20 generator: TwirpServiceGenerator,
21 type_name_domain: Option<String>,
22}
23
24impl TwirpBuilder {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn from_prost(config: Config) -> Self {
32 Self {
33 config,
34 generator: TwirpServiceGenerator::new(),
35 type_name_domain: None,
36 }
37 }
38
39 pub fn with_client(mut self) -> Self {
41 self.generator = self.generator.with_client();
42 self
43 }
44
45 pub fn with_server(mut self) -> Self {
47 self.generator = self.generator.with_server();
48 self
49 }
50
51 pub fn with_grpc(mut self) -> Self {
53 self.generator = self.generator.with_grpc();
54 self
55 }
56
57 pub fn with_axum_request_extractor(
75 mut self,
76 name: impl Into<String>,
77 type_name: impl Into<String>,
78 ) -> Self {
79 self.generator = self.generator.with_axum_request_extractor(name, type_name);
80 self
81 }
82
83 pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
87 self.type_name_domain = Some(domain.into());
88 self
89 }
90
91 pub fn compile_protos(
93 mut self,
94 protos: &[impl AsRef<Path>],
95 includes: &[impl AsRef<Path>],
96 ) -> Result<()> {
97 let out_dir = PathBuf::from(
98 env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
99 );
100
101 for proto in protos {
103 println!("cargo:rerun-if-changed={}", proto.as_ref().display());
104 }
105 self.config
106 .enable_type_names()
107 .type_name_domain(
108 ["."],
109 self.type_name_domain
110 .as_deref()
111 .unwrap_or("type.googleapis.com"),
112 )
113 .service_generator(Box::new(self.generator));
114
115 prost_reflect_build::Builder::new()
117 .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
118 .configure(&mut self.config, protos, includes)?;
119
120 let config = self.config.skip_protoc_run();
122 let file_descriptor_set = config.load_fds(protos, includes)?;
123 let modules = file_descriptor_set
124 .file
125 .iter()
126 .map(|fd| Module::from_protobuf_package_name(fd.package()))
127 .collect::<HashSet<_>>();
128
129 config.compile_fds(file_descriptor_set)?;
131
132 let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
137
138 for module in modules {
140 let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
141 if !file_path.exists() {
142 continue; }
144 let original_content = fs::read_to_string(&file_path)?;
145
146 let mut modified_content = original_content
148 .lines()
149 .flat_map(|line| {
150 if let Some(captures) = re.captures(line) {
151 let indentation = captures.get(1).map_or("", |m| m.as_str());
152 vec![
153 line.to_string(),
154 format!(" {}{}", indentation, "#[allow(unused_imports)]"),
156 format!(
157 " {}{}",
158 indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
159 ),
160 ]
161 } else {
162 vec![line.to_string()]
163 }
164 })
165 .collect::<Vec<_>>();
166
167 modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
168 let file_content = modified_content.join("\n");
169
170 fs::write(&file_path, &file_content)?;
171 }
172
173 Ok(())
174 }
175}
176
177#[derive(Default)]
185struct TwirpServiceGenerator {
186 client: bool,
187 server: bool,
188 grpc: bool,
189 request_extractors: Vec<(String, String)>,
190}
191
192impl TwirpServiceGenerator {
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 pub fn with_client(mut self) -> Self {
198 self.client = true;
199 self
200 }
201
202 pub fn with_server(mut self) -> Self {
203 self.server = true;
204 self
205 }
206
207 pub fn with_grpc(mut self) -> Self {
208 self.grpc = true;
209 self
210 }
211
212 pub fn with_axum_request_extractor(
213 mut self,
214 name: impl Into<String>,
215 type_name: impl Into<String>,
216 ) -> Self {
217 self.request_extractors
218 .push((name.into(), type_name.into()));
219 self
220 }
221}
222
223impl ServiceGenerator for TwirpServiceGenerator {
224 fn generate(&mut self, service: Service, buf: &mut String) {
225 self.do_generate(service, buf)
226 .expect("failed to generate Twirp service")
227 }
228}
229
230impl TwirpServiceGenerator {
231 fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
232 if self.client {
233 writeln!(buf)?;
234 for comment in &service.comments.leading {
235 writeln!(buf, "/// {comment}")?;
236 }
237 if service.options.deprecated.unwrap_or(false) {
238 writeln!(buf, "#[deprecated]")?;
239 }
240 writeln!(buf, "#[derive(Clone)]")?;
241 writeln!(
242 buf,
243 "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
244 service.name
245 )?;
246 writeln!(buf, " client: ::twurst_client::TwirpHttpClient<C>")?;
247 writeln!(buf, "}}")?;
248 writeln!(buf)?;
249 writeln!(
250 buf,
251 "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
252 service.name
253 )?;
254 writeln!(
255 buf,
256 " pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
257 )?;
258 writeln!(buf, " Self {{ client: client.into() }}")?;
259 writeln!(buf, " }}")?;
260 for method in &service.methods {
261 if method.client_streaming || method.server_streaming {
262 continue; }
264 for comment in &method.comments.leading {
265 writeln!(buf, " /// {comment}")?;
266 }
267 if method.options.deprecated.unwrap_or(false) {
268 writeln!(buf, "#[deprecated]")?;
269 }
270 writeln!(
271 buf,
272 " pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
273 method.name, method.input_type, method.output_type,
274 )?;
275 writeln!(
276 buf,
277 " self.client.call(\"/{}.{}/{}\", request).await",
278 service.package, service.proto_name, method.proto_name,
279 )?;
280 writeln!(buf, " }}")?;
281 }
282 writeln!(buf, "}}")?;
283 }
284
285 if self.server {
286 writeln!(buf)?;
287 for comment in &service.comments.leading {
288 writeln!(buf, "/// {comment}")?;
289 }
290 writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
291 writeln!(buf, "pub trait {} {{", service.name)?;
292 for method in &service.methods {
293 if !self.grpc && (method.client_streaming || method.server_streaming) {
294 continue; }
296 for comment in &method.comments.leading {
297 writeln!(buf, " /// {comment}")?;
298 }
299 write!(buf, " async fn {}(&self, request: ", method.name)?;
300 if method.client_streaming {
301 write!(
302 buf,
303 "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
304 method.input_type,
305 )?;
306 } else {
307 write!(buf, "{}", method.input_type)?;
308 }
309 for (arg_name, arg_type) in &self.request_extractors {
310 write!(buf, ", {arg_name}: {arg_type}")?;
311 }
312 writeln!(buf, ") -> Result<")?;
313 if method.server_streaming {
314 writeln!(buf, "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>", method.output_type)?;
316 } else {
317 writeln!(buf, "{}", method.output_type)?;
318 }
319 writeln!(buf, ", ::twurst_server::TwirpError>;")?;
320 }
321 writeln!(buf)?;
322 writeln!(
323 buf,
324 " fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
325 )?;
326 writeln!(
327 buf,
328 " ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
329 )?;
330 for method in &service.methods {
331 if method.client_streaming || method.server_streaming {
332 writeln!(
333 buf,
334 " .route_streaming(\"/{}.{}/{}\")",
335 service.package, service.proto_name, method.proto_name,
336 )?;
337 continue;
338 }
339 write!(
340 buf,
341 " .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
342 service.package, service.proto_name, method.proto_name, method.input_type,
343 )?;
344 if self.request_extractors.is_empty() {
345 write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
346 } else {
347 write!(
348 buf,
349 ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
350 )?;
351 }
352 write!(buf, "| {{")?;
353 writeln!(buf, " async move {{")?;
354 write!(buf, " service.{}(request", method.name)?;
355 for (_name, type_name) in &self.request_extractors {
356 write!(
357 buf,
358 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
359 )?;
360 }
361 writeln!(buf, ").await")?;
362 writeln!(buf, " }}")?;
363 writeln!(buf, " }})")?;
364 }
365 writeln!(buf, " .build()")?;
366 writeln!(buf, " }}")?;
367
368 if self.grpc {
369 writeln!(buf)?;
370 writeln!(
371 buf,
372 " fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
373 )?;
374 writeln!(
375 buf,
376 " ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
377 )?;
378 for method in &service.methods {
379 let method_name = match (method.client_streaming, method.server_streaming) {
380 (false, false) => "route",
381 (false, true) => "route_server_streaming",
382 (true, false) => "route_client_streaming",
383 (true, true) => "route_streaming",
384 };
385 write!(
386 buf,
387 " .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",method_name,
388 service.package, service.proto_name, method.proto_name,
389 )?;
390 if method.client_streaming {
391 write!(
392 buf,
393 "::twurst_server::codegen::GrpcClientStream<{}>",
394 method.input_type,
395 )?;
396 } else {
397 write!(buf, "{}", method.input_type)?;
398 }
399 if self.request_extractors.is_empty() {
400 write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
401 } else {
402 write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
403 }
404 write!(buf, "| {{")?;
405 write!(buf, " async move {{")?;
406 if method.server_streaming {
407 write!(buf, "Ok(Box::into_pin(")?;
408 }
409 write!(buf, "service.{}(request", method.name)?;
410 for (_name, type_name) in &self.request_extractors {
411 write!(
412 buf,
413 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
414 )?;
415 }
416 write!(buf, ").await")?;
417 if method.server_streaming {
418 write!(buf, "?))")?;
419 }
420 writeln!(buf, "}}")?;
421 writeln!(buf, " }})")?;
422 }
423 writeln!(buf, " .build()")?;
424 writeln!(buf, " }}")?;
425 }
426
427 writeln!(buf, "}}")?;
428 }
429
430 Ok(())
431 }
432}