1pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12
13use axum::extract::{Path, State};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use axum::routing::{delete, get, patch, post, put, MethodRouter};
17use axum::{Json, Router};
18use futures::StreamExt;
19use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
20use tonic::client::Grpc;
21
22use crate::config::AliasConfig;
23
24pub trait TranscodeState: Clone + Send + Sync + 'static {
29 fn grpc_channel(&self) -> tonic::transport::Channel;
31 fn forwarded_headers(&self) -> &[String];
33}
34
35impl TranscodeState for crate::ProxyState {
36 fn grpc_channel(&self) -> tonic::transport::Channel {
37 self.grpc_channel.clone()
38 }
39 fn forwarded_headers(&self) -> &[String] {
40 &self.forwarded_headers
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RouteEntry {
47 http_path: String,
49 http_method: HttpMethod,
51 grpc_path: String,
53 method: MethodDescriptor,
55}
56
57#[derive(Debug, Clone, Copy)]
58enum HttpMethod {
59 Get,
60 Post,
61 Put,
62 Patch,
63 Delete,
64}
65
66pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
71 let entries = extract_routes(pool);
72 if entries.is_empty() {
73 tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
74 return Router::new();
75 }
76
77 tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
78
79 let mut router: Router<S> = Router::new();
80 for entry in &entries {
81 let entry_clone = entry.clone();
82
83 let handler = move |proxy_state: State<S>,
84 headers: HeaderMap,
85 path_params: Path<std::collections::HashMap<String, String>>,
86 body: axum::body::Bytes| {
87 transcode_handler(proxy_state, headers, path_params, body, entry_clone)
88 };
89
90 let method_router: MethodRouter<S> = match entry.http_method {
91 HttpMethod::Get => get(handler),
92 HttpMethod::Post => post(handler),
93 HttpMethod::Put => put(handler),
94 HttpMethod::Patch => patch(handler),
95 HttpMethod::Delete => delete(handler),
96 };
97
98 let axum_path = proto_path_to_axum(&entry.http_path);
99 router = router.route(&axum_path, method_router);
100
101 for alias in aliases {
103 if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
104 let alias_path = if alias.from.ends_with("/{path}") {
106 let prefix = alias.from.trim_end_matches("/{path}");
107 format!("{}{}", prefix, suffix)
108 } else {
109 continue;
110 };
111
112 let alias_entry = entry.clone();
113 let alias_handler = move |proxy_state: State<S>,
114 headers: HeaderMap,
115 path_params: Path<std::collections::HashMap<String, String>>,
116 body: axum::body::Bytes| {
117 transcode_handler(proxy_state, headers, path_params, body, alias_entry)
118 };
119 let alias_method: MethodRouter<S> = match entry.http_method {
120 HttpMethod::Get => get(alias_handler),
121 HttpMethod::Post => post(alias_handler),
122 HttpMethod::Put => put(alias_handler),
123 HttpMethod::Patch => patch(alias_handler),
124 HttpMethod::Delete => delete(alias_handler),
125 };
126 router = router.route(&alias_path, alias_method);
127 }
128 }
129 }
130
131 let streaming_entries = extract_streaming_routes(pool);
133 for entry in &streaming_entries {
134 let entry_clone = entry.clone();
135 let axum_path = proto_path_to_axum(&entry.http_path);
136
137 let handler = move |proxy_state: State<S>, headers: HeaderMap| {
138 streaming_handler(proxy_state, headers, entry_clone)
139 };
140
141 let method_router: MethodRouter<S> = match entry.http_method {
142 HttpMethod::Get => get(handler),
143 HttpMethod::Post => post(handler),
144 _ => continue,
145 };
146
147 router = router.route(&axum_path, method_router);
148 }
149
150 router
151}
152
153async fn streaming_handler<S: TranscodeState>(
155 State(proxy_state): State<S>,
156 headers: HeaderMap,
157 entry: RouteEntry,
158) -> Response {
159 let channel = proxy_state.grpc_channel();
160
161 let input_desc = entry.method.input();
162 let request_msg = DynamicMessage::new(input_desc);
163
164 let grpc_metadata =
165 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
166 let mut grpc_request = tonic::Request::new(request_msg);
167 *grpc_request.metadata_mut() = grpc_metadata;
168
169 let output_desc = entry.method.output();
170 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
171 let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
172 Ok(p) => p,
173 Err(e) => {
174 tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
175 return (
176 StatusCode::INTERNAL_SERVER_ERROR,
177 Json(serde_json::json!({
178 "error": "INTERNAL",
179 "message": "invalid gRPC path configuration",
180 })),
181 )
182 .into_response();
183 }
184 };
185
186 let mut grpc_client = Grpc::new(channel);
187 if let Err(e) = grpc_client.ready().await {
188 return (
189 StatusCode::SERVICE_UNAVAILABLE,
190 Json(serde_json::json!({
191 "error": "UNAVAILABLE",
192 "message": format!("gRPC upstream not ready: {e}"),
193 })),
194 )
195 .into_response();
196 }
197
198 match grpc_client
199 .server_streaming(grpc_request, grpc_path, grpc_codec)
200 .await
201 {
202 Ok(response) => {
203 let stream = response.into_inner();
204 let serialize_opts = SerializeOptions::new()
205 .skip_default_fields(false)
206 .stringify_64_bit_integers(true);
207
208 let byte_stream = stream.map(move |result| match result {
209 Ok(msg) => {
210 match msg.serialize_with_options(
211 serde_json::value::Serializer,
212 &serialize_opts,
213 ) {
214 Ok(json_value) => {
215 let mut bytes =
216 serde_json::to_vec(&json_value).unwrap_or_default();
217 bytes.push(b'\n');
218 Ok::<axum::body::Bytes, std::io::Error>(
219 axum::body::Bytes::from(bytes),
220 )
221 }
222 Err(e) => Err(std::io::Error::other(format!(
223 "serialization error: {e}"
224 ))),
225 }
226 }
227 Err(status) => {
228 Err(std::io::Error::other(format!("gRPC stream error: {status}")))
229 }
230 });
231
232 let body = axum::body::Body::from_stream(byte_stream);
233 Response::builder()
234 .status(StatusCode::OK)
235 .header("content-type", "application/x-ndjson")
236 .header("transfer-encoding", "chunked")
237 .body(body)
238 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
239 }
240 Err(status) => error::status_to_response(status),
241 }
242}
243
244async fn transcode_handler<S: TranscodeState>(
246 State(proxy_state): State<S>,
247 headers: HeaderMap,
248 Path(path_params): Path<std::collections::HashMap<String, String>>,
249 body_bytes: axum::body::Bytes,
250 entry: RouteEntry,
251) -> Response {
252 let channel = proxy_state.grpc_channel();
253
254 let ct = body::content_type(&headers);
255 let mut json_body = match body::parse_body(ct, &body_bytes) {
256 Ok(v) => v,
257 Err(e) => {
258 return (
259 StatusCode::BAD_REQUEST,
260 Json(serde_json::json!({
261 "error": "INVALID_ARGUMENT",
262 "message": format!("failed to parse request body: {e}"),
263 })),
264 )
265 .into_response();
266 }
267 };
268
269 if !path_params.is_empty() {
270 if let Some(obj) = json_body.as_object_mut() {
271 for (key, value) in &path_params {
272 obj.insert(key.clone(), serde_json::Value::String(value.clone()));
273 }
274 }
275 }
276
277 let input_desc = entry.method.input();
278 let request_msg = match DynamicMessage::deserialize(input_desc, json_body) {
279 Ok(msg) => msg,
280 Err(e) => {
281 return (
282 StatusCode::BAD_REQUEST,
283 Json(serde_json::json!({
284 "error": "INVALID_ARGUMENT",
285 "message": format!("failed to decode request: {e}"),
286 })),
287 )
288 .into_response();
289 }
290 };
291
292 let grpc_metadata =
293 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
294 let mut grpc_request = tonic::Request::new(request_msg);
295 *grpc_request.metadata_mut() = grpc_metadata;
296
297 let output_desc = entry.method.output();
298 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
299 let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
300 Ok(p) => p,
301 Err(e) => {
302 tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
303 return (
304 StatusCode::INTERNAL_SERVER_ERROR,
305 Json(serde_json::json!({
306 "error": "INTERNAL",
307 "message": "invalid gRPC path configuration",
308 })),
309 )
310 .into_response();
311 }
312 };
313
314 let mut grpc_client = Grpc::new(channel);
315 if let Err(e) = grpc_client.ready().await {
316 return (
317 StatusCode::SERVICE_UNAVAILABLE,
318 Json(serde_json::json!({
319 "error": "UNAVAILABLE",
320 "message": format!("gRPC upstream not ready: {e}"),
321 })),
322 )
323 .into_response();
324 }
325
326 match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
327 Ok(response) => {
328 let response_msg = response.into_inner();
329 let serialize_opts = SerializeOptions::new()
330 .skip_default_fields(false)
331 .stringify_64_bit_integers(true);
332 match response_msg
333 .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
334 {
335 Ok(json_value) => (StatusCode::OK, Json(json_value)).into_response(),
336 Err(e) => {
337 tracing::error!("Failed to serialize gRPC response: {e}");
338 (
339 StatusCode::INTERNAL_SERVER_ERROR,
340 Json(serde_json::json!({
341 "error": "INTERNAL",
342 "message": "failed to serialize response",
343 })),
344 )
345 .into_response()
346 }
347 }
348 }
349 Err(status) => error::status_to_response(status),
350 }
351}
352
353fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
355 let http_ext = match pool.get_extension_by_name("google.api.http") {
356 Some(ext) => ext,
357 None => {
358 tracing::warn!("google.api.http extension not found in descriptor pool");
359 return Vec::new();
360 }
361 };
362
363 let mut entries = Vec::new();
364
365 for service in pool.services() {
366 for method in service.methods() {
367 if method.is_client_streaming() || method.is_server_streaming() {
368 continue;
369 }
370
371 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
372
373 if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
374 entries.push(RouteEntry {
375 http_path,
376 http_method,
377 grpc_path,
378 method: method.clone(),
379 });
380 }
381 }
382 }
383
384 entries
385}
386
387fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
389 let http_ext = match pool.get_extension_by_name("google.api.http") {
390 Some(ext) => ext,
391 None => return Vec::new(),
392 };
393
394 let mut entries = Vec::new();
395
396 for service in pool.services() {
397 for method in service.methods() {
398 if !method.is_server_streaming() || method.is_client_streaming() {
399 continue;
400 }
401
402 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
403
404 if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
405 tracing::info!(
406 "Registering streaming route: {} {} → {}",
407 match http_method {
408 HttpMethod::Get => "GET",
409 HttpMethod::Post => "POST",
410 _ => "OTHER",
411 },
412 http_path,
413 grpc_path
414 );
415 entries.push(RouteEntry {
416 http_path,
417 http_method,
418 grpc_path,
419 method: method.clone(),
420 });
421 }
422 }
423 }
424
425 entries
426}
427
428fn extract_http_rule(
430 method: &MethodDescriptor,
431 http_ext: &prost_reflect::ExtensionDescriptor,
432) -> Option<(HttpMethod, String)> {
433 let options = method.options();
434
435 if !options.has_extension(http_ext) {
436 return None;
437 }
438
439 let http_rule = options.get_extension(http_ext);
440 if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
441 for (method_name, http_method) in [
442 ("get", HttpMethod::Get),
443 ("post", HttpMethod::Post),
444 ("put", HttpMethod::Put),
445 ("delete", HttpMethod::Delete),
446 ("patch", HttpMethod::Patch),
447 ] {
448 if let Some(val) = rule_msg.get_field_by_name(method_name) {
449 if let prost_reflect::Value::String(path) = val.into_owned() {
450 if !path.is_empty() {
451 return Some((http_method, path));
452 }
453 }
454 }
455 }
456 }
457
458 None
459}
460
461pub fn proto_path_to_axum(path: &str) -> String {
463 let mut result = String::with_capacity(path.len());
464
465 for ch in path.chars() {
466 match ch {
467 '{' => result.push(':'),
468 '}' => {}
469 _ => result.push(ch),
470 }
471 }
472
473 result
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_proto_path_to_axum() {
482 assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/:id");
483 assert_eq!(
484 proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
485 "/v1/admin/profiles/:profile_id/metadata/:key"
486 );
487 assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
488 }
489}