1pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12pub mod request;
13
14use axum::extract::{Path, RawQuery, State};
15use axum::http::{HeaderMap, StatusCode};
16use axum::response::sse::{Event, KeepAlive, Sse};
17use axum::response::{IntoResponse, Response};
18use axum::routing::{delete, get, patch, post, put, MethodRouter};
19use axum::{Json, Router};
20use futures::StreamExt;
21use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
22use tonic::client::Grpc;
23
24use crate::config::AliasConfig;
25
26pub trait TranscodeState: Clone + Send + Sync + 'static {
31 fn grpc_channel(&self) -> tonic::transport::Channel;
33 fn forwarded_headers(&self) -> &[String];
35 fn sse_keep_alive_secs(&self) -> u64;
37}
38
39impl TranscodeState for crate::ProxyState {
40 fn grpc_channel(&self) -> tonic::transport::Channel {
41 self.grpc_channel.clone()
42 }
43 fn forwarded_headers(&self) -> &[String] {
44 &self.forwarded_headers
45 }
46 fn sse_keep_alive_secs(&self) -> u64 {
47 self.sse_keep_alive_secs
48 }
49}
50
51#[derive(Debug, Clone)]
53struct RouteEntry {
54 http_path: String,
56 http_method: HttpMethod,
58 grpc_path: axum::http::uri::PathAndQuery,
61 method: MethodDescriptor,
63 body: request::BodyMapping,
65 response_body: Option<String>,
67}
68
69#[derive(Debug, Clone, Copy)]
70enum HttpMethod {
71 Get,
72 Post,
73 Put,
74 Patch,
75 Delete,
76}
77
78pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
83 let entries = extract_routes(pool);
84 if entries.is_empty() {
85 tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
86 return Router::new();
87 }
88
89 tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
90
91 let mut router: Router<S> = Router::new();
92 for entry in &entries {
93 let entry_clone = std::sync::Arc::new(entry.clone());
94
95 let handler = move |proxy_state: State<S>,
96 headers: HeaderMap,
97 path_params: Path<std::collections::HashMap<String, String>>,
98 raw_query: RawQuery,
99 body: axum::body::Bytes| {
100 transcode_handler(
101 proxy_state,
102 headers,
103 path_params,
104 raw_query,
105 body,
106 entry_clone,
107 )
108 };
109
110 let method_router: MethodRouter<S> = match entry.http_method {
111 HttpMethod::Get => get(handler),
112 HttpMethod::Post => post(handler),
113 HttpMethod::Put => put(handler),
114 HttpMethod::Patch => patch(handler),
115 HttpMethod::Delete => delete(handler),
116 };
117
118 let axum_path = proto_path_to_axum(&entry.http_path);
119 router = router.route(&axum_path, method_router);
120
121 for alias in aliases {
123 if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
124 let alias_path = if alias.from.ends_with("/{path}") {
126 let prefix = alias.from.trim_end_matches("/{path}");
127 format!("{}{}", prefix, suffix)
128 } else {
129 continue;
130 };
131
132 let alias_entry = std::sync::Arc::new(entry.clone());
133 let alias_handler =
134 move |proxy_state: State<S>,
135 headers: HeaderMap,
136 path_params: Path<std::collections::HashMap<String, String>>,
137 raw_query: RawQuery,
138 body: axum::body::Bytes| {
139 transcode_handler(
140 proxy_state,
141 headers,
142 path_params,
143 raw_query,
144 body,
145 alias_entry,
146 )
147 };
148 let alias_method: MethodRouter<S> = match entry.http_method {
149 HttpMethod::Get => get(alias_handler),
150 HttpMethod::Post => post(alias_handler),
151 HttpMethod::Put => put(alias_handler),
152 HttpMethod::Patch => patch(alias_handler),
153 HttpMethod::Delete => delete(alias_handler),
154 };
155 router = router.route(&alias_path, alias_method);
156 }
157 }
158 }
159
160 let streaming_entries = extract_streaming_routes(pool);
162 for entry in &streaming_entries {
163 let entry_clone = std::sync::Arc::new(entry.clone());
164 let axum_path = proto_path_to_axum(&entry.http_path);
165
166 let handler = move |proxy_state: State<S>, headers: HeaderMap| {
167 streaming_handler(proxy_state, headers, entry_clone)
168 };
169
170 let method_router: MethodRouter<S> = match entry.http_method {
171 HttpMethod::Get => get(handler),
172 HttpMethod::Post => post(handler),
173 _ => continue,
174 };
175
176 router = router.route(&axum_path, method_router);
177 }
178
179 router
180}
181
182fn response_serialize_options() -> SerializeOptions {
185 SerializeOptions::new()
186 .skip_default_fields(false)
187 .stringify_64_bit_integers(true)
188}
189
190fn message_to_json_string(msg: &DynamicMessage, opts: &SerializeOptions) -> Result<String, String> {
192 let value = msg
193 .serialize_with_options(serde_json::value::Serializer, opts)
194 .map_err(|e| e.to_string())?;
195 serde_json::to_string(&value).map_err(|e| e.to_string())
196}
197
198fn stream_error_json(status: &tonic::Status) -> serde_json::Value {
201 serde_json::json!({
202 "error": error::grpc_code_name(status.code()),
203 "message": status.message(),
204 "code": status.code() as i32,
205 })
206}
207
208fn wants_sse(headers: &HeaderMap) -> bool {
216 headers
217 .get_all(axum::http::header::ACCEPT)
218 .iter()
219 .filter_map(|v| v.to_str().ok())
220 .flat_map(|accept| accept.split(','))
221 .any(accept_range_selects_sse)
222}
223
224fn accept_range_selects_sse(range: &str) -> bool {
227 let mut parts = range.split(';');
228 let media = parts.next().unwrap_or("").trim();
229 if !media.eq_ignore_ascii_case("text/event-stream") {
230 return false;
231 }
232 for param in parts {
235 let mut kv = param.splitn(2, '=');
236 if kv.next().unwrap_or("").trim().eq_ignore_ascii_case("q") {
237 let q: f32 = kv.next().unwrap_or("").trim().parse().unwrap_or(1.0);
238 return q > 0.0;
239 }
240 }
241 true
242}
243
244async fn streaming_handler<S: TranscodeState>(
251 State(proxy_state): State<S>,
252 headers: HeaderMap,
253 entry: std::sync::Arc<RouteEntry>,
254) -> Response {
255 let channel = proxy_state.grpc_channel();
256
257 let input_desc = entry.method.input();
258 let request_msg = DynamicMessage::new(input_desc);
259
260 let grpc_metadata =
261 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
262 let mut grpc_request = tonic::Request::new(request_msg);
263 *grpc_request.metadata_mut() = grpc_metadata;
264 metadata::apply_request_deadline(&mut grpc_request, &headers);
265
266 let output_desc = entry.method.output();
267 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
268 let grpc_path = entry.grpc_path.clone();
269
270 let mut grpc_client = Grpc::new(channel);
271 if let Err(e) = grpc_client.ready().await {
272 return (
273 StatusCode::SERVICE_UNAVAILABLE,
274 Json(serde_json::json!({
275 "error": "UNAVAILABLE",
276 "message": format!("gRPC upstream not ready: {e}"),
277 })),
278 )
279 .into_response();
280 }
281
282 let use_sse = wants_sse(&headers);
283
284 match grpc_client
285 .server_streaming(grpc_request, grpc_path, grpc_codec)
286 .await
287 {
288 Ok(response) => {
289 let stream = response.into_inner();
290 if use_sse {
291 sse_response(stream, proxy_state.sse_keep_alive_secs())
292 } else {
293 ndjson_response(stream)
294 }
295 }
296 Err(status) => error::status_to_response(status),
297 }
298}
299
300enum StreamFrame {
306 Data(String),
307 Error(String),
308}
309
310fn json_frames<St>(stream: St) -> impl futures::Stream<Item = StreamFrame> + Send + 'static
317where
318 St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
319{
320 let opts = response_serialize_options();
321 stream.scan(false, move |stopped, result| {
322 if *stopped {
323 return futures::future::ready(None);
324 }
325 let frame = match result {
326 Ok(msg) => match message_to_json_string(&msg, &opts) {
327 Ok(s) => StreamFrame::Data(s),
328 Err(e) => {
329 *stopped = true;
330 StreamFrame::Error(
331 serde_json::json!({
332 "error": "INTERNAL",
333 "message": format!("serialization error: {e}"),
334 })
335 .to_string(),
336 )
337 }
338 },
339 Err(status) => {
340 *stopped = true;
341 StreamFrame::Error(stream_error_json(&status).to_string())
342 }
343 };
344 futures::future::ready(Some(frame))
345 })
346}
347
348fn ndjson_response<St>(stream: St) -> Response
350where
351 St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
352{
353 let byte_stream = json_frames(stream).map(|frame| {
356 let mut line = match frame {
357 StreamFrame::Data(s) | StreamFrame::Error(s) => s,
358 };
359 line.push('\n');
360 Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(line))
361 });
362
363 let body = axum::body::Body::from_stream(byte_stream);
364 Response::builder()
368 .status(StatusCode::OK)
369 .header("content-type", "application/x-ndjson")
370 .body(body)
371 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
372}
373
374fn sse_response<St>(stream: St, keep_alive_secs: u64) -> Response
376where
377 St: futures::Stream<Item = Result<DynamicMessage, tonic::Status>> + Send + 'static,
378{
379 let event_stream = json_frames(stream).map(|frame| {
383 let event = match frame {
384 StreamFrame::Data(s) => Event::default().data(s),
385 StreamFrame::Error(s) => Event::default().event("stream-error").data(s),
386 };
387 Ok::<Event, std::convert::Infallible>(event)
388 });
389
390 Sse::new(event_stream)
391 .keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(keep_alive_secs)))
392 .into_response()
393}
394
395async fn transcode_handler<S: TranscodeState>(
397 State(proxy_state): State<S>,
398 headers: HeaderMap,
399 Path(path_params): Path<std::collections::HashMap<String, String>>,
400 RawQuery(raw_query): RawQuery,
401 body_bytes: axum::body::Bytes,
402 entry: std::sync::Arc<RouteEntry>,
403) -> Response {
404 let channel = proxy_state.grpc_channel();
405
406 let json_body = match entry.body {
408 request::BodyMapping::None => serde_json::Value::Null,
409 _ => {
410 let ct = body::content_type(&headers);
411 match body::parse_body(ct, &body_bytes) {
412 Ok(v) => v,
413 Err(e) => {
414 return (
415 StatusCode::BAD_REQUEST,
416 Json(serde_json::json!({
417 "error": "INVALID_ARGUMENT",
418 "message": format!("failed to parse request body: {e}"),
419 })),
420 )
421 .into_response();
422 }
423 }
424 }
425 };
426
427 let query_pairs = match request::parse_query(raw_query.as_deref()) {
431 Ok(pairs) => pairs,
432 Err(e) => {
433 return (
434 StatusCode::BAD_REQUEST,
435 Json(serde_json::json!({
436 "error": "INVALID_ARGUMENT",
437 "message": e,
438 })),
439 )
440 .into_response();
441 }
442 };
443
444 let input_desc = entry.method.input();
445 let request_json = match request::build_request_json(
446 &input_desc,
447 &entry.body,
448 json_body,
449 &path_params,
450 &query_pairs,
451 ) {
452 Ok(v) => v,
453 Err(e) => {
454 return (
455 StatusCode::BAD_REQUEST,
456 Json(serde_json::json!({
457 "error": "INVALID_ARGUMENT",
458 "message": e,
459 })),
460 )
461 .into_response();
462 }
463 };
464
465 let request_msg = match DynamicMessage::deserialize(input_desc, request_json) {
466 Ok(msg) => msg,
467 Err(e) => {
468 return (
469 StatusCode::BAD_REQUEST,
470 Json(serde_json::json!({
471 "error": "INVALID_ARGUMENT",
472 "message": format!("failed to decode request: {e}"),
473 })),
474 )
475 .into_response();
476 }
477 };
478
479 let grpc_metadata =
480 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
481 let mut grpc_request = tonic::Request::new(request_msg);
482 *grpc_request.metadata_mut() = grpc_metadata;
483 metadata::apply_request_deadline(&mut grpc_request, &headers);
484
485 let output_desc = entry.method.output();
486 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
487 let grpc_path = entry.grpc_path.clone();
488
489 let mut grpc_client = Grpc::new(channel);
490 if let Err(e) = grpc_client.ready().await {
491 return (
492 StatusCode::SERVICE_UNAVAILABLE,
493 Json(serde_json::json!({
494 "error": "UNAVAILABLE",
495 "message": format!("gRPC upstream not ready: {e}"),
496 })),
497 )
498 .into_response();
499 }
500
501 match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
502 Ok(response) => {
503 let response_msg = response.into_inner();
504 let serialize_opts = response_serialize_options();
505 match response_msg
506 .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
507 {
508 Ok(json_value) => {
509 let out = match &entry.response_body {
511 Some(path) => request::extract_response_body(&json_value, path)
512 .unwrap_or_else(|| {
513 tracing::warn!(
514 response_body = %path,
515 "configured response_body path not found in response; \
516 returning null"
517 );
518 serde_json::Value::Null
519 }),
520 None => json_value,
521 };
522 (StatusCode::OK, Json(out)).into_response()
523 }
524 Err(e) => {
525 tracing::error!("Failed to serialize gRPC response: {e}");
526 (
527 StatusCode::INTERNAL_SERVER_ERROR,
528 Json(serde_json::json!({
529 "error": "INTERNAL",
530 "message": "failed to serialize response",
531 })),
532 )
533 .into_response()
534 }
535 }
536 }
537 Err(status) => error::status_to_response(status),
538 }
539}
540
541fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
543 let http_ext = match pool.get_extension_by_name("google.api.http") {
544 Some(ext) => ext,
545 None => {
546 tracing::warn!("google.api.http extension not found in descriptor pool");
547 return Vec::new();
548 }
549 };
550
551 let mut entries = Vec::new();
552
553 for service in pool.services() {
554 for method in service.methods() {
555 if method.is_client_streaming() || method.is_server_streaming() {
556 continue;
557 }
558
559 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
560 let grpc_path: axum::http::uri::PathAndQuery = match grpc_path.parse() {
561 Ok(p) => p,
562 Err(e) => {
563 tracing::error!("skipping route with invalid gRPC path '{grpc_path}': {e}");
564 continue;
565 }
566 };
567
568 for binding in extract_http_bindings(&method, &http_ext) {
569 entries.push(RouteEntry {
570 http_path: binding.http_path,
571 http_method: binding.http_method,
572 grpc_path: grpc_path.clone(),
573 method: method.clone(),
574 body: binding.body,
575 response_body: binding.response_body,
576 });
577 }
578 }
579 }
580
581 entries
582}
583
584fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
586 let http_ext = match pool.get_extension_by_name("google.api.http") {
587 Some(ext) => ext,
588 None => return Vec::new(),
589 };
590
591 let mut entries = Vec::new();
592
593 for service in pool.services() {
594 for method in service.methods() {
595 if !method.is_server_streaming() || method.is_client_streaming() {
596 continue;
597 }
598
599 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
600 let grpc_path: axum::http::uri::PathAndQuery = match grpc_path.parse() {
601 Ok(p) => p,
602 Err(e) => {
603 tracing::error!("skipping route with invalid gRPC path '{grpc_path}': {e}");
604 continue;
605 }
606 };
607
608 for binding in extract_http_bindings(&method, &http_ext) {
609 tracing::info!(
610 "Registering streaming route: {} {} → {}",
611 match binding.http_method {
612 HttpMethod::Get => "GET",
613 HttpMethod::Post => "POST",
614 _ => "OTHER",
615 },
616 binding.http_path,
617 grpc_path
618 );
619 entries.push(RouteEntry {
620 http_path: binding.http_path,
621 http_method: binding.http_method,
622 grpc_path: grpc_path.clone(),
623 method: method.clone(),
624 body: binding.body,
625 response_body: binding.response_body,
626 });
627 }
628 }
629 }
630
631 entries
632}
633
634struct HttpBinding {
636 http_method: HttpMethod,
637 http_path: String,
638 body: request::BodyMapping,
639 response_body: Option<String>,
640}
641
642fn extract_http_bindings(
645 method: &MethodDescriptor,
646 http_ext: &prost_reflect::ExtensionDescriptor,
647) -> Vec<HttpBinding> {
648 let options = method.options();
649 if !options.has_extension(http_ext) {
650 return Vec::new();
651 }
652
653 let prost_reflect::Value::Message(rule_msg) = options.get_extension(http_ext).into_owned()
654 else {
655 return Vec::new();
656 };
657
658 collect_bindings(&rule_msg)
659}
660
661fn collect_bindings(rule_msg: &DynamicMessage) -> Vec<HttpBinding> {
664 let mut bindings = Vec::new();
665 if let Some(binding) = parse_http_rule(rule_msg) {
666 bindings.push(binding);
667 }
668
669 if let Some(field) = rule_msg.get_field_by_name("additional_bindings") {
672 if let prost_reflect::Value::List(list) = field.into_owned() {
673 for item in list {
674 if let prost_reflect::Value::Message(sub) = item {
675 if let Some(binding) = parse_http_rule(&sub) {
676 bindings.push(binding);
677 }
678 }
679 }
680 }
681 }
682
683 bindings
684}
685
686fn parse_http_rule(rule_msg: &DynamicMessage) -> Option<HttpBinding> {
688 let (http_method, http_path) = [
689 ("get", HttpMethod::Get),
690 ("post", HttpMethod::Post),
691 ("put", HttpMethod::Put),
692 ("delete", HttpMethod::Delete),
693 ("patch", HttpMethod::Patch),
694 ]
695 .into_iter()
696 .find_map(
697 |(name, http_method)| match rule_msg.get_field_by_name(name)?.into_owned() {
698 prost_reflect::Value::String(path) if !path.is_empty() => Some((http_method, path)),
699 _ => None,
700 },
701 )?;
702
703 let body = rule_msg
704 .get_field_by_name("body")
705 .and_then(|v| match v.into_owned() {
706 prost_reflect::Value::String(s) => Some(request::BodyMapping::parse(&s)),
707 _ => None,
708 })
709 .unwrap_or(request::BodyMapping::None);
710
711 let response_body =
712 rule_msg
713 .get_field_by_name("response_body")
714 .and_then(|v| match v.into_owned() {
715 prost_reflect::Value::String(s) if !s.is_empty() => Some(s),
716 _ => None,
717 });
718
719 Some(HttpBinding {
720 http_method,
721 http_path,
722 body,
723 response_body,
724 })
725}
726
727pub fn proto_path_to_axum(path: &str) -> String {
738 let mut out = String::with_capacity(path.len());
739
740 let segments = split_top_level(path);
741 let last = segments.len().saturating_sub(1);
742 for (idx, segment) in segments.iter().enumerate() {
743 if idx > 0 {
744 out.push('/');
745 }
746 out.push_str(&convert_segment(segment, idx, idx == last));
747 }
748
749 out
750}
751
752fn split_top_level(path: &str) -> Vec<&str> {
759 let mut segments = Vec::new();
760 let mut depth = 0usize;
761 let mut start = 0usize;
762
763 for (i, ch) in path.char_indices() {
764 match ch {
765 '{' => depth += 1,
766 '}' if depth > 0 => depth -= 1,
769 '/' if depth == 0 => {
770 segments.push(&path[start..i]);
771 start = i + 1;
772 }
773 _ => {}
774 }
775 }
776 segments.push(&path[start..]);
777 segments
778}
779
780fn convert_segment(segment: &str, idx: usize, is_last: bool) -> String {
785 if let Some(inner) = segment.strip_prefix('{').and_then(|s| s.strip_suffix('}')) {
786 if let Some((name, template)) = inner.split_once('=') {
788 return match template {
789 "*" => format!("{{{name}}}"),
791 "**" => catch_all(name, is_last),
793 _ => {
799 tracing::warn!(
800 template = %inner,
801 "google.api.http multi-segment field template is not fully \
802 supported; routing it as a catch-all capture"
803 );
804 catch_all(name, is_last)
805 }
806 };
807 }
808 return format!("{{{inner}}}");
810 }
811
812 match segment {
814 "**" => catch_all(&format!("wildcard{idx}"), is_last),
815 "*" => format!("{{wildcard{idx}}}"),
816 literal => literal.to_string(),
817 }
818}
819
820fn catch_all(name: &str, is_last: bool) -> String {
828 if is_last {
829 format!("{{*{name}}}")
830 } else {
831 tracing::warn!(
832 capture = %name,
833 "catch-all in a non-terminal path segment is unrepresentable in axum; \
834 degrading to a single-segment capture"
835 );
836 format!("{{{name}}}")
837 }
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843
844 fn http_rule_descriptor() -> prost_reflect::MessageDescriptor {
848 use prost_reflect::prost::Message;
849 use prost_reflect::prost_types::{
850 field_descriptor_proto::{Label, Type},
851 DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
852 };
853
854 let str_field = |name: &str, num: i32| FieldDescriptorProto {
855 name: Some(name.to_string()),
856 number: Some(num),
857 label: Some(Label::Optional as i32),
858 r#type: Some(Type::String as i32),
859 ..Default::default()
860 };
861 let rule = DescriptorProto {
862 name: Some("HttpRule".to_string()),
863 field: vec![
864 str_field("get", 2),
865 str_field("put", 3),
866 str_field("post", 4),
867 str_field("delete", 5),
868 str_field("patch", 6),
869 str_field("body", 7),
870 str_field("response_body", 12),
871 FieldDescriptorProto {
872 name: Some("additional_bindings".to_string()),
873 number: Some(11),
874 label: Some(Label::Repeated as i32),
875 r#type: Some(Type::Message as i32),
876 type_name: Some(".gapi.HttpRule".to_string()),
877 ..Default::default()
878 },
879 ],
880 ..Default::default()
881 };
882 let file = FileDescriptorProto {
883 name: Some("http.proto".to_string()),
884 package: Some("gapi".to_string()),
885 message_type: vec![rule],
886 syntax: Some("proto3".to_string()),
887 ..Default::default()
888 };
889 let fds = FileDescriptorSet { file: vec![file] };
890 let pool = DescriptorPool::decode(fds.encode_to_vec().as_slice()).unwrap();
891 pool.get_message_by_name("gapi.HttpRule").unwrap()
892 }
893
894 #[test]
895 fn collect_bindings_reads_body_response_and_additional() {
896 let desc = http_rule_descriptor();
897
898 let mut extra = DynamicMessage::new(desc.clone());
900 extra.set_field_by_name("post", prost_reflect::Value::String("/v1/items".into()));
901 extra.set_field_by_name("body", prost_reflect::Value::String("*".into()));
902
903 let mut rule = DynamicMessage::new(desc);
905 rule.set_field_by_name("get", prost_reflect::Value::String("/v1/items/{id}".into()));
906 rule.set_field_by_name(
907 "response_body",
908 prost_reflect::Value::String("result".into()),
909 );
910 rule.set_field_by_name(
911 "additional_bindings",
912 prost_reflect::Value::List(vec![prost_reflect::Value::Message(extra)]),
913 );
914
915 let bindings = collect_bindings(&rule);
916 assert_eq!(bindings.len(), 2);
917
918 assert!(matches!(bindings[0].http_method, HttpMethod::Get));
920 assert_eq!(bindings[0].http_path, "/v1/items/{id}");
921 assert_eq!(bindings[0].body, request::BodyMapping::None);
922 assert_eq!(bindings[0].response_body.as_deref(), Some("result"));
923
924 assert!(matches!(bindings[1].http_method, HttpMethod::Post));
926 assert_eq!(bindings[1].http_path, "/v1/items");
927 assert_eq!(bindings[1].body, request::BodyMapping::Root);
928 assert_eq!(bindings[1].response_body, None);
929 }
930
931 #[test]
932 fn test_proto_path_to_axum() {
933 assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/{id}");
935 assert_eq!(
936 proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
937 "/v1/admin/profiles/{profile_id}/metadata/{key}"
938 );
939 assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
940 }
941
942 #[test]
943 fn test_proto_path_to_axum_wildcards() {
944 assert_eq!(proto_path_to_axum("/v1/{name=*}"), "/v1/{name}");
946 assert_eq!(
948 proto_path_to_axum("/v1/files/{path=**}"),
949 "/v1/files/{*path}"
950 );
951 assert_eq!(proto_path_to_axum("/v1/*/items"), "/v1/{wildcard2}/items");
954 assert_eq!(proto_path_to_axum("/v1/files/**"), "/v1/files/{*wildcard3}");
955 }
956
957 #[test]
958 fn non_terminal_catch_all_degrades_to_single_capture() {
959 assert_eq!(
965 proto_path_to_axum("/v1/{name=projects/*}/topics"),
966 "/v1/{name}/topics"
967 );
968 let path = proto_path_to_axum("/v1/{name=projects/*}/topics");
969 let _router: Router<()> = Router::new().route(&path, get(|| async { "ok" }));
970
971 assert_eq!(proto_path_to_axum("/v1/{rest=**}/tail"), "/v1/{rest}/tail");
974 assert_eq!(
975 proto_path_to_axum("/v1/files/{rest=**}"),
976 "/v1/files/{*rest}"
977 );
978 }
979
980 #[test]
981 fn multi_segment_field_template_does_not_fracture() {
982 assert_eq!(
988 proto_path_to_axum("/v1/{name=shelves/*/books/*}"),
989 "/v1/{*name}"
990 );
991 let path = proto_path_to_axum("/v1/{name=shelves/*/books/*}");
993 let _router: Router<()> = Router::new().route(&path, get(|| async { "ok" }));
994 }
995
996 #[test]
1001 fn router_builds_with_brace_path_params_on_axum_0_8() {
1002 let axum_path = proto_path_to_axum("/v1/profiles/{id}");
1003 let _router: Router<()> = Router::new().route(&axum_path, get(|| async { "ok" }));
1004
1005 let nested = proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}");
1007 let catch_all = proto_path_to_axum("/v1/files/{path=**}");
1008 let _router: Router<()> = Router::new()
1009 .route(&nested, get(|| async { "ok" }))
1010 .route(&catch_all, get(|| async { "ok" }));
1011 }
1012
1013 fn item_message() -> DynamicMessage {
1016 item_message_named("alice", 42)
1017 }
1018
1019 fn item_message_named(name: &str, count: i64) -> DynamicMessage {
1022 use prost_reflect::prost::Message;
1023 use prost_reflect::prost_types::{
1024 field_descriptor_proto::{Label, Type},
1025 DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
1026 };
1027
1028 let item = DescriptorProto {
1029 name: Some("Item".to_string()),
1030 field: vec![
1031 FieldDescriptorProto {
1032 name: Some("name".to_string()),
1033 number: Some(1),
1034 label: Some(Label::Optional as i32),
1035 r#type: Some(Type::String as i32),
1036 ..Default::default()
1037 },
1038 FieldDescriptorProto {
1039 name: Some("count".to_string()),
1040 number: Some(2),
1041 label: Some(Label::Optional as i32),
1042 r#type: Some(Type::Int64 as i32),
1043 ..Default::default()
1044 },
1045 ],
1046 ..Default::default()
1047 };
1048 let file = FileDescriptorProto {
1049 name: Some("item.proto".to_string()),
1050 package: Some("test.v1".to_string()),
1051 message_type: vec![item],
1052 syntax: Some("proto3".to_string()),
1053 ..Default::default()
1054 };
1055 let mut bytes = Vec::new();
1056 FileDescriptorSet { file: vec![file] }
1057 .encode(&mut bytes)
1058 .unwrap();
1059 let pool = DescriptorPool::decode(bytes.as_slice()).unwrap();
1060 let desc = pool.get_message_by_name("test.v1.Item").unwrap();
1061
1062 let mut msg = DynamicMessage::new(desc);
1063 msg.set_field_by_name("name", prost_reflect::Value::String(name.to_string()));
1064 msg.set_field_by_name("count", prost_reflect::Value::I64(count));
1065 msg
1066 }
1067
1068 async fn collect_body(resp: Response) -> String {
1070 let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
1071 .await
1072 .unwrap();
1073 String::from_utf8(bytes.to_vec()).unwrap()
1074 }
1075
1076 #[tokio::test]
1077 async fn ndjson_error_frame_is_terminal() {
1078 let items = vec![
1082 Ok(item_message_named("alice", 1)),
1083 Err(tonic::Status::internal("boom")),
1084 Ok(item_message_named("bob", 2)),
1085 ];
1086 let body = collect_body(ndjson_response(futures::stream::iter(items))).await;
1087 let lines: Vec<&str> = body.lines().collect();
1088 assert_eq!(lines.len(), 2, "stream must stop after the error frame");
1089 assert!(lines[0].contains("alice"));
1090 assert!(lines[1].contains("INTERNAL") && lines[1].contains("boom"));
1091 assert!(!body.contains("bob"), "post-error message must be dropped");
1092 }
1093
1094 #[tokio::test]
1095 async fn sse_error_uses_distinct_event_name() {
1096 let items = vec![
1099 Ok(item_message_named("alice", 1)),
1100 Err(tonic::Status::permission_denied("nope")),
1101 Ok(item_message_named("bob", 2)),
1102 ];
1103 let body = collect_body(sse_response(futures::stream::iter(items), 15)).await;
1104 assert!(body.contains("stream-error"));
1105 assert!(body.contains("PERMISSION_DENIED"));
1106 assert!(!body.contains("bob"), "post-error message must be dropped");
1107 }
1108
1109 #[test]
1110 fn wants_sse_detects_event_stream_accept() {
1111 let mut headers = HeaderMap::new();
1112 headers.insert("accept", "text/event-stream".parse().unwrap());
1113 assert!(wants_sse(&headers));
1114 }
1115
1116 #[test]
1117 fn wants_sse_matches_within_list_and_ignores_params() {
1118 let mut headers = HeaderMap::new();
1119 headers.insert(
1120 "accept",
1121 "application/json, text/event-stream;q=0.9".parse().unwrap(),
1122 );
1123 assert!(wants_sse(&headers));
1124 }
1125
1126 #[test]
1127 fn wants_sse_false_for_json_and_missing() {
1128 let mut headers = HeaderMap::new();
1129 headers.insert("accept", "application/json".parse().unwrap());
1130 assert!(!wants_sse(&headers));
1131 assert!(!wants_sse(&HeaderMap::new()));
1132 }
1133
1134 #[test]
1135 fn wants_sse_rejects_explicit_q_zero() {
1136 let mut headers = HeaderMap::new();
1139 headers.insert("accept", "text/event-stream;q=0".parse().unwrap());
1140 assert!(!wants_sse(&headers));
1141 }
1142
1143 #[test]
1144 fn wants_sse_honors_second_accept_header_line() {
1145 let mut headers = HeaderMap::new();
1148 headers.append("accept", "application/json".parse().unwrap());
1149 headers.append("accept", "text/event-stream".parse().unwrap());
1150 assert!(wants_sse(&headers));
1151 }
1152
1153 #[test]
1154 fn message_to_json_string_stringifies_64bit() {
1155 let opts = response_serialize_options();
1156 let json = message_to_json_string(&item_message(), &opts).unwrap();
1157 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
1158 assert_eq!(value["name"], "alice");
1159 assert_eq!(value["count"], "42");
1161 }
1162
1163 #[test]
1164 fn ndjson_response_omits_manual_transfer_encoding() {
1165 let resp = ndjson_response(futures::stream::empty::<
1168 Result<DynamicMessage, tonic::Status>,
1169 >());
1170 assert_eq!(
1171 resp.headers().get("content-type").unwrap(),
1172 "application/x-ndjson"
1173 );
1174 assert!(resp.headers().get("transfer-encoding").is_none());
1175 }
1176
1177 #[test]
1178 fn stream_error_json_carries_grpc_code_name() {
1179 let status = tonic::Status::permission_denied("nope");
1180 let value = stream_error_json(&status);
1181 assert_eq!(value["error"], "PERMISSION_DENIED");
1182 assert_eq!(value["message"], "nope");
1183 assert_eq!(value["code"], tonic::Code::PermissionDenied as i32);
1184 }
1185}