Skip to main content

structured_proxy/transcode/
mod.rs

1//! REST→gRPC transcoding layer.
2//!
3//! Reads `google.api.http` annotations from proto service descriptors
4//! and builds axum routes that proxy JSON/form requests to gRPC upstream.
5//!
6//! Generic: works with ANY proto descriptor set. No product-specific code.
7
8pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12
13use axum::extract::{Path, State};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use axum::routing::{delete, get, patch, post, put, MethodRouter};
17use axum::{Json, Router};
18use futures::StreamExt;
19use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
20use tonic::client::Grpc;
21
22use crate::config::AliasConfig;
23
24/// Trait for state types that support REST→gRPC transcoding.
25///
26/// Implement this for your application's state type to use `transcode::routes()`.
27/// Provides the minimal interface needed by transcode handlers.
28pub trait TranscodeState: Clone + Send + Sync + 'static {
29    /// Lazy gRPC channel to upstream service.
30    fn grpc_channel(&self) -> tonic::transport::Channel;
31    /// Headers to forward from HTTP to gRPC metadata.
32    fn forwarded_headers(&self) -> &[String];
33}
34
35impl TranscodeState for crate::ProxyState {
36    fn grpc_channel(&self) -> tonic::transport::Channel {
37        self.grpc_channel.clone()
38    }
39    fn forwarded_headers(&self) -> &[String] {
40        &self.forwarded_headers
41    }
42}
43
44/// Route entry extracted from proto HTTP annotations.
45#[derive(Debug, Clone)]
46struct RouteEntry {
47    /// HTTP path pattern (e.g., "/v1/auth/opaque/login/start").
48    http_path: String,
49    /// HTTP method (GET, POST, PUT, PATCH, DELETE).
50    http_method: HttpMethod,
51    /// gRPC path (e.g., "/sid.v1.AuthService/OpaqueLoginStart").
52    grpc_path: String,
53    /// Method descriptor for input/output message resolution.
54    method: MethodDescriptor,
55}
56
57#[derive(Debug, Clone, Copy)]
58enum HttpMethod {
59    Get,
60    Post,
61    Put,
62    Patch,
63    Delete,
64}
65
66/// Build transcoded REST→gRPC routes from a descriptor pool.
67///
68/// Takes a `DescriptorPool` and optional path aliases from config.
69/// Returns an axum Router that transcodes REST requests to gRPC calls.
70pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
71    let entries = extract_routes(pool);
72    if entries.is_empty() {
73        tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
74        return Router::new();
75    }
76
77    tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
78
79    let mut router: Router<S> = Router::new();
80    for entry in &entries {
81        let entry_clone = entry.clone();
82
83        let handler = move |proxy_state: State<S>,
84                            headers: HeaderMap,
85                            path_params: Path<std::collections::HashMap<String, String>>,
86                            body: axum::body::Bytes| {
87            transcode_handler(proxy_state, headers, path_params, body, entry_clone)
88        };
89
90        let method_router: MethodRouter<S> = match entry.http_method {
91            HttpMethod::Get => get(handler),
92            HttpMethod::Post => post(handler),
93            HttpMethod::Put => put(handler),
94            HttpMethod::Patch => patch(handler),
95            HttpMethod::Delete => delete(handler),
96        };
97
98        let axum_path = proto_path_to_axum(&entry.http_path);
99        router = router.route(&axum_path, method_router);
100
101        // Register aliases from config
102        for alias in aliases {
103            if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
104                // Build alias path: alias.from with the matched suffix
105                let alias_path = if alias.from.ends_with("/{path}") {
106                    let prefix = alias.from.trim_end_matches("/{path}");
107                    format!("{}{}", prefix, suffix)
108                } else {
109                    continue;
110                };
111
112                let alias_entry = entry.clone();
113                let alias_handler =
114                    move |proxy_state: State<S>,
115                          headers: HeaderMap,
116                          path_params: Path<std::collections::HashMap<String, String>>,
117                          body: axum::body::Bytes| {
118                        transcode_handler(proxy_state, headers, path_params, body, alias_entry)
119                    };
120                let alias_method: MethodRouter<S> = match entry.http_method {
121                    HttpMethod::Get => get(alias_handler),
122                    HttpMethod::Post => post(alias_handler),
123                    HttpMethod::Put => put(alias_handler),
124                    HttpMethod::Patch => patch(alias_handler),
125                    HttpMethod::Delete => delete(alias_handler),
126                };
127                router = router.route(&alias_path, alias_method);
128            }
129        }
130    }
131
132    // Server-streaming RPCs
133    let streaming_entries = extract_streaming_routes(pool);
134    for entry in &streaming_entries {
135        let entry_clone = entry.clone();
136        let axum_path = proto_path_to_axum(&entry.http_path);
137
138        let handler = move |proxy_state: State<S>, headers: HeaderMap| {
139            streaming_handler(proxy_state, headers, entry_clone)
140        };
141
142        let method_router: MethodRouter<S> = match entry.http_method {
143            HttpMethod::Get => get(handler),
144            HttpMethod::Post => post(handler),
145            _ => continue,
146        };
147
148        router = router.route(&axum_path, method_router);
149    }
150
151    router
152}
153
154/// Handler for server-streaming RPCs (NDJSON response).
155async fn streaming_handler<S: TranscodeState>(
156    State(proxy_state): State<S>,
157    headers: HeaderMap,
158    entry: RouteEntry,
159) -> Response {
160    let channel = proxy_state.grpc_channel();
161
162    let input_desc = entry.method.input();
163    let request_msg = DynamicMessage::new(input_desc);
164
165    let grpc_metadata =
166        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
167    let mut grpc_request = tonic::Request::new(request_msg);
168    *grpc_request.metadata_mut() = grpc_metadata;
169
170    let output_desc = entry.method.output();
171    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
172    let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
173        Ok(p) => p,
174        Err(e) => {
175            tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
176            return (
177                StatusCode::INTERNAL_SERVER_ERROR,
178                Json(serde_json::json!({
179                    "error": "INTERNAL",
180                    "message": "invalid gRPC path configuration",
181                })),
182            )
183                .into_response();
184        }
185    };
186
187    let mut grpc_client = Grpc::new(channel);
188    if let Err(e) = grpc_client.ready().await {
189        return (
190            StatusCode::SERVICE_UNAVAILABLE,
191            Json(serde_json::json!({
192                "error": "UNAVAILABLE",
193                "message": format!("gRPC upstream not ready: {e}"),
194            })),
195        )
196            .into_response();
197    }
198
199    match grpc_client
200        .server_streaming(grpc_request, grpc_path, grpc_codec)
201        .await
202    {
203        Ok(response) => {
204            let stream = response.into_inner();
205            let serialize_opts = SerializeOptions::new()
206                .skip_default_fields(false)
207                .stringify_64_bit_integers(true);
208
209            let byte_stream = stream.map(move |result| match result {
210                Ok(msg) => {
211                    match msg.serialize_with_options(serde_json::value::Serializer, &serialize_opts)
212                    {
213                        Ok(json_value) => {
214                            let mut bytes = serde_json::to_vec(&json_value).unwrap_or_default();
215                            bytes.push(b'\n');
216                            Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(bytes))
217                        }
218                        Err(e) => Err(std::io::Error::other(format!("serialization error: {e}"))),
219                    }
220                }
221                Err(status) => Err(std::io::Error::other(format!(
222                    "gRPC stream error: {status}"
223                ))),
224            });
225
226            let body = axum::body::Body::from_stream(byte_stream);
227            Response::builder()
228                .status(StatusCode::OK)
229                .header("content-type", "application/x-ndjson")
230                .header("transfer-encoding", "chunked")
231                .body(body)
232                .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
233        }
234        Err(status) => error::status_to_response(status),
235    }
236}
237
238/// Generic transcoding handler.
239async fn transcode_handler<S: TranscodeState>(
240    State(proxy_state): State<S>,
241    headers: HeaderMap,
242    Path(path_params): Path<std::collections::HashMap<String, String>>,
243    body_bytes: axum::body::Bytes,
244    entry: RouteEntry,
245) -> Response {
246    let channel = proxy_state.grpc_channel();
247
248    let ct = body::content_type(&headers);
249    let mut json_body = match body::parse_body(ct, &body_bytes) {
250        Ok(v) => v,
251        Err(e) => {
252            return (
253                StatusCode::BAD_REQUEST,
254                Json(serde_json::json!({
255                    "error": "INVALID_ARGUMENT",
256                    "message": format!("failed to parse request body: {e}"),
257                })),
258            )
259                .into_response();
260        }
261    };
262
263    if !path_params.is_empty() {
264        if let Some(obj) = json_body.as_object_mut() {
265            for (key, value) in &path_params {
266                obj.insert(key.clone(), serde_json::Value::String(value.clone()));
267            }
268        }
269    }
270
271    let input_desc = entry.method.input();
272    let request_msg = match DynamicMessage::deserialize(input_desc, json_body) {
273        Ok(msg) => msg,
274        Err(e) => {
275            return (
276                StatusCode::BAD_REQUEST,
277                Json(serde_json::json!({
278                    "error": "INVALID_ARGUMENT",
279                    "message": format!("failed to decode request: {e}"),
280                })),
281            )
282                .into_response();
283        }
284    };
285
286    let grpc_metadata =
287        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
288    let mut grpc_request = tonic::Request::new(request_msg);
289    *grpc_request.metadata_mut() = grpc_metadata;
290
291    let output_desc = entry.method.output();
292    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
293    let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
294        Ok(p) => p,
295        Err(e) => {
296            tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
297            return (
298                StatusCode::INTERNAL_SERVER_ERROR,
299                Json(serde_json::json!({
300                    "error": "INTERNAL",
301                    "message": "invalid gRPC path configuration",
302                })),
303            )
304                .into_response();
305        }
306    };
307
308    let mut grpc_client = Grpc::new(channel);
309    if let Err(e) = grpc_client.ready().await {
310        return (
311            StatusCode::SERVICE_UNAVAILABLE,
312            Json(serde_json::json!({
313                "error": "UNAVAILABLE",
314                "message": format!("gRPC upstream not ready: {e}"),
315            })),
316        )
317            .into_response();
318    }
319
320    match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
321        Ok(response) => {
322            let response_msg = response.into_inner();
323            let serialize_opts = SerializeOptions::new()
324                .skip_default_fields(false)
325                .stringify_64_bit_integers(true);
326            match response_msg
327                .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
328            {
329                Ok(json_value) => (StatusCode::OK, Json(json_value)).into_response(),
330                Err(e) => {
331                    tracing::error!("Failed to serialize gRPC response: {e}");
332                    (
333                        StatusCode::INTERNAL_SERVER_ERROR,
334                        Json(serde_json::json!({
335                            "error": "INTERNAL",
336                            "message": "failed to serialize response",
337                        })),
338                    )
339                        .into_response()
340                }
341            }
342        }
343        Err(status) => error::status_to_response(status),
344    }
345}
346
347/// Extract HTTP route entries from proto descriptors.
348fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
349    let http_ext = match pool.get_extension_by_name("google.api.http") {
350        Some(ext) => ext,
351        None => {
352            tracing::warn!("google.api.http extension not found in descriptor pool");
353            return Vec::new();
354        }
355    };
356
357    let mut entries = Vec::new();
358
359    for service in pool.services() {
360        for method in service.methods() {
361            if method.is_client_streaming() || method.is_server_streaming() {
362                continue;
363            }
364
365            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
366
367            if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
368                entries.push(RouteEntry {
369                    http_path,
370                    http_method,
371                    grpc_path,
372                    method: method.clone(),
373                });
374            }
375        }
376    }
377
378    entries
379}
380
381/// Extract server-streaming HTTP route entries.
382fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
383    let http_ext = match pool.get_extension_by_name("google.api.http") {
384        Some(ext) => ext,
385        None => return Vec::new(),
386    };
387
388    let mut entries = Vec::new();
389
390    for service in pool.services() {
391        for method in service.methods() {
392            if !method.is_server_streaming() || method.is_client_streaming() {
393                continue;
394            }
395
396            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
397
398            if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
399                tracing::info!(
400                    "Registering streaming route: {} {} → {}",
401                    match http_method {
402                        HttpMethod::Get => "GET",
403                        HttpMethod::Post => "POST",
404                        _ => "OTHER",
405                    },
406                    http_path,
407                    grpc_path
408                );
409                entries.push(RouteEntry {
410                    http_path,
411                    http_method,
412                    grpc_path,
413                    method: method.clone(),
414                });
415            }
416        }
417    }
418
419    entries
420}
421
422/// Extract the HTTP method and path from a method's `google.api.http` extension.
423fn extract_http_rule(
424    method: &MethodDescriptor,
425    http_ext: &prost_reflect::ExtensionDescriptor,
426) -> Option<(HttpMethod, String)> {
427    let options = method.options();
428
429    if !options.has_extension(http_ext) {
430        return None;
431    }
432
433    let http_rule = options.get_extension(http_ext);
434    if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
435        for (method_name, http_method) in [
436            ("get", HttpMethod::Get),
437            ("post", HttpMethod::Post),
438            ("put", HttpMethod::Put),
439            ("delete", HttpMethod::Delete),
440            ("patch", HttpMethod::Patch),
441        ] {
442            if let Some(val) = rule_msg.get_field_by_name(method_name) {
443                if let prost_reflect::Value::String(path) = val.into_owned() {
444                    if !path.is_empty() {
445                        return Some((http_method, path));
446                    }
447                }
448            }
449        }
450    }
451
452    None
453}
454
455/// Convert proto-style path parameters `{param}` to axum-style `:param`.
456pub fn proto_path_to_axum(path: &str) -> String {
457    let mut result = String::with_capacity(path.len());
458
459    for ch in path.chars() {
460        match ch {
461            '{' => result.push(':'),
462            '}' => {}
463            _ => result.push(ch),
464        }
465    }
466
467    result
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_proto_path_to_axum() {
476        assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/:id");
477        assert_eq!(
478            proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
479            "/v1/admin/profiles/:profile_id/metadata/:key"
480        );
481        assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
482    }
483}