Skip to main content

structured_proxy/transcode/
mod.rs

1//! REST→gRPC transcoding layer.
2//!
3//! Reads `google.api.http` annotations from proto service descriptors
4//! and builds axum routes that proxy JSON/form requests to gRPC upstream.
5//!
6//! Generic: works with ANY proto descriptor set. No product-specific code.
7
8pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12pub mod request;
13
14use axum::extract::{Path, RawQuery, State};
15use axum::http::{HeaderMap, StatusCode};
16use axum::response::sse::{Event, KeepAlive, Sse};
17use axum::response::{IntoResponse, Response};
18use axum::routing::{delete, get, patch, post, put, MethodRouter};
19use axum::{Json, Router};
20use futures::StreamExt;
21use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
22use tonic::client::Grpc;
23
24use crate::config::AliasConfig;
25
26/// Trait for state types that support REST→gRPC transcoding.
27///
28/// Implement this for your application's state type to use `transcode::routes()`.
29/// Provides the minimal interface needed by transcode handlers.
30pub trait TranscodeState: Clone + Send + Sync + 'static {
31    /// Lazy gRPC channel to upstream service.
32    fn grpc_channel(&self) -> tonic::transport::Channel;
33    /// Headers to forward from HTTP to gRPC metadata.
34    fn forwarded_headers(&self) -> &[String];
35    /// SSE keep-alive interval (seconds) for server-streaming responses.
36    fn sse_keep_alive_secs(&self) -> u64;
37}
38
39impl TranscodeState for crate::ProxyState {
40    fn grpc_channel(&self) -> tonic::transport::Channel {
41        self.grpc_channel.clone()
42    }
43    fn forwarded_headers(&self) -> &[String] {
44        &self.forwarded_headers
45    }
46    fn sse_keep_alive_secs(&self) -> u64 {
47        self.sse_keep_alive_secs
48    }
49}
50
51/// Route entry extracted from proto HTTP annotations.
52#[derive(Debug, Clone)]
53struct RouteEntry {
54    /// HTTP path pattern (e.g., "/v1/auth/opaque/login/start").
55    http_path: String,
56    /// HTTP method (GET, POST, PUT, PATCH, DELETE).
57    http_method: HttpMethod,
58    /// gRPC path (e.g., "/sid.v1.AuthService/OpaqueLoginStart"), parsed once at
59    /// route-build time so each request clones a cheap `Bytes` refcount.
60    grpc_path: axum::http::uri::PathAndQuery,
61    /// Method descriptor for input/output message resolution.
62    method: MethodDescriptor,
63    /// How the request body maps onto the gRPC request message.
64    body: request::BodyMapping,
65    /// Optional response subfield to return as the HTTP body (`response_body`).
66    response_body: Option<String>,
67}
68
69#[derive(Debug, Clone, Copy)]
70enum HttpMethod {
71    Get,
72    Post,
73    Put,
74    Patch,
75    Delete,
76}
77
78/// Build transcoded REST→gRPC routes from a descriptor pool.
79///
80/// Takes a `DescriptorPool` and optional path aliases from config.
81/// Returns an axum Router that transcodes REST requests to gRPC calls.
82pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
83    let entries = extract_routes(pool);
84    if entries.is_empty() {
85        tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
86        return Router::new();
87    }
88
89    tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
90
91    let mut router: Router<S> = Router::new();
92    for entry in &entries {
93        let entry_clone = std::sync::Arc::new(entry.clone());
94
95        let handler = move |proxy_state: State<S>,
96                            headers: HeaderMap,
97                            path_params: Path<std::collections::HashMap<String, String>>,
98                            raw_query: RawQuery,
99                            body: axum::body::Bytes| {
100            transcode_handler(
101                proxy_state,
102                headers,
103                path_params,
104                raw_query,
105                body,
106                entry_clone,
107            )
108        };
109
110        let method_router: MethodRouter<S> = match entry.http_method {
111            HttpMethod::Get => get(handler),
112            HttpMethod::Post => post(handler),
113            HttpMethod::Put => put(handler),
114            HttpMethod::Patch => patch(handler),
115            HttpMethod::Delete => delete(handler),
116        };
117
118        let axum_path = proto_path_to_axum(&entry.http_path);
119        router = router.route(&axum_path, method_router);
120
121        // Register aliases from config
122        for alias in aliases {
123            if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
124                // Build alias path: alias.from with the matched suffix
125                let alias_path = if alias.from.ends_with("/{path}") {
126                    let prefix = alias.from.trim_end_matches("/{path}");
127                    format!("{}{}", prefix, suffix)
128                } else {
129                    continue;
130                };
131
132                let alias_entry = std::sync::Arc::new(entry.clone());
133                let alias_handler =
134                    move |proxy_state: State<S>,
135                          headers: HeaderMap,
136                          path_params: Path<std::collections::HashMap<String, String>>,
137                          raw_query: RawQuery,
138                          body: axum::body::Bytes| {
139                        transcode_handler(
140                            proxy_state,
141                            headers,
142                            path_params,
143                            raw_query,
144                            body,
145                            alias_entry,
146                        )
147                    };
148                let alias_method: MethodRouter<S> = match entry.http_method {
149                    HttpMethod::Get => get(alias_handler),
150                    HttpMethod::Post => post(alias_handler),
151                    HttpMethod::Put => put(alias_handler),
152                    HttpMethod::Patch => patch(alias_handler),
153                    HttpMethod::Delete => delete(alias_handler),
154                };
155                router = router.route(&alias_path, alias_method);
156            }
157        }
158    }
159
160    // Server-streaming RPCs
161    let streaming_entries = extract_streaming_routes(pool);
162    for entry in &streaming_entries {
163        let entry_clone = std::sync::Arc::new(entry.clone());
164        let axum_path = proto_path_to_axum(&entry.http_path);
165
166        let handler = move |proxy_state: State<S>, headers: HeaderMap| {
167            streaming_handler(proxy_state, headers, entry_clone)
168        };
169
170        let method_router: MethodRouter<S> = match entry.http_method {
171            HttpMethod::Get => get(handler),
172            HttpMethod::Post => post(handler),
173            _ => continue,
174        };
175
176        router = router.route(&axum_path, method_router);
177    }
178
179    router
180}
181
182/// JSON serialization options shared by the unary and streaming response paths,
183/// so a given message serializes identically regardless of RPC kind.
184fn response_serialize_options() -> SerializeOptions {
185    SerializeOptions::new()
186        .skip_default_fields(false)
187        .stringify_64_bit_integers(true)
188}
189
190/// Serialize one streamed gRPC message to a compact JSON string.
191fn message_to_json_string(msg: &DynamicMessage, opts: &SerializeOptions) -> Result<String, String> {
192    let value = msg
193        .serialize_with_options(serde_json::value::Serializer, opts)
194        .map_err(|e| e.to_string())?;
195    serde_json::to_string(&value).map_err(|e| e.to_string())
196}
197
198/// Terminal error frame for a stream that failed mid-flight. Shared by the
199/// NDJSON and SSE paths so a client sees the same shape in either format.
200fn stream_error_json(status: &tonic::Status) -> serde_json::Value {
201    serde_json::json!({
202        "error": error::grpc_code_name(status.code()),
203        "message": status.message(),
204        "code": status.code() as i32,
205    })
206}
207
208/// Whether the client negotiated a Server-Sent Events response via `Accept`.
209///
210/// Considers every `Accept` header line (a client may send more than one) and
211/// every comma-separated media range within each. Matches `text/event-stream`
212/// case-insensitively and honors the quality factor: per RFC 7231 §5.3.1 a
213/// `q=0` weight means the type is explicitly not acceptable, so it does not
214/// select the SSE path.
215fn wants_sse(headers: &HeaderMap) -> bool {
216    headers
217        .get_all(axum::http::header::ACCEPT)
218        .iter()
219        .filter_map(|v| v.to_str().ok())
220        .flat_map(|accept| accept.split(','))
221        .any(accept_range_selects_sse)
222}
223
224/// Whether a single `Accept` media range selects `text/event-stream` with a
225/// non-zero quality factor.
226fn accept_range_selects_sse(range: &str) -> bool {
227    let mut parts = range.split(';');
228    let media = parts.next().unwrap_or("").trim();
229    if !media.eq_ignore_ascii_case("text/event-stream") {
230        return false;
231    }
232    // Default weight is 1.0; only an explicit `q=0` (or unparseable-as-positive)
233    // disqualifies the match. A malformed weight falls back to acceptable.
234    for param in parts {
235        let mut kv = param.splitn(2, '=');
236        if kv.next().unwrap_or("").trim().eq_ignore_ascii_case("q") {
237            let q: f32 = kv.next().unwrap_or("").trim().parse().unwrap_or(1.0);
238            return q > 0.0;
239        }
240    }
241    true
242}
243
244/// Handler for server-streaming RPCs.
245///
246/// Returns Server-Sent Events when the client sends `Accept: text/event-stream`,
247/// otherwise newline-delimited JSON (NDJSON). In both formats a gRPC error
248/// mid-stream is delivered as an explicit terminal frame before the stream is
249/// closed cleanly, rather than truncating the HTTP body.
250async fn streaming_handler<S: TranscodeState>(
251    State(proxy_state): State<S>,
252    headers: HeaderMap,
253    entry: std::sync::Arc<RouteEntry>,
254) -> Response {
255    let channel = proxy_state.grpc_channel();
256
257    let input_desc = entry.method.input();
258    let request_msg = DynamicMessage::new(input_desc);
259
260    let grpc_metadata =
261        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
262    let mut grpc_request = tonic::Request::new(request_msg);
263    *grpc_request.metadata_mut() = grpc_metadata;
264    metadata::apply_request_deadline(&mut grpc_request, &headers);
265
266    let output_desc = entry.method.output();
267    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
268    let grpc_path = entry.grpc_path.clone();
269
270    let mut grpc_client = Grpc::new(channel);
271    if let Err(e) = grpc_client.ready().await {
272        return (
273            StatusCode::SERVICE_UNAVAILABLE,
274            Json(serde_json::json!({
275                "error": "UNAVAILABLE",
276                "message": format!("gRPC upstream not ready: {e}"),
277            })),
278        )
279            .into_response();
280    }
281
282    let use_sse = wants_sse(&headers);
283
284    match grpc_client
285        .server_streaming(grpc_request, grpc_path, grpc_codec)
286        .await
287    {
288        Ok(response) => {
289            let stream = response.into_inner();
290            if use_sse {
291                sse_response(stream, proxy_state.sse_keep_alive_secs())
292            } else {
293                ndjson_response(stream)
294            }
295        }
296        Err(status) => error::status_to_response(status),
297    }
298}
299
300/// One JSON frame of a streaming response, already serialized.
301///
302/// `Error` is terminal: [`json_frames`] stops the stream right after yielding
303/// it, so an error frame is always the last thing a client sees regardless of
304/// whether it came from a gRPC status or a serialization failure.
305enum StreamFrame {
306    Data(String),
307    Error(String),
308}
309
310/// Turn a gRPC message stream into a stream of serialized JSON frames, stopping
311/// after the first error so error frames are unambiguously terminal.
312///
313/// Both a gRPC `Status` and a per-message serialization failure become a
314/// terminal [`StreamFrame::Error`]; downstream messages the upstream might
315/// still emit are dropped rather than streamed past the error.
316fn json_frames<St>(stream: St) -> impl futures::Stream<Item = StreamFrame> + Send + 'static
317where
318    St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
319{
320    let opts = response_serialize_options();
321    stream.scan(false, move |stopped, result| {
322        if *stopped {
323            return futures::future::ready(None);
324        }
325        let frame = match result {
326            Ok(msg) => match message_to_json_string(&msg, &opts) {
327                Ok(s) => StreamFrame::Data(s),
328                Err(e) => {
329                    *stopped = true;
330                    StreamFrame::Error(
331                        serde_json::json!({
332                            "error": "INTERNAL",
333                            "message": format!("serialization error: {e}"),
334                        })
335                        .to_string(),
336                    )
337                }
338            },
339            Err(status) => {
340                *stopped = true;
341                StreamFrame::Error(stream_error_json(&status).to_string())
342            }
343        };
344        futures::future::ready(Some(frame))
345    })
346}
347
348/// Build an NDJSON (`application/x-ndjson`) streaming response.
349fn ndjson_response<St>(stream: St) -> Response
350where
351    St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
352{
353    // Data and error frames are both JSON lines; an error is distinguished by
354    // its `error` field and by being the final line (see `json_frames`).
355    let byte_stream = json_frames(stream).map(|frame| {
356        let mut line = match frame {
357            StreamFrame::Data(s) | StreamFrame::Error(s) => s,
358        };
359        line.push('\n');
360        Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(line))
361    });
362
363    let body = axum::body::Body::from_stream(byte_stream);
364    // Body framing (chunked on HTTP/1.1, DATA frames on HTTP/2) is chosen by
365    // hyper from the protocol version; setting transfer-encoding by hand would
366    // be redundant on HTTP/1.1 and illegal on HTTP/2.
367    Response::builder()
368        .status(StatusCode::OK)
369        .header("content-type", "application/x-ndjson")
370        .body(body)
371        .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
372}
373
374/// Build a Server-Sent Events (`text/event-stream`) streaming response.
375fn sse_response<St>(stream: St, keep_alive_secs: u64) -> Response
376where
377    St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
378{
379    // Terminal errors use the `stream-error` event type, not the reserved
380    // `error` type that the browser EventSource dispatches for transport
381    // failures — clients listen for it via addEventListener("stream-error").
382    let event_stream = json_frames(stream).map(|frame| {
383        let event = match frame {
384            StreamFrame::Data(s) => Event::default().data(s),
385            StreamFrame::Error(s) => Event::default().event("stream-error").data(s),
386        };
387        Ok::<Event, std::convert::Infallible>(event)
388    });
389
390    Sse::new(event_stream)
391        .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(keep_alive_secs)))
392        .into_response()
393}
394
395/// Generic transcoding handler.
396async fn transcode_handler<S: TranscodeState>(
397    State(proxy_state): State<S>,
398    headers: HeaderMap,
399    Path(path_params): Path<std::collections::HashMap<String, String>>,
400    RawQuery(raw_query): RawQuery,
401    body_bytes: axum::body::Bytes,
402    entry: std::sync::Arc<RouteEntry>,
403) -> Response {
404    let channel = proxy_state.grpc_channel();
405
406    // Only read the body when the rule maps it onto the message.
407    let json_body = match entry.body {
408        request::BodyMapping::None => serde_json::Value::Null,
409        _ => {
410            let ct = body::content_type(&headers);
411            match body::parse_body(ct, &body_bytes) {
412                Ok(v) => v,
413                Err(e) => {
414                    return (
415                        StatusCode::BAD_REQUEST,
416                        Json(serde_json::json!({
417                            "error": "INVALID_ARGUMENT",
418                            "message": format!("failed to parse request body: {e}"),
419                        })),
420                    )
421                        .into_response();
422                }
423            }
424        }
425    };
426
427    // Query string → field bindings (fields not bound by path or body).
428    // A malformed query is a client error: reject it rather than silently
429    // dropping every query-bound field.
430    let query_pairs = match request::parse_query(raw_query.as_deref()) {
431        Ok(pairs) => pairs,
432        Err(e) => {
433            return (
434                StatusCode::BAD_REQUEST,
435                Json(serde_json::json!({
436                    "error": "INVALID_ARGUMENT",
437                    "message": e,
438                })),
439            )
440                .into_response();
441        }
442    };
443
444    let input_desc = entry.method.input();
445    let request_json = match request::build_request_json(
446        &input_desc,
447        &entry.body,
448        json_body,
449        &path_params,
450        &query_pairs,
451    ) {
452        Ok(v) => v,
453        Err(e) => {
454            return (
455                StatusCode::BAD_REQUEST,
456                Json(serde_json::json!({
457                    "error": "INVALID_ARGUMENT",
458                    "message": e,
459                })),
460            )
461                .into_response();
462        }
463    };
464
465    let request_msg = match DynamicMessage::deserialize(input_desc, request_json) {
466        Ok(msg) => msg,
467        Err(e) => {
468            return (
469                StatusCode::BAD_REQUEST,
470                Json(serde_json::json!({
471                    "error": "INVALID_ARGUMENT",
472                    "message": format!("failed to decode request: {e}"),
473                })),
474            )
475                .into_response();
476        }
477    };
478
479    let grpc_metadata =
480        metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
481    let mut grpc_request = tonic::Request::new(request_msg);
482    *grpc_request.metadata_mut() = grpc_metadata;
483    metadata::apply_request_deadline(&mut grpc_request, &headers);
484
485    let output_desc = entry.method.output();
486    let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
487    let grpc_path = entry.grpc_path.clone();
488
489    let mut grpc_client = Grpc::new(channel);
490    if let Err(e) = grpc_client.ready().await {
491        return (
492            StatusCode::SERVICE_UNAVAILABLE,
493            Json(serde_json::json!({
494                "error": "UNAVAILABLE",
495                "message": format!("gRPC upstream not ready: {e}"),
496            })),
497        )
498            .into_response();
499    }
500
501    match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
502        Ok(response) => {
503            let response_msg = response.into_inner();
504            let serialize_opts = response_serialize_options();
505            match response_msg
506                .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
507            {
508                Ok(json_value) => {
509                    // `response_body` returns just that subfield as the HTTP body.
510                    let out = match &entry.response_body {
511                        Some(path) => request::extract_response_body(&json_value, path)
512                            .unwrap_or_else(|| {
513                                tracing::warn!(
514                                    response_body = %path,
515                                    "configured response_body path not found in response; \
516                                     returning null"
517                                );
518                                serde_json::Value::Null
519                            }),
520                        None => json_value,
521                    };
522                    (StatusCode::OK, Json(out)).into_response()
523                }
524                Err(e) => {
525                    tracing::error!("Failed to serialize gRPC response: {e}");
526                    (
527                        StatusCode::INTERNAL_SERVER_ERROR,
528                        Json(serde_json::json!({
529                            "error": "INTERNAL",
530                            "message": "failed to serialize response",
531                        })),
532                    )
533                        .into_response()
534                }
535            }
536        }
537        Err(status) => error::status_to_response(status),
538    }
539}
540
541/// Extract HTTP route entries from proto descriptors.
542fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
543    let http_ext = match pool.get_extension_by_name("google.api.http") {
544        Some(ext) => ext,
545        None => {
546            tracing::warn!("google.api.http extension not found in descriptor pool");
547            return Vec::new();
548        }
549    };
550
551    let mut entries = Vec::new();
552
553    for service in pool.services() {
554        for method in service.methods() {
555            if method.is_client_streaming() || method.is_server_streaming() {
556                continue;
557            }
558
559            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
560            let grpc_path: axum::http::uri::PathAndQuery = match grpc_path.parse() {
561                Ok(p) => p,
562                Err(e) => {
563                    tracing::error!("skipping route with invalid gRPC path '{grpc_path}': {e}");
564                    continue;
565                }
566            };
567
568            for binding in extract_http_bindings(&method, &http_ext) {
569                entries.push(RouteEntry {
570                    http_path: binding.http_path,
571                    http_method: binding.http_method,
572                    grpc_path: grpc_path.clone(),
573                    method: method.clone(),
574                    body: binding.body,
575                    response_body: binding.response_body,
576                });
577            }
578        }
579    }
580
581    entries
582}
583
584/// Extract server-streaming HTTP route entries.
585fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
586    let http_ext = match pool.get_extension_by_name("google.api.http") {
587        Some(ext) => ext,
588        None => return Vec::new(),
589    };
590
591    let mut entries = Vec::new();
592
593    for service in pool.services() {
594        for method in service.methods() {
595            if !method.is_server_streaming() || method.is_client_streaming() {
596                continue;
597            }
598
599            let grpc_path = format!("/{}/{}", service.full_name(), method.name());
600            let grpc_path: axum::http::uri::PathAndQuery = match grpc_path.parse() {
601                Ok(p) => p,
602                Err(e) => {
603                    tracing::error!("skipping route with invalid gRPC path '{grpc_path}': {e}");
604                    continue;
605                }
606            };
607
608            for binding in extract_http_bindings(&method, &http_ext) {
609                tracing::info!(
610                    "Registering streaming route: {} {} → {}",
611                    match binding.http_method {
612                        HttpMethod::Get => "GET",
613                        HttpMethod::Post => "POST",
614                        _ => "OTHER",
615                    },
616                    binding.http_path,
617                    grpc_path
618                );
619                entries.push(RouteEntry {
620                    http_path: binding.http_path,
621                    http_method: binding.http_method,
622                    grpc_path: grpc_path.clone(),
623                    method: method.clone(),
624                    body: binding.body,
625                    response_body: binding.response_body,
626                });
627            }
628        }
629    }
630
631    entries
632}
633
634/// A single HTTP binding parsed from a `google.api.http` rule.
635struct HttpBinding {
636    http_method: HttpMethod,
637    http_path: String,
638    body: request::BodyMapping,
639    response_body: Option<String>,
640}
641
642/// Extract all HTTP bindings (the primary rule plus any `additional_bindings`)
643/// from a method's `google.api.http` extension.
644fn extract_http_bindings(
645    method: &MethodDescriptor,
646    http_ext: &prost_reflect::ExtensionDescriptor,
647) -> Vec<HttpBinding> {
648    let options = method.options();
649    if !options.has_extension(http_ext) {
650        return Vec::new();
651    }
652
653    let prost_reflect::Value::Message(rule_msg) = options.get_extension(http_ext).into_owned()
654    else {
655        return Vec::new();
656    };
657
658    collect_bindings(&rule_msg)
659}
660
661/// Collect the primary binding plus every `additional_bindings` entry from an
662/// `HttpRule` message.
663fn collect_bindings(rule_msg: &DynamicMessage) -> Vec<HttpBinding> {
664    let mut bindings = Vec::new();
665    if let Some(binding) = parse_http_rule(rule_msg) {
666        bindings.push(binding);
667    }
668
669    // additional_bindings is a repeated HttpRule; each carries its own
670    // method/path/body. The proto forbids nesting them further.
671    if let Some(field) = rule_msg.get_field_by_name("additional_bindings") {
672        if let prost_reflect::Value::List(list) = field.into_owned() {
673            for item in list {
674                if let prost_reflect::Value::Message(sub) = item {
675                    if let Some(binding) = parse_http_rule(&sub) {
676                        bindings.push(binding);
677                    }
678                }
679            }
680        }
681    }
682
683    bindings
684}
685
686/// Parse a single `HttpRule` message into a binding (method+path required).
687fn parse_http_rule(rule_msg: &DynamicMessage) -> Option<HttpBinding> {
688    let (http_method, http_path) = [
689        ("get", HttpMethod::Get),
690        ("post", HttpMethod::Post),
691        ("put", HttpMethod::Put),
692        ("delete", HttpMethod::Delete),
693        ("patch", HttpMethod::Patch),
694    ]
695    .into_iter()
696    .find_map(
697        |(name, http_method)| match rule_msg.get_field_by_name(name)?.into_owned() {
698            prost_reflect::Value::String(path) if !path.is_empty() => Some((http_method, path)),
699            _ => None,
700        },
701    )?;
702
703    let body = rule_msg
704        .get_field_by_name("body")
705        .and_then(|v| match v.into_owned() {
706            prost_reflect::Value::String(s) => Some(request::BodyMapping::parse(&s)),
707            _ => None,
708        })
709        .unwrap_or(request::BodyMapping::None);
710
711    let response_body =
712        rule_msg
713            .get_field_by_name("response_body")
714            .and_then(|v| match v.into_owned() {
715                prost_reflect::Value::String(s) if !s.is_empty() => Some(s),
716                _ => None,
717            });
718
719    Some(HttpBinding {
720        http_method,
721        http_path,
722        body,
723        response_body,
724    })
725}
726
727/// Convert a `google.api.http` path template to axum 0.8 path syntax.
728///
729/// The proto `{param}` form IS axum 0.8's native capture syntax, so plain
730/// single-segment params pass through verbatim. Only field-path templates and
731/// bare wildcards need rewriting (axum 0.7 used `:param`; 0.8 uses `{param}`
732/// and rejects any segment starting with `:`):
733/// - `{name=*}`  (single segment)      -> `{name}`
734/// - `{name=**}` (multi-segment) -> `{*name}` (axum catch-all)
735/// - bare `*` segment            -> `{wildcardN}`
736/// - bare `**` segment           -> `{*wildcardN}` (axum catch-all)
737pub fn proto_path_to_axum(path: &str) -> String {
738    let mut out = String::with_capacity(path.len());
739
740    let segments = split_top_level(path);
741    let last = segments.len().saturating_sub(1);
742    for (idx, segment) in segments.iter().enumerate() {
743        if idx > 0 {
744            out.push('/');
745        }
746        out.push_str(&convert_segment(segment, idx, idx == last));
747    }
748
749    out
750}
751
752/// Split a path on `/` boundaries that are NOT inside a `{...}` brace span.
753///
754/// google.api.http field templates can embed slashes inside a single capture
755/// (e.g. the AIP-127 resource name `{name=shelves/*/books/*}`), so a naive
756/// `str::split('/')` would fracture the brace span into invalid fragments.
757/// Tracking brace depth keeps each capture intact.
758fn split_top_level(path: &str) -> Vec<&str> {
759    let mut segments = Vec::new();
760    let mut depth = 0usize;
761    let mut start = 0usize;
762
763    for (i, ch) in path.char_indices() {
764        match ch {
765            '{' => depth += 1,
766            // Decrement only on a matched brace; a stray `}` (malformed input)
767            // is treated as a literal rather than driving depth negative.
768            '}' if depth > 0 => depth -= 1,
769            '/' if depth == 0 => {
770                segments.push(&path[start..i]);
771                start = i + 1;
772            }
773            _ => {}
774        }
775    }
776    segments.push(&path[start..]);
777    segments
778}
779
780/// Convert a single top-level path segment from proto template to axum 0.8 form.
781///
782/// `is_last` indicates the terminal segment: axum permits a catch-all capture
783/// (`{*name}`) only there, so catch-alls in any other position must degrade.
784fn convert_segment(segment: &str, idx: usize, is_last: bool) -> String {
785    if let Some(inner) = segment.strip_prefix('{').and_then(|s| s.strip_suffix('}')) {
786        // Brace capture, possibly with a `name=template` field path.
787        if let Some((name, template)) = inner.split_once('=') {
788            return match template {
789                // Single-segment field path collapses to a plain capture.
790                "*" => format!("{{{name}}}"),
791                // Multi-segment catch-all maps to axum's `{*name}` (terminal only).
792                "**" => catch_all(name, is_last),
793                // Templates with interspersed literals (`{name=shelves/*/books/*}`)
794                // have no faithful axum form: axum cannot bind literal segments
795                // into one capture. Collapse to a catch-all so routing stays
796                // deterministic and the field still binds to the matched tail,
797                // and warn so the limitation surfaces instead of mis-routing.
798                _ => {
799                    tracing::warn!(
800                        template = %inner,
801                        "google.api.http multi-segment field template is not fully \
802                         supported; routing it as a catch-all capture"
803                    );
804                    catch_all(name, is_last)
805                }
806            };
807        }
808        // Plain `{name}` is already valid axum 0.8 syntax.
809        return format!("{{{inner}}}");
810    }
811
812    // Bare wildcards: name them by position so multiple wildcards never collide.
813    match segment {
814        "**" => catch_all(&format!("wildcard{idx}"), is_last),
815        "*" => format!("{{wildcard{idx}}}"),
816        literal => literal.to_string(),
817    }
818}
819
820/// Emit an axum catch-all `{*name}` when `is_last`, else degrade to a
821/// single-segment `{name}` capture.
822///
823/// axum accepts a catch-all only in the final path segment; a mid-path
824/// `{*name}` is rejected at `Router::route()`. A non-terminal catch-all comes
825/// from a malformed or unsupported google.api.http template, so we degrade
826/// (capturing one segment) and warn rather than panic the whole router.
827fn catch_all(name: &str, is_last: bool) -> String {
828    if is_last {
829        format!("{{*{name}}}")
830    } else {
831        tracing::warn!(
832            capture = %name,
833            "catch-all in a non-terminal path segment is unrepresentable in axum; \
834             degrading to a single-segment capture"
835        );
836        format!("{{{name}}}")
837    }
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843
844    /// Build a standalone `HttpRule`-shaped descriptor (self-referential
845    /// `additional_bindings`) so the binding parser can be tested without the
846    /// google.api extension wiring.
847    fn http_rule_descriptor() -> prost_reflect::MessageDescriptor {
848        use prost_reflect::prost::Message;
849        use prost_reflect::prost_types::{
850            field_descriptor_proto::{Label, Type},
851            DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
852        };
853
854        let str_field = |name: &str, num: i32| FieldDescriptorProto {
855            name: Some(name.to_string()),
856            number: Some(num),
857            label: Some(Label::Optional as i32),
858            r#type: Some(Type::String as i32),
859            ..Default::default()
860        };
861        let rule = DescriptorProto {
862            name: Some("HttpRule".to_string()),
863            field: vec![
864                str_field("get", 2),
865                str_field("put", 3),
866                str_field("post", 4),
867                str_field("delete", 5),
868                str_field("patch", 6),
869                str_field("body", 7),
870                str_field("response_body", 12),
871                FieldDescriptorProto {
872                    name: Some("additional_bindings".to_string()),
873                    number: Some(11),
874                    label: Some(Label::Repeated as i32),
875                    r#type: Some(Type::Message as i32),
876                    type_name: Some(".gapi.HttpRule".to_string()),
877                    ..Default::default()
878                },
879            ],
880            ..Default::default()
881        };
882        let file = FileDescriptorProto {
883            name: Some("http.proto".to_string()),
884            package: Some("gapi".to_string()),
885            message_type: vec![rule],
886            syntax: Some("proto3".to_string()),
887            ..Default::default()
888        };
889        let fds = FileDescriptorSet { file: vec![file] };
890        let pool = DescriptorPool::decode(fds.encode_to_vec().as_slice()).unwrap();
891        pool.get_message_by_name("gapi.HttpRule").unwrap()
892    }
893
894    #[test]
895    fn collect_bindings_reads_body_response_and_additional() {
896        let desc = http_rule_descriptor();
897
898        // additional_bindings entry: POST /v1/items with whole-body mapping.
899        let mut extra = DynamicMessage::new(desc.clone());
900        extra.set_field_by_name("post", prost_reflect::Value::String("/v1/items".into()));
901        extra.set_field_by_name("body", prost_reflect::Value::String("*".into()));
902
903        // primary rule: GET /v1/items/{id}, returns only the `result` subfield.
904        let mut rule = DynamicMessage::new(desc);
905        rule.set_field_by_name("get", prost_reflect::Value::String("/v1/items/{id}".into()));
906        rule.set_field_by_name(
907            "response_body",
908            prost_reflect::Value::String("result".into()),
909        );
910        rule.set_field_by_name(
911            "additional_bindings",
912            prost_reflect::Value::List(vec![prost_reflect::Value::Message(extra)]),
913        );
914
915        let bindings = collect_bindings(&rule);
916        assert_eq!(bindings.len(), 2);
917
918        // Primary: GET, no body, response_body = result.
919        assert!(matches!(bindings[0].http_method, HttpMethod::Get));
920        assert_eq!(bindings[0].http_path, "/v1/items/{id}");
921        assert_eq!(bindings[0].body, request::BodyMapping::None);
922        assert_eq!(bindings[0].response_body.as_deref(), Some("result"));
923
924        // Additional: POST, whole-body mapping, no response_body.
925        assert!(matches!(bindings[1].http_method, HttpMethod::Post));
926        assert_eq!(bindings[1].http_path, "/v1/items");
927        assert_eq!(bindings[1].body, request::BodyMapping::Root);
928        assert_eq!(bindings[1].response_body, None);
929    }
930
931    #[test]
932    fn test_proto_path_to_axum() {
933        // axum 0.8: proto `{param}` IS the native capture syntax, pass through verbatim.
934        assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/{id}");
935        assert_eq!(
936            proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
937            "/v1/admin/profiles/{profile_id}/metadata/{key}"
938        );
939        assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
940    }
941
942    #[test]
943    fn test_proto_path_to_axum_wildcards() {
944        // `{name=*}` single-segment field path collapses to a plain capture.
945        assert_eq!(proto_path_to_axum("/v1/{name=*}"), "/v1/{name}");
946        // `{name=**}` multi-segment catch-all maps to axum's `{*name}`.
947        assert_eq!(
948            proto_path_to_axum("/v1/files/{path=**}"),
949            "/v1/files/{*path}"
950        );
951        // Bare wildcards get position-named captures so they never collide.
952        // Index is the segment position after splitting on `/` (leading "" = 0).
953        assert_eq!(proto_path_to_axum("/v1/*/items"), "/v1/{wildcard2}/items");
954        assert_eq!(proto_path_to_axum("/v1/files/**"), "/v1/files/{*wildcard3}");
955    }
956
957    #[test]
958    fn non_terminal_catch_all_degrades_to_single_capture() {
959        // A catch-all `{*name}` is only valid in axum's LAST path segment.
960        // An unsupported/multi-segment field template in a NON-terminal position
961        // (`/v1/{name=projects/*}/topics`) must NOT emit a mid-path catch-all —
962        // axum rejects `/v1/{*name}/topics` at `Router::route()`. It degrades to
963        // a single-segment capture instead.
964        assert_eq!(
965            proto_path_to_axum("/v1/{name=projects/*}/topics"),
966            "/v1/{name}/topics"
967        );
968        let path = proto_path_to_axum("/v1/{name=projects/*}/topics");
969        let _router: Router<()> = Router::new().route(&path, get(|| async { "ok" }));
970
971        // The same guard applies to an explicit `**` template in non-terminal
972        // position and a terminal one still yields a real catch-all.
973        assert_eq!(proto_path_to_axum("/v1/{rest=**}/tail"), "/v1/{rest}/tail");
974        assert_eq!(
975            proto_path_to_axum("/v1/files/{rest=**}"),
976            "/v1/files/{*rest}"
977        );
978    }
979
980    #[test]
981    fn multi_segment_field_template_does_not_fracture() {
982        // google.api.http resource-name templates (AIP-127) embed slashes
983        // inside a SINGLE brace span: `{name=shelves/*/books/*}`. Splitting on
984        // `/` before brace parsing fractured this into invalid fragments and
985        // produced a mangled axum path that panicked at `Router::route()`.
986        // It must collapse to a single catch-all capture instead.
987        assert_eq!(
988            proto_path_to_axum("/v1/{name=shelves/*/books/*}"),
989            "/v1/{*name}"
990        );
991        // And the produced path must actually register on axum 0.8.
992        let path = proto_path_to_axum("/v1/{name=shelves/*/books/*}");
993        let _router: Router<()> = Router::new().route(&path, get(|| async { "ok" }));
994    }
995
996    /// Regression for the axum 0.7→0.8 migration bug: `proto_path_to_axum`
997    /// emitted `:id` syntax, which axum 0.8 rejects at `Router::route()` with
998    /// a startup panic ("Path segments must not start with `:`"). Building the
999    /// router over a brace-param path must NOT panic. Pre-fix this panicked.
1000    #[test]
1001    fn router_builds_with_brace_path_params_on_axum_0_8() {
1002        let axum_path = proto_path_to_axum("/v1/profiles/{id}");
1003        let _router: Router<()> = Router::new().route(&axum_path, get(|| async { "ok" }));
1004
1005        // Deeper nesting and a catch-all also route without panicking.
1006        let nested = proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}");
1007        let catch_all = proto_path_to_axum("/v1/files/{path=**}");
1008        let _router: Router<()> = Router::new()
1009            .route(&nested, get(|| async { "ok" }))
1010            .route(&catch_all, get(|| async { "ok" }));
1011    }
1012
1013    /// `Item { name: "alice", count: 42 }` — default fixture for the
1014    /// serialization helpers.
1015    fn item_message() -> DynamicMessage {
1016        item_message_named("alice", 42)
1017    }
1018
1019    /// Build an `Item { name, count }` message from a freshly-decoded
1020    /// descriptor pool, used to exercise the streaming serialization helpers.
1021    fn item_message_named(name: &str, count: i64) -> DynamicMessage {
1022        use prost_reflect::prost::Message;
1023        use prost_reflect::prost_types::{
1024            field_descriptor_proto::{Label, Type},
1025            DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
1026        };
1027
1028        let item = DescriptorProto {
1029            name: Some("Item".to_string()),
1030            field: vec![
1031                FieldDescriptorProto {
1032                    name: Some("name".to_string()),
1033                    number: Some(1),
1034                    label: Some(Label::Optional as i32),
1035                    r#type: Some(Type::String as i32),
1036                    ..Default::default()
1037                },
1038                FieldDescriptorProto {
1039                    name: Some("count".to_string()),
1040                    number: Some(2),
1041                    label: Some(Label::Optional as i32),
1042                    r#type: Some(Type::Int64 as i32),
1043                    ..Default::default()
1044                },
1045            ],
1046            ..Default::default()
1047        };
1048        let file = FileDescriptorProto {
1049            name: Some("item.proto".to_string()),
1050            package: Some("test.v1".to_string()),
1051            message_type: vec![item],
1052            syntax: Some("proto3".to_string()),
1053            ..Default::default()
1054        };
1055        let mut bytes = Vec::new();
1056        FileDescriptorSet { file: vec![file] }
1057            .encode(&mut bytes)
1058            .unwrap();
1059        let pool = DescriptorPool::decode(bytes.as_slice()).unwrap();
1060        let desc = pool.get_message_by_name("test.v1.Item").unwrap();
1061
1062        let mut msg = DynamicMessage::new(desc);
1063        msg.set_field_by_name("name", prost_reflect::Value::String(name.to_string()));
1064        msg.set_field_by_name("count", prost_reflect::Value::I64(count));
1065        msg
1066    }
1067
1068    /// Collect a streaming response body into a single UTF-8 string.
1069    async fn collect_body(resp: Response) -> String {
1070        let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
1071            .await
1072            .unwrap();
1073        String::from_utf8(bytes.to_vec()).unwrap()
1074    }
1075
1076    #[tokio::test]
1077    async fn ndjson_error_frame_is_terminal() {
1078        // A gRPC error mid-stream must be the LAST frame: messages the upstream
1079        // would yield after the error are dropped, so the error line is an
1080        // unambiguous end-of-stream signal rather than a mid-stream marker.
1081        let items = vec![
1082            Ok(item_message_named("alice", 1)),
1083            Err(tonic::Status::internal("boom")),
1084            Ok(item_message_named("bob", 2)),
1085        ];
1086        let body = collect_body(ndjson_response(futures::stream::iter(items))).await;
1087        let lines: Vec<&str> = body.lines().collect();
1088        assert_eq!(lines.len(), 2, "stream must stop after the error frame");
1089        assert!(lines[0].contains("alice"));
1090        assert!(lines[1].contains("INTERNAL") && lines[1].contains("boom"));
1091        assert!(!body.contains("bob"), "post-error message must be dropped");
1092    }
1093
1094    #[tokio::test]
1095    async fn sse_error_uses_distinct_event_name() {
1096        // The terminal error is sent as `event: stream-error`, not the reserved
1097        // `error` type that collides with the browser EventSource onerror.
1098        let items = vec![
1099            Ok(item_message_named("alice", 1)),
1100            Err(tonic::Status::permission_denied("nope")),
1101            Ok(item_message_named("bob", 2)),
1102        ];
1103        let body = collect_body(sse_response(futures::stream::iter(items), 15)).await;
1104        assert!(body.contains("stream-error"));
1105        assert!(body.contains("PERMISSION_DENIED"));
1106        assert!(!body.contains("bob"), "post-error message must be dropped");
1107    }
1108
1109    #[test]
1110    fn wants_sse_detects_event_stream_accept() {
1111        let mut headers = HeaderMap::new();
1112        headers.insert("accept", "text/event-stream".parse().unwrap());
1113        assert!(wants_sse(&headers));
1114    }
1115
1116    #[test]
1117    fn wants_sse_matches_within_list_and_ignores_params() {
1118        let mut headers = HeaderMap::new();
1119        headers.insert(
1120            "accept",
1121            "application/json, text/event-stream;q=0.9".parse().unwrap(),
1122        );
1123        assert!(wants_sse(&headers));
1124    }
1125
1126    #[test]
1127    fn wants_sse_false_for_json_and_missing() {
1128        let mut headers = HeaderMap::new();
1129        headers.insert("accept", "application/json".parse().unwrap());
1130        assert!(!wants_sse(&headers));
1131        assert!(!wants_sse(&HeaderMap::new()));
1132    }
1133
1134    #[test]
1135    fn wants_sse_rejects_explicit_q_zero() {
1136        // RFC 7231 §5.3.1: `q=0` means the media type is explicitly NOT
1137        // acceptable, so it must not select the SSE path.
1138        let mut headers = HeaderMap::new();
1139        headers.insert("accept", "text/event-stream;q=0".parse().unwrap());
1140        assert!(!wants_sse(&headers));
1141    }
1142
1143    #[test]
1144    fn wants_sse_honors_second_accept_header_line() {
1145        // A client may send multiple `Accept` header lines; the negotiation
1146        // must consider all of them, not just the first.
1147        let mut headers = HeaderMap::new();
1148        headers.append("accept", "application/json".parse().unwrap());
1149        headers.append("accept", "text/event-stream".parse().unwrap());
1150        assert!(wants_sse(&headers));
1151    }
1152
1153    #[test]
1154    fn message_to_json_string_stringifies_64bit() {
1155        let opts = response_serialize_options();
1156        let json = message_to_json_string(&item_message(), &opts).unwrap();
1157        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
1158        assert_eq!(value["name"], "alice");
1159        // 64-bit integers are stringified to survive JS number precision limits.
1160        assert_eq!(value["count"], "42");
1161    }
1162
1163    #[test]
1164    fn ndjson_response_omits_manual_transfer_encoding() {
1165        // hyper picks the framing per protocol version; a hand-set
1166        // transfer-encoding would be illegal on HTTP/2.
1167        let resp = ndjson_response(futures::stream::empty::<
1168            Result<DynamicMessage, tonic::Status>,
1169        >());
1170        assert_eq!(
1171            resp.headers().get("content-type").unwrap(),
1172            "application/x-ndjson"
1173        );
1174        assert!(resp.headers().get("transfer-encoding").is_none());
1175    }
1176
1177    #[test]
1178    fn stream_error_json_carries_grpc_code_name() {
1179        let status = tonic::Status::permission_denied("nope");
1180        let value = stream_error_json(&status);
1181        assert_eq!(value["error"], "PERMISSION_DENIED");
1182        assert_eq!(value["message"], "nope");
1183        assert_eq!(value["code"], tonic::Code::PermissionDenied as i32);
1184    }
1185}