Skip to main content

tonic_rest_build/codegen/
config.rs

1//! Configuration for REST route code generation.
2
3use std::collections::{HashMap, HashSet};
4
5/// Error returned by [`generate`](super::generate).
6#[derive(Debug, thiserror::Error)]
7#[non_exhaustive]
8pub enum GenerateError {
9    /// Proto `FileDescriptorSet` decoding failure.
10    #[error("failed to decode FileDescriptorSet: {0}")]
11    ProtoDecode(#[from] prost::DecodeError),
12
13    /// A nested path param (e.g., `{user_id.value}`) was found but
14    /// [`RestCodegenConfig::wrapper_type`] is not configured.
15    #[error(
16        "nested path param '{{{param}}}' requires wrapper_type to be configured. \
17         Call .wrapper_type(\"path::to::Uuid\") on RestCodegenConfig."
18    )]
19    MissingWrapperType {
20        /// The nested path parameter that triggered the error (e.g., `user_id.value`).
21        param: String,
22    },
23
24    /// Partial body selector (body is a field name, not `"*"`).
25    ///
26    /// Currently only `body: "*"` (whole message) is supported.
27    /// Field-level body selectors like `body: "user"` require sub-message
28    /// deserialization that is not yet implemented.
29    #[error(
30        "partial body selector `{body}` in method `{method}` is not supported; \
31         use `body: \"*\"` instead"
32    )]
33    UnsupportedBodySelector {
34        /// The RPC method name.
35        method: String,
36        /// The unsupported body selector value.
37        body: String,
38    },
39
40    /// Generic configuration error.
41    #[error("{0}")]
42    Config(String),
43}
44
45/// Configuration for REST route code generation.
46///
47/// Decouples the generator from any specific service — all project-specific
48/// knowledge (which packages to process, which methods are public) is passed
49/// in rather than hardcoded.
50///
51/// # Auto-Discovery
52///
53/// When no packages are registered, [`generate`](super::generate) automatically discovers all
54/// services with `google.api.http` annotations in the descriptor set, inferring
55/// Rust module paths from proto package names (dots → `::`, e.g., `auth.v1` →
56/// `auth::v1`). This matches standard `prost-build` module generation.
57///
58/// # Examples
59///
60/// Minimal — auto-discovers packages from descriptor set:
61///
62/// ```ignore
63/// let config = RestCodegenConfig::new();
64/// let code = tonic_rest_build::generate(&descriptor_bytes, &config)?;
65/// ```
66///
67/// Explicit package mapping (e.g., when using `pub use v1::*;` re-exports):
68///
69/// ```ignore
70/// let config = RestCodegenConfig::new()
71///     .package("auth.v1", "auth")
72///     .package("users.v1", "users")
73///     .wrapper_type("crate::core::Uuid")
74///     .extension_type("my_app::AuthInfo")
75///     .public_methods(&["Login", "SignUp"]);
76///
77/// let code = tonic_rest_build::generate(&descriptor_bytes, &config)?;
78/// ```
79#[derive(Clone, Debug)]
80pub struct RestCodegenConfig {
81    /// Proto package → Rust module mapping.
82    ///
83    /// When empty, packages are auto-discovered from the descriptor set:
84    /// any service with `google.api.http` annotations is included, and the
85    /// Rust module path is inferred from the proto package name (dots → `::`,
86    /// e.g., `auth.v1` → `auth::v1`).
87    ///
88    /// When set explicitly, only listed packages are processed:
89    /// - Key: proto package name (e.g., `"auth.v1"`)
90    /// - Value: Rust module path (e.g., `"auth"` or `"auth::v1"`)
91    pub(crate) packages: HashMap<String, String>,
92
93    /// Proto method names whose REST paths should bypass authentication.
94    ///
95    /// These are emitted as `PUBLIC_REST_PATHS` in the generated code.
96    pub(crate) public_methods: HashSet<String>,
97
98    /// Root module for proto-generated types (default: `"crate"`).
99    ///
100    /// Used to convert `.auth.v1.User` → `{proto_root}::auth::User`.
101    pub(crate) proto_root: String,
102
103    /// Path to the runtime crate/module (default: `"tonic_rest"`).
104    ///
105    /// Generated handlers reference `{runtime_crate}::RestError`, etc.
106    /// Set to `"crate::rest"` if the runtime types live in-crate.
107    pub(crate) runtime_crate: String,
108
109    /// Rust type path for single-field wrapper messages (e.g., `"crate::core::Uuid"`).
110    ///
111    /// When set, nested path params like `{user_id.value}` generate:
112    /// `body.user_id = Some({wrapper_type} { value })`. This is commonly
113    /// used for UUID wrapper types in protobuf.
114    /// When `None`, nested params with `.` in the path will produce a
115    /// [`GenerateError`].
116    pub(crate) wrapper_type: Option<String>,
117
118    /// SSE keep-alive interval in seconds (default: 15).
119    pub(crate) sse_keep_alive_secs: u64,
120
121    /// Concrete extension type extracted from Axum request extensions.
122    ///
123    /// When set, generated handlers use `Option<Extension<{extension_type}>>` to
124    /// extract the value from request extensions and pass it to `build_tonic_request`.
125    /// This is typically used for auth info (e.g., `"my_app::AuthInfo"`).
126    /// When `None`, handlers skip extension extraction and pass `None::<()>` directly.
127    pub(crate) extension_type: Option<String>,
128
129    /// Extra HTTP headers to forward from REST requests to gRPC metadata.
130    ///
131    /// When set, generated handlers combine `FORWARDED_HEADERS` with these
132    /// and call `build_tonic_request_with_headers` instead of `build_tonic_request`.
133    /// Use this for vendor-specific headers (e.g., `["cf-connecting-ip"]` for Cloudflare).
134    pub(crate) extra_forwarded_headers: Vec<String>,
135}
136
137impl Default for RestCodegenConfig {
138    fn default() -> Self {
139        Self {
140            packages: HashMap::new(),
141            public_methods: HashSet::new(),
142            proto_root: "crate".to_string(),
143            runtime_crate: "tonic_rest".to_string(),
144            wrapper_type: None,
145            sse_keep_alive_secs: 15,
146            extension_type: None,
147            extra_forwarded_headers: Vec::new(),
148        }
149    }
150}
151
152impl RestCodegenConfig {
153    /// Create a new config with defaults.
154    #[must_use]
155    pub fn new() -> Self {
156        Self::default()
157    }
158
159    /// Register a proto package for REST route generation.
160    ///
161    /// When at least one package is registered, only registered packages are
162    /// processed (auto-discovery is disabled).
163    ///
164    /// # Example
165    /// ```ignore
166    /// config.package("auth.v1", "auth")
167    ///       .package("users.v1", "users");
168    /// ```
169    #[must_use]
170    pub fn package(mut self, proto_package: &str, rust_module: &str) -> Self {
171        self.packages
172            .insert(proto_package.to_string(), rust_module.to_string());
173        self
174    }
175
176    /// Set proto method names whose REST paths bypass authentication.
177    ///
178    /// Method names should be in `PascalCase` as defined in proto (e.g., `"Authenticate"`).
179    #[must_use]
180    pub fn public_methods(mut self, methods: &[&str]) -> Self {
181        self.public_methods = methods.iter().map(ToString::to_string).collect();
182        self
183    }
184
185    /// Set the root module path for proto-generated types.
186    ///
187    /// Default: `"crate"` — converts `.auth.v1.User` → `crate::auth::User`.
188    #[must_use]
189    pub fn proto_root(mut self, root: &str) -> Self {
190        self.proto_root = root.to_string();
191        self
192    }
193
194    /// Set the runtime crate/module path for generated handler imports.
195    ///
196    /// Default: `"tonic_rest"` — generates `tonic_rest::RestError`, etc.
197    /// Set to `"crate::rest"` if the runtime types live alongside the generated code.
198    #[must_use]
199    pub fn runtime_crate(mut self, path: &str) -> Self {
200        self.runtime_crate = path.to_string();
201        self
202    }
203
204    /// Set the Rust type path for single-field wrapper messages.
205    ///
206    /// Required when proto paths contain nested params like `{user_id.value}`.
207    /// Commonly used for UUID wrapper types. Without this, [`generate`](super::generate)
208    /// returns a [`GenerateError`] for nested path params.
209    #[must_use]
210    pub fn wrapper_type(mut self, type_path: &str) -> Self {
211        self.wrapper_type = Some(type_path.to_string());
212        self
213    }
214
215    /// Set the SSE keep-alive interval in seconds (default: 15).
216    ///
217    /// Values less than 1 are clamped to 1 to prevent continuous keep-alive spam.
218    #[must_use]
219    pub fn sse_keep_alive_secs(mut self, secs: u64) -> Self {
220        self.sse_keep_alive_secs = secs.max(1);
221        self
222    }
223
224    /// Set the extension type extracted from Axum request extensions.
225    ///
226    /// When set, generated handlers use `Option<Extension<T>>` to extract
227    /// the value and forward it to `build_tonic_request`. Typically used
228    /// for auth info (e.g., `"my_app::AuthInfo"`).
229    /// When `None`, handlers skip extension extraction entirely.
230    ///
231    /// # Example
232    /// ```ignore
233    /// config.extension_type("my_app::AuthInfo")
234    /// ```
235    #[must_use]
236    pub fn extension_type(mut self, type_path: &str) -> Self {
237        self.extension_type = Some(type_path.to_string());
238        self
239    }
240
241    /// Add extra HTTP headers to forward from REST requests to gRPC metadata.
242    ///
243    /// These are combined with the default `FORWARDED_HEADERS` at startup.
244    /// Use for vendor-specific headers like Cloudflare's `cf-connecting-ip`.
245    ///
246    /// # Example
247    /// ```ignore
248    /// // Forward Cloudflare client IP header
249    /// config.extra_forwarded_headers(&["cf-connecting-ip"])
250    /// ```
251    #[must_use]
252    pub fn extra_forwarded_headers(mut self, headers: &[&str]) -> Self {
253        self.extra_forwarded_headers = headers.iter().map(ToString::to_string).collect();
254        self
255    }
256
257    /// Resolve a proto package name to its Rust module name.
258    pub(crate) fn rust_module(&self, proto_package: &str) -> Option<&str> {
259        self.packages.get(proto_package).map(String::as_str)
260    }
261
262    /// Return the extension extractor line for the handler signature, or empty
263    /// string if no extension type is configured.
264    ///
265    /// With `extension_type("Foo")`: `"    ext: Option<Extension<Foo>>,\n"`
266    /// Without:                      `""`
267    pub(crate) fn extension_extractor_line(&self) -> String {
268        self.extension_type.as_ref().map_or_else(String::new, |ty| {
269            format!("    ext: Option<Extension<{ty}>>,\n")
270        })
271    }
272
273    /// Return the extension binding + `build_tonic_request` call for the handler body.
274    ///
275    /// When `extra_forwarded_headers` is empty, uses `build_tonic_request`
276    /// (which forwards the default header set). When extra headers are
277    /// configured, uses `build_tonic_request_with_headers` with the
278    /// generated `ALL_FORWARDED_HEADERS` constant.
279    pub(crate) fn extension_and_request_lines(&self, body_var: &str) -> String {
280        let rt = &self.runtime_crate;
281        let build_fn = if self.extra_forwarded_headers.is_empty() {
282            match &self.extension_type {
283                Some(_) => format!("{rt}::build_tonic_request({body_var}, &headers, ext)",),
284                None => format!("{rt}::build_tonic_request::<_, ()>({body_var}, &headers, None)",),
285            }
286        } else {
287            match &self.extension_type {
288                Some(_) => format!(
289                    "{rt}::build_tonic_request_with_headers({body_var}, &headers, ext, ALL_FORWARDED_HEADERS)",
290                ),
291                None => format!(
292                    "{rt}::build_tonic_request_with_headers::<_, ()>({body_var}, &headers, None, ALL_FORWARDED_HEADERS)",
293                ),
294            }
295        };
296
297        match &self.extension_type {
298            Some(_) => format!(
299                "    let ext = ext.map(|Extension(v)| v);\n\
300                 \x20   let req = {build_fn};\n",
301            ),
302            None => format!("    let req = {build_fn};\n",),
303        }
304    }
305
306    /// Convert a fully-qualified proto type to a Rust type path.
307    ///
308    /// Uses the resolved packages map for accurate module resolution:
309    /// - `.auth.v1.User` → `{proto_root}::auth::User` (with `.package("auth.v1", "auth")`)
310    /// - `.auth.v1.User` → `{proto_root}::auth::v1::User` (auto-discovered)
311    /// - `.google.protobuf.Empty` → `()`
312    ///
313    /// Falls back to first-segment heuristic for types whose package is not
314    /// in the resolved map (e.g., cross-package references).
315    pub(crate) fn proto_type_to_rust(&self, proto_fqn: &str) -> String {
316        if proto_fqn == ".google.protobuf.Empty" {
317            return "()".to_string();
318        }
319
320        let trimmed = proto_fqn.trim_start_matches('.');
321
322        // Find the longest matching package prefix in the packages map
323        let mut best: Option<(&str, &str)> = None;
324        for (package, module) in &self.packages {
325            if let Some(rest) = trimmed.strip_prefix(package.as_str()) {
326                if rest.starts_with('.') && best.is_none_or(|(p, _)| package.len() > p.len()) {
327                    best = Some((package.as_str(), module.as_str()));
328                }
329            }
330        }
331
332        if let Some((package, module)) = best {
333            let type_name = &trimmed[package.len() + 1..];
334            format!("{}::{module}::{type_name}", self.proto_root)
335        } else {
336            // Fallback: use first segment as module name
337            let parts: Vec<&str> = trimmed.split('.').collect();
338            if parts.len() >= 3 {
339                let package = parts[0];
340                let type_name = parts[parts.len() - 1];
341                format!("{}::{package}::{type_name}", self.proto_root)
342            } else {
343                proto_fqn.to_string()
344            }
345        }
346    }
347}