Skip to main content

structured_proxy/transcode/
mod.rs

1//! REST→gRPC transcoding layer.
2//!
3//! Reads `google.api.http` annotations from proto service descriptors
4//! and builds axum routes that proxy JSON/form requests to gRPC upstream.
5//!
6//! Generic: works with ANY proto descriptor set. No product-specific code.
7
8pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12
13use axum::extract::{Path, State};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use axum::routing::{delete, get, patch, post, put, MethodRouter};
17use axum::{Json, Router};
18use futures::StreamExt;
19use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
20use tonic::client::Grpc;
21
22use crate::config::AliasConfig;
23
24/// Trait for state types that support REST→gRPC transcoding.
25///
26/// Implement this for your application's state type to use `transcode::routes()`.
27/// Provides the minimal interface needed by transcode handlers.
28pub trait TranscodeState: Clone + Send + Sync + 'static {
29    /// Lazy gRPC channel to upstream service.
30    fn grpc_channel(&self) -> tonic::transport::Channel;
31    /// Headers to forward from HTTP to gRPC metadata.
32    fn forwarded_headers(&self) -> &[String];
33}
34
35impl TranscodeState for crate::ProxyState {
36    fn grpc_channel(&self) -> tonic::transport::Channel {
37        self.grpc_channel.clone()
38    }
39    fn forwarded_headers(&self) -> &[String] {
40        &self.forwarded_headers
41    }
42}
43
44/// Route entry extracted from proto HTTP annotations.
45#[derive(Debug, Clone)]
46struct RouteEntry {
47    /// HTTP path pattern (e.g., "/v1/auth/opaque/login/start").
48    http_path: String,
49    /// HTTP method (GET, POST, PUT, PATCH, DELETE).
50    http_method: HttpMethod,
51    /// gRPC path (e.g., "/sid.v1.AuthService/OpaqueLoginStart").
52    grpc_path: String,
53    /// Method descriptor for input/output message resolution.
54    method: MethodDescriptor,
55}
56
57#[derive(Debug, Clone, Copy)]
58enum HttpMethod {
59    Get,
60    Post,
61    Put,
62    Patch,
63    Delete,
64}
65
66/// Build transcoded REST→gRPC routes from a descriptor pool.
67///
68/// Takes a `DescriptorPool` and optional path aliases from config.
69/// Returns an axum Router that transcodes REST requests to gRPC calls.
70pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
71    let entries = extract_routes(pool);
72    if entries.is_empty() {
73        tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
74        return Router::new();
75    }
76
77    tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
78
79    let mut router: Router<S> = Router::new();
80    for entry in &entries {
81        let entry_clone = entry.clone();
82
83        let handler = move |proxy_state: State<S>,
84                            headers: HeaderMap,
85                            path_params: Path<std::collections::HashMap<String, String>>,
86                            body: axum::body::Bytes| {
87            transcode_handler(proxy_state, headers, path_params, body, entry_clone)
88        };
89
90        let method_router: MethodRouter<S> = match entry.http_method {
91            HttpMethod::Get => get(handler),
92            HttpMethod::Post => post(handler),
93            HttpMethod::Put => put(handler),
94            HttpMethod::Patch => patch(handler),
95            HttpMethod::Delete => delete(handler),
96        };
97
98        let axum_path = proto_path_to_axum(&entry.http_path);
99        router = router.route(&axum_path, method_router);
100
101        // Register aliases from config
102        for alias in aliases {
103            if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
104                // Build alias path: alias.from with the matched suffix
105                let alias_path = if alias.from.ends_with("/{path}") {
106                    let prefix = alias.from.trim_end_matches("/{path}");
107                    format!("{}{}", prefix, suffix)
108                } else {
109                    continue;
110                };
111
112                let alias_entry = entry.clone();
113                let alias_handler = move |proxy_state: State<S>,
114                                          headers: HeaderMap,
115                                          path_params: Path<std::collections::HashMap<String, String>>,
116                                          body: axum::body::Bytes| {
117                    transcode_handler(proxy_state, headers, path_params, body, alias_entry)
118                };
119                let alias_method: MethodRouter<S> = match entry.http_method {
120                    HttpMethod::Get => get(alias_handler),
121                    HttpMethod::Post => post(alias_handler),
122                    HttpMethod::Put => put(alias_handler),
123                    HttpMethod::Patch => patch(alias_handler),
124                    HttpMethod::Delete => delete(alias_handler),
125                };
126                router = router.route(&alias_path, alias_method);
127            }
128        }
129    }
130
131    // Server-streaming RPCs
132    let streaming_entries = extract_streaming_routes(pool);
133    for entry in &streaming_entries {
134        let entry_clone = entry.clone();
135        let axum_path = proto_path_to_axum(&entry.http_path);
136
137        let handler = move |proxy_state: State<S>, headers: HeaderMap| {
138            streaming_handler(proxy_state, headers, entry_clone)
139        };
140
141        let method_router: MethodRouter<S> = match entry.http_method {
142            HttpMethod::Get => get(handler),
143            HttpMethod::Post => post(handler),
144            _ => continue,
145        };
146
147        router = router.route(&axum_path, method_router);
148    }
149
150    router
151}
152
153/// Handler for server-streaming RPCs (NDJSON response).
154async fn streaming_handler<S: TranscodeState>(
155    State(proxy_state): State<S>,
156    headers: HeaderMap,
157    entry: RouteEntry,
158) -> Response {
159    let channel = proxy_state.grpc_channel();
160
161    let input_desc = entry.method.input();
162    let request_msg = DynamicMessage::new(input_desc);
163
164    let grpc_metadata =
165        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
166    let mut grpc_request = tonic::Request::new(request_msg);
167    *grpc_request.metadata_mut() = grpc_metadata;
168
169    let output_desc = entry.method.output();
170    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
171    let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
172        Ok(p) => p,
173        Err(e) => {
174            tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
175            return (
176                StatusCode::INTERNAL_SERVER_ERROR,
177                Json(serde_json::json!({
178                    "error": "INTERNAL",
179                    "message": "invalid gRPC path configuration",
180                })),
181            )
182                .into_response();
183        }
184    };
185
186    let mut grpc_client = Grpc::new(channel);
187    if let Err(e) = grpc_client.ready().await {
188        return (
189            StatusCode::SERVICE_UNAVAILABLE,
190            Json(serde_json::json!({
191                "error": "UNAVAILABLE",
192                "message": format!("gRPC upstream not ready: {e}"),
193            })),
194        )
195            .into_response();
196    }
197
198    match grpc_client
199        .server_streaming(grpc_request, grpc_path, grpc_codec)
200        .await
201    {
202        Ok(response) => {
203            let stream = response.into_inner();
204            let serialize_opts = SerializeOptions::new()
205                .skip_default_fields(false)
206                .stringify_64_bit_integers(true);
207
208            let byte_stream = stream.map(move |result| match result {
209                Ok(msg) => {
210                    match msg.serialize_with_options(
211                        serde_json::value::Serializer,
212                        &serialize_opts,
213                    ) {
214                        Ok(json_value) => {
215                            let mut bytes =
216                                serde_json::to_vec(&json_value).unwrap_or_default();
217                            bytes.push(b'\n');
218                            Ok::<axum::body::Bytes, std::io::Error>(
219                                axum::body::Bytes::from(bytes),
220                            )
221                        }
222                        Err(e) => Err(std::io::Error::other(format!(
223                            "serialization error: {e}"
224                        ))),
225                    }
226                }
227                Err(status) => {
228                    Err(std::io::Error::other(format!("gRPC stream error: {status}")))
229                }
230            });
231
232            let body = axum::body::Body::from_stream(byte_stream);
233            Response::builder()
234                .status(StatusCode::OK)
235                .header("content-type", "application/x-ndjson")
236                .header("transfer-encoding", "chunked")
237                .body(body)
238                .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
239        }
240        Err(status) => error::status_to_response(status),
241    }
242}
243
244/// Generic transcoding handler.
245async fn transcode_handler<S: TranscodeState>(
246    State(proxy_state): State<S>,
247    headers: HeaderMap,
248    Path(path_params): Path<std::collections::HashMap<String, String>>,
249    body_bytes: axum::body::Bytes,
250    entry: RouteEntry,
251) -> Response {
252    let channel = proxy_state.grpc_channel();
253
254    let ct = body::content_type(&headers);
255    let mut json_body = match body::parse_body(ct, &body_bytes) {
256        Ok(v) => v,
257        Err(e) => {
258            return (
259                StatusCode::BAD_REQUEST,
260                Json(serde_json::json!({
261                    "error": "INVALID_ARGUMENT",
262                    "message": format!("failed to parse request body: {e}"),
263                })),
264            )
265                .into_response();
266        }
267    };
268
269    if !path_params.is_empty() {
270        if let Some(obj) = json_body.as_object_mut() {
271            for (key, value) in &path_params {
272                obj.insert(key.clone(), serde_json::Value::String(value.clone()));
273            }
274        }
275    }
276
277    let input_desc = entry.method.input();
278    let request_msg = match DynamicMessage::deserialize(input_desc, json_body) {
279        Ok(msg) => msg,
280        Err(e) => {
281            return (
282                StatusCode::BAD_REQUEST,
283                Json(serde_json::json!({
284                    "error": "INVALID_ARGUMENT",
285                    "message": format!("failed to decode request: {e}"),
286                })),
287            )
288                .into_response();
289        }
290    };
291
292    let grpc_metadata =
293        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
294    let mut grpc_request = tonic::Request::new(request_msg);
295    *grpc_request.metadata_mut() = grpc_metadata;
296
297    let output_desc = entry.method.output();
298    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
299    let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
300        Ok(p) => p,
301        Err(e) => {
302            tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
303            return (
304                StatusCode::INTERNAL_SERVER_ERROR,
305                Json(serde_json::json!({
306                    "error": "INTERNAL",
307                    "message": "invalid gRPC path configuration",
308                })),
309            )
310                .into_response();
311        }
312    };
313
314    let mut grpc_client = Grpc::new(channel);
315    if let Err(e) = grpc_client.ready().await {
316        return (
317            StatusCode::SERVICE_UNAVAILABLE,
318            Json(serde_json::json!({
319                "error": "UNAVAILABLE",
320                "message": format!("gRPC upstream not ready: {e}"),
321            })),
322        )
323            .into_response();
324    }
325
326    match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
327        Ok(response) => {
328            let response_msg = response.into_inner();
329            let serialize_opts = SerializeOptions::new()
330                .skip_default_fields(false)
331                .stringify_64_bit_integers(true);
332            match response_msg
333                .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
334            {
335                Ok(json_value) => (StatusCode::OK, Json(json_value)).into_response(),
336                Err(e) => {
337                    tracing::error!("Failed to serialize gRPC response: {e}");
338                    (
339                        StatusCode::INTERNAL_SERVER_ERROR,
340                        Json(serde_json::json!({
341                            "error": "INTERNAL",
342                            "message": "failed to serialize response",
343                        })),
344                    )
345                        .into_response()
346                }
347            }
348        }
349        Err(status) => error::status_to_response(status),
350    }
351}
352
353/// Extract HTTP route entries from proto descriptors.
354fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
355    let http_ext = match pool.get_extension_by_name("google.api.http") {
356        Some(ext) => ext,
357        None => {
358            tracing::warn!("google.api.http extension not found in descriptor pool");
359            return Vec::new();
360        }
361    };
362
363    let mut entries = Vec::new();
364
365    for service in pool.services() {
366        for method in service.methods() {
367            if method.is_client_streaming() || method.is_server_streaming() {
368                continue;
369            }
370
371            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
372
373            if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
374                entries.push(RouteEntry {
375                    http_path,
376                    http_method,
377                    grpc_path,
378                    method: method.clone(),
379                });
380            }
381        }
382    }
383
384    entries
385}
386
387/// Extract server-streaming HTTP route entries.
388fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
389    let http_ext = match pool.get_extension_by_name("google.api.http") {
390        Some(ext) => ext,
391        None => return Vec::new(),
392    };
393
394    let mut entries = Vec::new();
395
396    for service in pool.services() {
397        for method in service.methods() {
398            if !method.is_server_streaming() || method.is_client_streaming() {
399                continue;
400            }
401
402            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
403
404            if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
405                tracing::info!(
406                    "Registering streaming route: {} {} → {}",
407                    match http_method {
408                        HttpMethod::Get => "GET",
409                        HttpMethod::Post => "POST",
410                        _ => "OTHER",
411                    },
412                    http_path,
413                    grpc_path
414                );
415                entries.push(RouteEntry {
416                    http_path,
417                    http_method,
418                    grpc_path,
419                    method: method.clone(),
420                });
421            }
422        }
423    }
424
425    entries
426}
427
428/// Extract the HTTP method and path from a method's `google.api.http` extension.
429fn extract_http_rule(
430    method: &MethodDescriptor,
431    http_ext: &prost_reflect::ExtensionDescriptor,
432) -> Option<(HttpMethod, String)> {
433    let options = method.options();
434
435    if !options.has_extension(http_ext) {
436        return None;
437    }
438
439    let http_rule = options.get_extension(http_ext);
440    if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
441        for (method_name, http_method) in [
442            ("get", HttpMethod::Get),
443            ("post", HttpMethod::Post),
444            ("put", HttpMethod::Put),
445            ("delete", HttpMethod::Delete),
446            ("patch", HttpMethod::Patch),
447        ] {
448            if let Some(val) = rule_msg.get_field_by_name(method_name) {
449                if let prost_reflect::Value::String(path) = val.into_owned() {
450                    if !path.is_empty() {
451                        return Some((http_method, path));
452                    }
453                }
454            }
455        }
456    }
457
458    None
459}
460
461/// Convert proto-style path parameters `{param}` to axum-style `:param`.
462pub fn proto_path_to_axum(path: &str) -> String {
463    let mut result = String::with_capacity(path.len());
464
465    for ch in path.chars() {
466        match ch {
467            '{' => result.push(':'),
468            '}' => {}
469            _ => result.push(ch),
470        }
471    }
472
473    result
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_proto_path_to_axum() {
482        assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/:id");
483        assert_eq!(
484            proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
485            "/v1/admin/profiles/:profile_id/metadata/:key"
486        );
487        assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
488    }
489}