1#![doc = include_str!("../README.md")]
2#![doc(
3 test(attr(deny(warnings))),
4 html_favicon_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png",
5 html_logo_url = "https://raw.githubusercontent.com/helsing-ai/twurst/main/docs/img/twurst.png"
6)]
7#![cfg_attr(docsrs, feature(doc_auto_cfg))]
8
9use self::proto_path_map::ProtoPathMap;
10pub use prost_build as prost;
11use prost_build::{Config, Module, Service, ServiceGenerator};
12use regex::Regex;
13use std::collections::HashSet;
14use std::fmt::Write;
15use std::io::{Error, Result};
16use std::path::{Path, PathBuf};
17use std::{env, fs};
18
19mod proto_path_map;
20
21#[derive(Default)]
25pub struct TwirpBuilder {
26 config: Config,
27 generator: TwirpServiceGenerator,
28 type_name_domain: Option<String>,
29}
30
31impl TwirpBuilder {
32 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn from_prost(config: Config) -> Self {
39 Self {
40 config,
41 generator: TwirpServiceGenerator::new(),
42 type_name_domain: None,
43 }
44 }
45
46 pub fn with_client(mut self) -> Self {
48 self.generator = self.generator.with_client();
49 self
50 }
51
52 pub fn with_server(mut self) -> Self {
54 self.generator = self.generator.with_server();
55 self
56 }
57
58 pub fn with_grpc(mut self) -> Self {
60 self.generator = self.generator.with_grpc();
61 self
62 }
63
64 #[deprecated(
65 since = "0.3.1",
66 note = "replaced with with_default_axum_request_extractor"
67 )]
68 pub fn with_axum_request_extractor(
69 self,
70 name: impl Into<String>,
71 type_name: impl Into<String>,
72 ) -> Self {
73 self.with_default_axum_request_extractor(name, type_name)
74 }
75
76 pub fn with_default_axum_request_extractor(
97 mut self,
98 name: impl Into<String>,
99 type_name: impl Into<String>,
100 ) -> Self {
101 self.generator = self
102 .generator
103 .with_default_axum_request_extractor(name, type_name);
104 self
105 }
106
107 pub fn with_service_specific_axum_request_extractor(
198 mut self,
199 name: impl Into<String>,
200 type_name: impl Into<String>,
201 service_path: impl Into<String>,
202 ) -> Self {
203 self.generator = self.generator.with_service_specific_axum_request_extractor(
204 name,
205 type_name,
206 service_path,
207 );
208 self
209 }
210
211 pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
215 self.type_name_domain = Some(domain.into());
216 self
217 }
218
219 pub fn compile_protos(
221 mut self,
222 protos: &[impl AsRef<Path>],
223 includes: &[impl AsRef<Path>],
224 ) -> Result<()> {
225 let out_dir = PathBuf::from(
226 env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
227 );
228
229 for proto in protos {
231 println!("cargo:rerun-if-changed={}", proto.as_ref().display());
232 }
233
234 self.config
235 .enable_type_names()
236 .type_name_domain(
237 ["."],
238 self.type_name_domain
239 .as_deref()
240 .unwrap_or("type.googleapis.com"),
241 )
242 .service_generator(Box::new(self.generator));
243
244 prost_reflect_build::Builder::new()
246 .file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
247 .configure(&mut self.config, protos, includes)?;
248
249 let config = self.config.skip_protoc_run();
251 let file_descriptor_set = config.load_fds(protos, includes)?;
252 let modules = file_descriptor_set
253 .file
254 .iter()
255 .map(|fd| Module::from_protobuf_package_name(fd.package()))
256 .collect::<HashSet<_>>();
257
258 config.compile_fds(file_descriptor_set)?;
260
261 let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
266
267 for module in modules {
269 let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
270 if !file_path.exists() {
271 continue; }
273 let original_content = fs::read_to_string(&file_path)?;
274
275 let mut modified_content = original_content
277 .lines()
278 .flat_map(|line| {
279 if let Some(captures) = re.captures(line) {
280 let indentation = captures.get(1).map_or("", |m| m.as_str());
281 vec![
282 line.to_string(),
283 format!(" {}{}", indentation, "#[allow(unused_imports)]"),
285 format!(
286 " {}{}",
287 indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
288 ),
289 ]
290 } else {
291 vec![line.to_string()]
292 }
293 })
294 .collect::<Vec<_>>();
295
296 modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
297 let file_content = modified_content.join("\n");
298
299 fs::write(&file_path, &file_content)?;
300 }
301
302 Ok(())
303 }
304}
305
306#[derive(Default)]
314struct TwirpServiceGenerator {
315 client: bool,
316 server: bool,
317 grpc: bool,
318 default_request_extractors: Vec<(String, String)>,
320 matched_request_extractors: ProtoPathMap<(String, String)>,
322}
323
324impl TwirpServiceGenerator {
325 pub fn new() -> Self {
326 Self::default()
327 }
328
329 pub fn with_client(mut self) -> Self {
330 self.client = true;
331 self
332 }
333
334 pub fn with_server(mut self) -> Self {
335 self.server = true;
336 self
337 }
338
339 pub fn with_grpc(mut self) -> Self {
340 self.grpc = true;
341 self
342 }
343
344 pub fn with_default_axum_request_extractor(
345 mut self,
346 name: impl Into<String>,
347 type_name: impl Into<String>,
348 ) -> Self {
349 self.default_request_extractors
350 .push((name.into(), type_name.into()));
351 self
352 }
353
354 pub fn with_service_specific_axum_request_extractor(
356 mut self,
357 name: impl Into<String>,
358 type_name: impl Into<String>,
359 service_proto_path: impl Into<String>,
360 ) -> Self {
361 self.matched_request_extractors
362 .insert(service_proto_path.into(), (name.into(), type_name.into()));
363 self
364 }
365}
366
367impl ServiceGenerator for TwirpServiceGenerator {
368 fn generate(&mut self, service: Service, buf: &mut String) {
369 self.do_generate(service, buf)
370 .expect("failed to generate Twirp service")
371 }
372}
373
374impl TwirpServiceGenerator {
375 fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
376 let mut service_matches = self
377 .matched_request_extractors
378 .service_matches(&service)
379 .peekable();
380
381 let extractors: Vec<_> = if service_matches.peek().is_some() {
382 service_matches.collect()
383 } else {
384 self.default_request_extractors.iter().collect()
385 };
386
387 if self.client {
388 writeln!(buf)?;
389 for comment in &service.comments.leading {
390 writeln!(buf, "/// {comment}")?;
391 }
392 if service.options.deprecated.unwrap_or(false) {
393 writeln!(buf, "#[deprecated]")?;
394 }
395 writeln!(buf, "#[derive(Clone)]")?;
396 writeln!(
397 buf,
398 "pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
399 service.name
400 )?;
401 writeln!(buf, " client: ::twurst_client::TwirpHttpClient<C>")?;
402 writeln!(buf, "}}")?;
403 writeln!(buf)?;
404 writeln!(
405 buf,
406 "impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
407 service.name
408 )?;
409 writeln!(
410 buf,
411 " pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
412 )?;
413 writeln!(buf, " Self {{ client: client.into() }}")?;
414 writeln!(buf, " }}")?;
415 for method in &service.methods {
416 if method.client_streaming || method.server_streaming {
417 continue; }
419 for comment in &method.comments.leading {
420 writeln!(buf, " /// {comment}")?;
421 }
422 if method.options.deprecated.unwrap_or(false) {
423 writeln!(buf, "#[deprecated]")?;
424 }
425 writeln!(
426 buf,
427 " pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
428 method.name, method.input_type, method.output_type,
429 )?;
430 writeln!(
431 buf,
432 " self.client.call(\"/{}.{}/{}\", request).await",
433 service.package, service.proto_name, method.proto_name,
434 )?;
435 writeln!(buf, " }}")?;
436 }
437 writeln!(buf, "}}")?;
438 }
439
440 if self.server {
441 writeln!(buf)?;
442 for comment in &service.comments.leading {
443 writeln!(buf, "/// {comment}")?;
444 }
445 writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
446 writeln!(buf, "pub trait {} {{", service.name)?;
447 for method in &service.methods {
448 if !self.grpc && (method.client_streaming || method.server_streaming) {
449 continue; }
451 for comment in &method.comments.leading {
452 writeln!(buf, " /// {comment}")?;
453 }
454 write!(buf, " async fn {}(&self, request: ", method.name)?;
455 if method.client_streaming {
456 write!(
457 buf,
458 "impl ::twurst_server::codegen::Stream<Item=Result<{},::twurst_client::TwirpError>> + Send + 'static",
459 method.input_type,
460 )?;
461 } else {
462 write!(buf, "{}", method.input_type)?;
463 }
464 for (arg_name, arg_type) in &extractors {
465 write!(buf, ", {arg_name}: {arg_type}")?;
466 }
467 writeln!(buf, ") -> Result<")?;
468 if method.server_streaming {
469 writeln!(
471 buf,
472 "Box<dyn ::twurst_server::codegen::Stream<Item=Result<{}, ::twurst_server::TwirpError>> + Send>",
473 method.output_type
474 )?;
475 } else {
476 writeln!(buf, "{}", method.output_type)?;
477 }
478 writeln!(buf, ", ::twurst_server::TwirpError>;")?;
479 }
480 writeln!(buf)?;
481 writeln!(
482 buf,
483 " fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
484 )?;
485 writeln!(
486 buf,
487 " ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
488 )?;
489 for method in &service.methods {
490 if method.client_streaming || method.server_streaming {
491 writeln!(
492 buf,
493 " .route_streaming(\"/{}.{}/{}\")",
494 service.package, service.proto_name, method.proto_name,
495 )?;
496 continue;
497 }
498 write!(
499 buf,
500 " .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
501 service.package, service.proto_name, method.proto_name, method.input_type,
502 )?;
503 if extractors.is_empty() {
504 write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
505 } else {
506 write!(
507 buf,
508 ", mut parts: ::twurst_server::codegen::RequestParts, state: S",
509 )?;
510 }
511 write!(buf, "| {{")?;
512 writeln!(buf, " async move {{")?;
513 write!(buf, " service.{}(request", method.name)?;
514 for (_name, type_name) in &extractors {
515 write!(
516 buf,
517 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
518 )?;
519 }
520 writeln!(buf, ").await")?;
521 writeln!(buf, " }}")?;
522 writeln!(buf, " }})")?;
523 }
524 writeln!(buf, " .build()")?;
525 writeln!(buf, " }}")?;
526
527 if self.grpc {
528 writeln!(buf)?;
529 writeln!(
530 buf,
531 " fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
532 )?;
533 writeln!(
534 buf,
535 " ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
536 )?;
537 for method in &service.methods {
538 let method_name = match (method.client_streaming, method.server_streaming) {
539 (false, false) => "route",
540 (false, true) => "route_server_streaming",
541 (true, false) => "route_client_streaming",
542 (true, true) => "route_streaming",
543 };
544 write!(
545 buf,
546 " .{}(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: ",
547 method_name, service.package, service.proto_name, method.proto_name,
548 )?;
549 if method.client_streaming {
550 write!(
551 buf,
552 "::twurst_server::codegen::GrpcClientStream<{}>",
553 method.input_type,
554 )?;
555 } else {
556 write!(buf, "{}", method.input_type)?;
557 }
558 if extractors.is_empty() {
559 write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
560 } else {
561 write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
562 }
563 write!(buf, "| {{")?;
564 write!(buf, " async move {{")?;
565 if method.server_streaming {
566 write!(buf, "Ok(Box::into_pin(")?;
567 }
568 write!(buf, "service.{}(request", method.name)?;
569 for (_name, type_name) in &extractors {
570 write!(
571 buf,
572 ", match <{type_name} as ::twurst_server::codegen::FromRequestParts<_>>::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
573 )?;
574 }
575 write!(buf, ").await")?;
576 if method.server_streaming {
577 write!(buf, "?))")?;
578 }
579 writeln!(buf, "}}")?;
580 writeln!(buf, " }})")?;
581 }
582 writeln!(buf, " .build()")?;
583 writeln!(buf, " }}")?;
584 }
585
586 writeln!(buf, "}}")?;
587 }
588
589 Ok(())
590 }
591}