tonic_rest_build/codegen/config.rs
1//! Configuration for REST route code generation.
2
3use std::collections::{HashMap, HashSet};
4
5/// Error returned by [`generate`](super::generate).
6#[derive(Debug, thiserror::Error)]
7#[non_exhaustive]
8pub enum GenerateError {
9 /// Proto `FileDescriptorSet` decoding failure.
10 #[error("failed to decode FileDescriptorSet: {0}")]
11 ProtoDecode(#[from] prost::DecodeError),
12
13 /// A nested path param (e.g., `{user_id.value}`) was found but
14 /// [`RestCodegenConfig::wrapper_type`] is not configured.
15 #[error(
16 "nested path param '{{{param}}}' requires wrapper_type to be configured. \
17 Call .wrapper_type(\"path::to::Uuid\") on RestCodegenConfig."
18 )]
19 MissingWrapperType {
20 /// The nested path parameter that triggered the error (e.g., `user_id.value`).
21 param: String,
22 },
23
24 /// Partial body selector (body is a field name, not `"*"`).
25 ///
26 /// Currently only `body: "*"` (whole message) is supported.
27 /// Field-level body selectors like `body: "user"` require sub-message
28 /// deserialization that is not yet implemented.
29 #[error(
30 "partial body selector `{body}` in method `{method}` is not supported; \
31 use `body: \"*\"` instead"
32 )]
33 UnsupportedBodySelector {
34 /// The RPC method name.
35 method: String,
36 /// The unsupported body selector value.
37 body: String,
38 },
39
40 /// Generic configuration error.
41 #[error("{0}")]
42 Config(String),
43}
44
45/// Configuration for REST route code generation.
46///
47/// Decouples the generator from any specific service — all project-specific
48/// knowledge (which packages to process, which methods are public) is passed
49/// in rather than hardcoded.
50///
51/// # Auto-Discovery
52///
53/// When no packages are registered, [`generate`](super::generate) automatically discovers all
54/// services with `google.api.http` annotations in the descriptor set, inferring
55/// Rust module paths from proto package names (dots → `::`, e.g., `auth.v1` →
56/// `auth::v1`). This matches standard `prost-build` module generation.
57///
58/// # Examples
59///
60/// Minimal — auto-discovers packages from descriptor set:
61///
62/// ```ignore
63/// let config = RestCodegenConfig::new();
64/// let code = tonic_rest_build::generate(&descriptor_bytes, &config)?;
65/// ```
66///
67/// Explicit package mapping (e.g., when using `pub use v1::*;` re-exports):
68///
69/// ```ignore
70/// let config = RestCodegenConfig::new()
71/// .package("auth.v1", "auth")
72/// .package("users.v1", "users")
73/// .wrapper_type("crate::core::Uuid")
74/// .extension_type("my_app::AuthInfo")
75/// .public_methods(&["Login", "SignUp"]);
76///
77/// let code = tonic_rest_build::generate(&descriptor_bytes, &config)?;
78/// ```
79#[derive(Clone, Debug)]
80pub struct RestCodegenConfig {
81 /// Proto package → Rust module mapping.
82 ///
83 /// When empty, packages are auto-discovered from the descriptor set:
84 /// any service with `google.api.http` annotations is included, and the
85 /// Rust module path is inferred from the proto package name (dots → `::`,
86 /// e.g., `auth.v1` → `auth::v1`).
87 ///
88 /// When set explicitly, only listed packages are processed:
89 /// - Key: proto package name (e.g., `"auth.v1"`)
90 /// - Value: Rust module path (e.g., `"auth"` or `"auth::v1"`)
91 pub(crate) packages: HashMap<String, String>,
92
93 /// Proto method names whose REST paths should bypass authentication.
94 ///
95 /// These are emitted as `PUBLIC_REST_PATHS` in the generated code.
96 pub(crate) public_methods: HashSet<String>,
97
98 /// Root module for proto-generated types (default: `"crate"`).
99 ///
100 /// Used to convert `.auth.v1.User` → `{proto_root}::auth::User`.
101 pub(crate) proto_root: String,
102
103 /// Path to the runtime crate/module (default: `"tonic_rest"`).
104 ///
105 /// Generated handlers reference `{runtime_crate}::RestError`, etc.
106 /// Set to `"crate::rest"` if the runtime types live in-crate.
107 pub(crate) runtime_crate: String,
108
109 /// Rust type path for single-field wrapper messages (e.g., `"crate::core::Uuid"`).
110 ///
111 /// When set, nested path params like `{user_id.value}` generate:
112 /// `body.user_id = Some({wrapper_type} { value })`. This is commonly
113 /// used for UUID wrapper types in protobuf.
114 /// When `None`, nested params with `.` in the path will produce a
115 /// [`GenerateError`].
116 pub(crate) wrapper_type: Option<String>,
117
118 /// SSE keep-alive interval in seconds (default: 15).
119 pub(crate) sse_keep_alive_secs: u64,
120
121 /// Concrete extension type extracted from Axum request extensions.
122 ///
123 /// When set, generated handlers use `Option<Extension<{extension_type}>>` to
124 /// extract the value from request extensions and pass it to `build_tonic_request`.
125 /// This is typically used for auth info (e.g., `"my_app::AuthInfo"`).
126 /// When `None`, handlers skip extension extraction and pass `None::<()>` directly.
127 pub(crate) extension_type: Option<String>,
128
129 /// Extra HTTP headers to forward from REST requests to gRPC metadata.
130 ///
131 /// When set, generated handlers combine `FORWARDED_HEADERS` with these
132 /// and call `build_tonic_request_with_headers` instead of `build_tonic_request`.
133 /// Use this for vendor-specific headers (e.g., `["cf-connecting-ip"]` for Cloudflare).
134 pub(crate) extra_forwarded_headers: Vec<String>,
135}
136
137impl Default for RestCodegenConfig {
138 fn default() -> Self {
139 Self {
140 packages: HashMap::new(),
141 public_methods: HashSet::new(),
142 proto_root: "crate".to_string(),
143 runtime_crate: "tonic_rest".to_string(),
144 wrapper_type: None,
145 sse_keep_alive_secs: 15,
146 extension_type: None,
147 extra_forwarded_headers: Vec::new(),
148 }
149 }
150}
151
152impl RestCodegenConfig {
153 /// Create a new config with defaults.
154 #[must_use]
155 pub fn new() -> Self {
156 Self::default()
157 }
158
159 /// Register a proto package for REST route generation.
160 ///
161 /// When at least one package is registered, only registered packages are
162 /// processed (auto-discovery is disabled).
163 ///
164 /// # Example
165 /// ```ignore
166 /// config.package("auth.v1", "auth")
167 /// .package("users.v1", "users");
168 /// ```
169 #[must_use]
170 pub fn package(mut self, proto_package: &str, rust_module: &str) -> Self {
171 self.packages
172 .insert(proto_package.to_string(), rust_module.to_string());
173 self
174 }
175
176 /// Set proto method names whose REST paths bypass authentication.
177 ///
178 /// Method names should be in `PascalCase` as defined in proto (e.g., `"Authenticate"`).
179 #[must_use]
180 pub fn public_methods(mut self, methods: &[&str]) -> Self {
181 self.public_methods = methods.iter().map(ToString::to_string).collect();
182 self
183 }
184
185 /// Set the root module path for proto-generated types.
186 ///
187 /// Default: `"crate"` — converts `.auth.v1.User` → `crate::auth::User`.
188 #[must_use]
189 pub fn proto_root(mut self, root: &str) -> Self {
190 self.proto_root = root.to_string();
191 self
192 }
193
194 /// Set the runtime crate/module path for generated handler imports.
195 ///
196 /// Default: `"tonic_rest"` — generates `tonic_rest::RestError`, etc.
197 /// Set to `"crate::rest"` if the runtime types live alongside the generated code.
198 #[must_use]
199 pub fn runtime_crate(mut self, path: &str) -> Self {
200 self.runtime_crate = path.to_string();
201 self
202 }
203
204 /// Set the Rust type path for single-field wrapper messages.
205 ///
206 /// Required when proto paths contain nested params like `{user_id.value}`.
207 /// Commonly used for UUID wrapper types. Without this, [`generate`](super::generate)
208 /// returns a [`GenerateError`] for nested path params.
209 #[must_use]
210 pub fn wrapper_type(mut self, type_path: &str) -> Self {
211 self.wrapper_type = Some(type_path.to_string());
212 self
213 }
214
215 /// Set the SSE keep-alive interval in seconds (default: 15).
216 ///
217 /// Values less than 1 are clamped to 1 to prevent continuous keep-alive spam.
218 #[must_use]
219 pub fn sse_keep_alive_secs(mut self, secs: u64) -> Self {
220 self.sse_keep_alive_secs = secs.max(1);
221 self
222 }
223
224 /// Set the extension type extracted from Axum request extensions.
225 ///
226 /// When set, generated handlers use `Option<Extension<T>>` to extract
227 /// the value and forward it to `build_tonic_request`. Typically used
228 /// for auth info (e.g., `"my_app::AuthInfo"`).
229 /// When `None`, handlers skip extension extraction entirely.
230 ///
231 /// # Example
232 /// ```ignore
233 /// config.extension_type("my_app::AuthInfo")
234 /// ```
235 #[must_use]
236 pub fn extension_type(mut self, type_path: &str) -> Self {
237 self.extension_type = Some(type_path.to_string());
238 self
239 }
240
241 /// Add extra HTTP headers to forward from REST requests to gRPC metadata.
242 ///
243 /// These are combined with the default `FORWARDED_HEADERS` at startup.
244 /// Use for vendor-specific headers like Cloudflare's `cf-connecting-ip`.
245 ///
246 /// # Example
247 /// ```ignore
248 /// // Forward Cloudflare client IP header
249 /// config.extra_forwarded_headers(&["cf-connecting-ip"])
250 /// ```
251 #[must_use]
252 pub fn extra_forwarded_headers(mut self, headers: &[&str]) -> Self {
253 self.extra_forwarded_headers = headers.iter().map(ToString::to_string).collect();
254 self
255 }
256
257 /// Resolve a proto package name to its Rust module name.
258 pub(crate) fn rust_module(&self, proto_package: &str) -> Option<&str> {
259 self.packages.get(proto_package).map(String::as_str)
260 }
261
262 /// Return the extension extractor line for the handler signature, or empty
263 /// string if no extension type is configured.
264 ///
265 /// With `extension_type("Foo")`: `" ext: Option<Extension<Foo>>,\n"`
266 /// Without: `""`
267 pub(crate) fn extension_extractor_line(&self) -> String {
268 self.extension_type.as_ref().map_or_else(String::new, |ty| {
269 format!(" ext: Option<Extension<{ty}>>,\n")
270 })
271 }
272
273 /// Return the extension binding + `build_tonic_request` call for the handler body.
274 ///
275 /// When `extra_forwarded_headers` is empty, uses `build_tonic_request`
276 /// (which forwards the default header set). When extra headers are
277 /// configured, uses `build_tonic_request_with_headers` with the
278 /// generated `ALL_FORWARDED_HEADERS` constant.
279 pub(crate) fn extension_and_request_lines(&self, body_var: &str) -> String {
280 let rt = &self.runtime_crate;
281 let build_fn = if self.extra_forwarded_headers.is_empty() {
282 match &self.extension_type {
283 Some(_) => format!("{rt}::build_tonic_request({body_var}, &headers, ext)",),
284 None => format!("{rt}::build_tonic_request::<_, ()>({body_var}, &headers, None)",),
285 }
286 } else {
287 match &self.extension_type {
288 Some(_) => format!(
289 "{rt}::build_tonic_request_with_headers({body_var}, &headers, ext, ALL_FORWARDED_HEADERS)",
290 ),
291 None => format!(
292 "{rt}::build_tonic_request_with_headers::<_, ()>({body_var}, &headers, None, ALL_FORWARDED_HEADERS)",
293 ),
294 }
295 };
296
297 match &self.extension_type {
298 Some(_) => format!(
299 " let ext = ext.map(|Extension(v)| v);\n\
300 \x20 let req = {build_fn};\n",
301 ),
302 None => format!(" let req = {build_fn};\n",),
303 }
304 }
305
306 /// Convert a fully-qualified proto type to a Rust type path.
307 ///
308 /// Uses the resolved packages map for accurate module resolution:
309 /// - `.auth.v1.User` → `{proto_root}::auth::User` (with `.package("auth.v1", "auth")`)
310 /// - `.auth.v1.User` → `{proto_root}::auth::v1::User` (auto-discovered)
311 /// - `.google.protobuf.Empty` → `()`
312 ///
313 /// Falls back to first-segment heuristic for types whose package is not
314 /// in the resolved map (e.g., cross-package references).
315 pub(crate) fn proto_type_to_rust(&self, proto_fqn: &str) -> String {
316 if proto_fqn == ".google.protobuf.Empty" {
317 return "()".to_string();
318 }
319
320 let trimmed = proto_fqn.trim_start_matches('.');
321
322 // Find the longest matching package prefix in the packages map
323 let mut best: Option<(&str, &str)> = None;
324 for (package, module) in &self.packages {
325 if let Some(rest) = trimmed.strip_prefix(package.as_str()) {
326 if rest.starts_with('.') && best.is_none_or(|(p, _)| package.len() > p.len()) {
327 best = Some((package.as_str(), module.as_str()));
328 }
329 }
330 }
331
332 if let Some((package, module)) = best {
333 let type_name = &trimmed[package.len() + 1..];
334 format!("{}::{module}::{type_name}", self.proto_root)
335 } else {
336 // Fallback: use first segment as module name
337 let parts: Vec<&str> = trimmed.split('.').collect();
338 if parts.len() >= 3 {
339 let package = parts[0];
340 let type_name = parts[parts.len() - 1];
341 format!("{}::{package}::{type_name}", self.proto_root)
342 } else {
343 proto_fqn.to_string()
344 }
345 }
346 }
347}