#![doc = include_str!("../README.md")]
#![doc(test(attr(deny(warnings))))]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
pub use prost_build as prost;
use prost_build::{Config, Module, Service, ServiceGenerator};
use regex::Regex;
use std::collections::HashSet;
use std::fmt::Write;
use std::io::{Error, Result};
use std::path::{Path, PathBuf};
use std::{env, fs};
#[derive(Default)]
pub struct TwirpBuilder {
config: Config,
generator: TwirpServiceGenerator,
type_name_domain: Option<String>,
}
impl TwirpBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn from_prost(config: Config) -> Self {
Self {
config,
generator: TwirpServiceGenerator::new(),
type_name_domain: None,
}
}
pub fn with_client(mut self) -> Self {
self.generator = self.generator.with_client();
self
}
pub fn with_server(mut self) -> Self {
self.generator = self.generator.with_server();
self
}
pub fn with_axum_request_extractor(
mut self,
name: impl Into<String>,
type_name: impl Into<String>,
) -> Self {
self.generator = self.generator.with_axum_request_extractor(name, type_name);
self
}
pub fn with_type_name_domain(mut self, domain: impl Into<String>) -> Self {
self.type_name_domain = Some(domain.into());
self
}
pub fn compile_protos(
mut self,
protos: &[impl AsRef<Path>],
includes: &[impl AsRef<Path>],
) -> Result<()> {
let out_dir = PathBuf::from(
env::var_os("OUT_DIR").ok_or_else(|| Error::other("OUT_DIR is not set"))?,
);
for proto in protos {
println!("cargo:rerun-if-changed={}", proto.as_ref().display());
}
self.config
.enable_type_names()
.type_name_domain(
["."],
self.type_name_domain
.as_deref()
.unwrap_or("type.googleapis.com"),
)
.service_generator(Box::new(self.generator));
prost_reflect_build::Builder::new()
.file_descriptor_set_bytes("self::FILE_DESCRIPTOR_SET_BYTES")
.configure(&mut self.config, protos, includes)?;
let config = self.config.skip_protoc_run();
let file_descriptor_set = config.load_fds(protos, includes)?;
let modules = file_descriptor_set
.file
.iter()
.map(|fd| Module::from_protobuf_package_name(fd.package()))
.collect::<HashSet<_>>();
for file_descriptor in &file_descriptor_set.file {
for service in &file_descriptor.service {
for method in &service.method {
if method.client_streaming() {
return Err(Error::other(format!(
"Client streaming is not supported in method {} of service {} in file {}",
method.name(), service.name(), file_descriptor.name()
)));
}
if method.server_streaming() {
return Err(Error::other(format!(
"Server streaming is not supported in method {} of service {} in file {}",
method.name(), service.name(), file_descriptor.name()
)));
}
}
}
}
config.compile_fds(file_descriptor_set)?;
let re = Regex::new(r"^(\s*)pub mod \w+ \{\s*$").expect("Failed to compile regex");
for module in modules {
let file_path = Path::new(&out_dir).join(module.to_file_name_or("_"));
if !file_path.exists() {
continue; }
let original_content = fs::read_to_string(&file_path)?;
let mut modified_content = original_content
.lines()
.flat_map(|line| {
if let Some(captures) = re.captures(line) {
let indentation = captures.get(1).map_or("", |m| m.as_str());
vec![
line.to_string(),
format!(" {}{}", indentation, "#[allow(unused_imports)]"),
format!(
" {}{}",
indentation, "use super::FILE_DESCRIPTOR_SET_BYTES;"
),
]
} else {
vec![line.to_string()]
}
})
.collect::<Vec<_>>();
modified_content.push("const FILE_DESCRIPTOR_SET_BYTES: &[u8] = include_bytes!(\"file_descriptor_set.bin\");\n".to_string());
let file_content = modified_content.join("\n");
fs::write(&file_path, &file_content)?;
}
Ok(())
}
}
#[derive(Default)]
struct TwirpServiceGenerator {
client: bool,
server: bool,
request_extractors: Vec<(String, String)>,
}
impl TwirpServiceGenerator {
pub fn new() -> Self {
Self::default()
}
pub fn with_client(mut self) -> Self {
self.client = true;
self
}
pub fn with_server(mut self) -> Self {
self.server = true;
self
}
pub fn with_axum_request_extractor(
mut self,
name: impl Into<String>,
type_name: impl Into<String>,
) -> Self {
self.request_extractors
.push((name.into(), type_name.into()));
self
}
}
impl ServiceGenerator for TwirpServiceGenerator {
fn generate(&mut self, service: Service, buf: &mut String) {
self.do_generate(service, buf)
.expect("failed to generate Twirp service")
}
}
impl TwirpServiceGenerator {
fn do_generate(&mut self, service: Service, buf: &mut String) -> std::fmt::Result {
if self.client {
writeln!(buf)?;
for comment in &service.comments.leading {
writeln!(buf, "/// {comment}")?;
}
if service.options.deprecated.unwrap_or(false) {
writeln!(buf, "#[deprecated]")?;
}
writeln!(buf, "#[derive(Clone)]")?;
writeln!(
buf,
"pub struct {}Client<C: ::twurst_client::TwirpHttpService> {{",
service.name
)?;
writeln!(buf, " client: ::twurst_client::TwirpHttpClient<C>")?;
writeln!(buf, "}}")?;
writeln!(buf)?;
writeln!(
buf,
"impl<C: ::twurst_client::TwirpHttpService> {}Client<C> {{",
service.name
)?;
writeln!(
buf,
" pub fn new(client: impl Into<::twurst_client::TwirpHttpClient<C>>) -> Self {{"
)?;
writeln!(buf, " Self {{ client: client.into() }}")?;
writeln!(buf, " }}")?;
for method in &service.methods {
for comment in &method.comments.leading {
writeln!(buf, " /// {comment}")?;
}
if method.options.deprecated.unwrap_or(false) {
writeln!(buf, "#[deprecated]")?;
}
writeln!(
buf,
" pub async fn {}(&self, request: &{}) -> Result<{}, ::twurst_client::TwirpError> {{",
method.name, method.input_type, method.output_type,
)?;
writeln!(
buf,
" self.client.call(\"/{}.{}/{}\", request).await",
service.package, service.proto_name, method.proto_name,
)?;
writeln!(buf, " }}")?;
}
writeln!(buf, "}}")?;
}
if self.server {
writeln!(buf)?;
for comment in &service.comments.leading {
writeln!(buf, "/// {comment}")?;
}
writeln!(buf, "#[::twurst_server::codegen::trait_variant_make(Send)]")?;
writeln!(buf, "pub trait {} {{", service.name)?;
for method in &service.methods {
for comment in &method.comments.leading {
writeln!(buf, " /// {comment}")?;
}
write!(
buf,
" async fn {}(&self, request: {}",
method.name, method.input_type
)?;
for (arg_name, arg_type) in &self.request_extractors {
write!(buf, ", {arg_name}: {arg_type}")?;
}
writeln!(
buf,
") -> Result<{}, ::twurst_server::TwirpError>;",
method.output_type
)?;
}
writeln!(buf)?;
writeln!(
buf,
" fn into_router<S: Clone + Send + Sync + 'static>(self) -> ::twurst_server::codegen::Router<S> where Self : Sized + Send + Sync + 'static {{"
)?;
writeln!(
buf,
" ::twurst_server::codegen::TwirpRouter::new(::std::sync::Arc::new(self))"
)?;
for method in &service.methods {
write!(
buf,
" .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
service.package, service.proto_name, method.proto_name, method.input_type,
)?;
if self.request_extractors.is_empty() {
write!(buf, ", _: ::twurst_server::codegen::RequestParts, _: S")?;
} else {
write!(
buf,
", mut parts: ::twurst_server::codegen::RequestParts, state: S",
)?;
}
write!(buf, "| {{")?;
writeln!(buf, " async move {{")?;
write!(buf, " service.{}(request", method.name)?;
for _ in 0..self.request_extractors.len() {
write!(
buf,
", match ::twurst_server::codegen::FromRequestParts::from_request_parts(&mut parts, &state).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
)?;
}
writeln!(buf, ").await")?;
writeln!(buf, " }}")?;
writeln!(buf, " }})")?;
}
writeln!(buf, " .build()")?;
writeln!(buf, " }}")?;
if cfg!(feature = "grpc") {
writeln!(buf)?;
writeln!(
buf,
" fn into_grpc_router(self) -> ::twurst_server::codegen::Router where Self : Sized + Send + Sync + 'static {{"
)?;
writeln!(
buf,
" ::twurst_server::codegen::GrpcRouter::new(::std::sync::Arc::new(self))"
)?;
for method in &service.methods {
write!(
buf,
" .route(\"/{}.{}/{}\", |service: ::std::sync::Arc<Self>, request: {}",
service.package, service.proto_name, method.proto_name, method.input_type,
)?;
if self.request_extractors.is_empty() {
write!(buf, ", _: ::twurst_server::codegen::RequestParts")?;
} else {
write!(buf, ", mut parts: ::twurst_server::codegen::RequestParts")?;
}
write!(buf, "| {{")?;
writeln!(buf, " async move {{")?;
write!(buf, " service.{}(request", method.name)?;
for _ in 0..self.request_extractors.len() {
write!(
buf,
", match ::twurst_server::codegen::FromRequestParts::from_request_parts(&mut parts, &()).await {{ Ok(r) => r, Err(e) => {{ return Err(::twurst_server::codegen::twirp_error_from_response(e).await) }} }}"
)?;
}
writeln!(buf, ").await")?;
writeln!(buf, " }}")?;
writeln!(buf, " }})")?;
}
writeln!(buf, " .build()")?;
writeln!(buf, " }}")?;
}
writeln!(buf, "}}")?;
}
Ok(())
}
}