1pub mod body;
9pub mod codec;
10pub mod error;
11pub mod metadata;
12
13use axum::extract::{Path, State};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::{IntoResponse, Response};
16use axum::routing::{delete, get, patch, post, put, MethodRouter};
17use axum::{Json, Router};
18use futures::StreamExt;
19use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, SerializeOptions};
20use tonic::client::Grpc;
21
22use crate::config::AliasConfig;
23
24pub trait TranscodeState: Clone + Send + Sync + 'static {
29 fn grpc_channel(&self) -> tonic::transport::Channel;
31 fn forwarded_headers(&self) -> &[String];
33}
34
35impl TranscodeState for crate::ProxyState {
36 fn grpc_channel(&self) -> tonic::transport::Channel {
37 self.grpc_channel.clone()
38 }
39 fn forwarded_headers(&self) -> &[String] {
40 &self.forwarded_headers
41 }
42}
43
44#[derive(Debug, Clone)]
46struct RouteEntry {
47 http_path: String,
49 http_method: HttpMethod,
51 grpc_path: String,
53 method: MethodDescriptor,
55}
56
57#[derive(Debug, Clone, Copy)]
58enum HttpMethod {
59 Get,
60 Post,
61 Put,
62 Patch,
63 Delete,
64}
65
66pub fn routes<S: TranscodeState>(pool: &DescriptorPool, aliases: &[AliasConfig]) -> Router<S> {
71 let entries = extract_routes(pool);
72 if entries.is_empty() {
73 tracing::warn!("No HTTP-annotated RPCs found in proto descriptors");
74 return Router::new();
75 }
76
77 tracing::info!("Registering {} transcoded REST→gRPC routes", entries.len());
78
79 let mut router: Router<S> = Router::new();
80 for entry in &entries {
81 let entry_clone = entry.clone();
82
83 let handler = move |proxy_state: State<S>,
84 headers: HeaderMap,
85 path_params: Path<std::collections::HashMap<String, String>>,
86 body: axum::body::Bytes| {
87 transcode_handler(proxy_state, headers, path_params, body, entry_clone)
88 };
89
90 let method_router: MethodRouter<S> = match entry.http_method {
91 HttpMethod::Get => get(handler),
92 HttpMethod::Post => post(handler),
93 HttpMethod::Put => put(handler),
94 HttpMethod::Patch => patch(handler),
95 HttpMethod::Delete => delete(handler),
96 };
97
98 let axum_path = proto_path_to_axum(&entry.http_path);
99 router = router.route(&axum_path, method_router);
100
101 for alias in aliases {
103 if let Some(suffix) = entry.http_path.strip_prefix(&alias.to) {
104 let alias_path = if alias.from.ends_with("/{path}") {
106 let prefix = alias.from.trim_end_matches("/{path}");
107 format!("{}{}", prefix, suffix)
108 } else {
109 continue;
110 };
111
112 let alias_entry = entry.clone();
113 let alias_handler =
114 move |proxy_state: State<S>,
115 headers: HeaderMap,
116 path_params: Path<std::collections::HashMap<String, String>>,
117 body: axum::body::Bytes| {
118 transcode_handler(proxy_state, headers, path_params, body, alias_entry)
119 };
120 let alias_method: MethodRouter<S> = match entry.http_method {
121 HttpMethod::Get => get(alias_handler),
122 HttpMethod::Post => post(alias_handler),
123 HttpMethod::Put => put(alias_handler),
124 HttpMethod::Patch => patch(alias_handler),
125 HttpMethod::Delete => delete(alias_handler),
126 };
127 router = router.route(&alias_path, alias_method);
128 }
129 }
130 }
131
132 let streaming_entries = extract_streaming_routes(pool);
134 for entry in &streaming_entries {
135 let entry_clone = entry.clone();
136 let axum_path = proto_path_to_axum(&entry.http_path);
137
138 let handler = move |proxy_state: State<S>, headers: HeaderMap| {
139 streaming_handler(proxy_state, headers, entry_clone)
140 };
141
142 let method_router: MethodRouter<S> = match entry.http_method {
143 HttpMethod::Get => get(handler),
144 HttpMethod::Post => post(handler),
145 _ => continue,
146 };
147
148 router = router.route(&axum_path, method_router);
149 }
150
151 router
152}
153
154async fn streaming_handler<S: TranscodeState>(
156 State(proxy_state): State<S>,
157 headers: HeaderMap,
158 entry: RouteEntry,
159) -> Response {
160 let channel = proxy_state.grpc_channel();
161
162 let input_desc = entry.method.input();
163 let request_msg = DynamicMessage::new(input_desc);
164
165 let grpc_metadata =
166 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
167 let mut grpc_request = tonic::Request::new(request_msg);
168 *grpc_request.metadata_mut() = grpc_metadata;
169
170 let output_desc = entry.method.output();
171 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
172 let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
173 Ok(p) => p,
174 Err(e) => {
175 tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
176 return (
177 StatusCode::INTERNAL_SERVER_ERROR,
178 Json(serde_json::json!({
179 "error": "INTERNAL",
180 "message": "invalid gRPC path configuration",
181 })),
182 )
183 .into_response();
184 }
185 };
186
187 let mut grpc_client = Grpc::new(channel);
188 if let Err(e) = grpc_client.ready().await {
189 return (
190 StatusCode::SERVICE_UNAVAILABLE,
191 Json(serde_json::json!({
192 "error": "UNAVAILABLE",
193 "message": format!("gRPC upstream not ready: {e}"),
194 })),
195 )
196 .into_response();
197 }
198
199 match grpc_client
200 .server_streaming(grpc_request, grpc_path, grpc_codec)
201 .await
202 {
203 Ok(response) => {
204 let stream = response.into_inner();
205 let serialize_opts = SerializeOptions::new()
206 .skip_default_fields(false)
207 .stringify_64_bit_integers(true);
208
209 let byte_stream = stream.map(move |result| match result {
210 Ok(msg) => {
211 match msg.serialize_with_options(serde_json::value::Serializer, &serialize_opts)
212 {
213 Ok(json_value) => {
214 let mut bytes = serde_json::to_vec(&json_value).unwrap_or_default();
215 bytes.push(b'\n');
216 Ok::<axum::body::Bytes, std::io::Error>(axum::body::Bytes::from(bytes))
217 }
218 Err(e) => Err(std::io::Error::other(format!("serialization error: {e}"))),
219 }
220 }
221 Err(status) => Err(std::io::Error::other(format!(
222 "gRPC stream error: {status}"
223 ))),
224 });
225
226 let body = axum::body::Body::from_stream(byte_stream);
227 Response::builder()
228 .status(StatusCode::OK)
229 .header("content-type", "application/x-ndjson")
230 .header("transfer-encoding", "chunked")
231 .body(body)
232 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
233 }
234 Err(status) => error::status_to_response(status),
235 }
236}
237
238async fn transcode_handler<S: TranscodeState>(
240 State(proxy_state): State<S>,
241 headers: HeaderMap,
242 Path(path_params): Path<std::collections::HashMap<String, String>>,
243 body_bytes: axum::body::Bytes,
244 entry: RouteEntry,
245) -> Response {
246 let channel = proxy_state.grpc_channel();
247
248 let ct = body::content_type(&headers);
249 let mut json_body = match body::parse_body(ct, &body_bytes) {
250 Ok(v) => v,
251 Err(e) => {
252 return (
253 StatusCode::BAD_REQUEST,
254 Json(serde_json::json!({
255 "error": "INVALID_ARGUMENT",
256 "message": format!("failed to parse request body: {e}"),
257 })),
258 )
259 .into_response();
260 }
261 };
262
263 if !path_params.is_empty() {
264 if let Some(obj) = json_body.as_object_mut() {
265 for (key, value) in &path_params {
266 obj.insert(key.clone(), serde_json::Value::String(value.clone()));
267 }
268 }
269 }
270
271 let input_desc = entry.method.input();
272 let request_msg = match DynamicMessage::deserialize(input_desc, json_body) {
273 Ok(msg) => msg,
274 Err(e) => {
275 return (
276 StatusCode::BAD_REQUEST,
277 Json(serde_json::json!({
278 "error": "INVALID_ARGUMENT",
279 "message": format!("failed to decode request: {e}"),
280 })),
281 )
282 .into_response();
283 }
284 };
285
286 let grpc_metadata =
287 metadata::http_headers_to_grpc_metadata(&headers, proxy_state.forwarded_headers());
288 let mut grpc_request = tonic::Request::new(request_msg);
289 *grpc_request.metadata_mut() = grpc_metadata;
290
291 let output_desc = entry.method.output();
292 let grpc_codec = codec::DynamicCodec::new(output_desc.clone());
293 let grpc_path: axum::http::uri::PathAndQuery = match entry.grpc_path.parse() {
294 Ok(p) => p,
295 Err(e) => {
296 tracing::error!("Invalid gRPC path '{}': {e}", entry.grpc_path);
297 return (
298 StatusCode::INTERNAL_SERVER_ERROR,
299 Json(serde_json::json!({
300 "error": "INTERNAL",
301 "message": "invalid gRPC path configuration",
302 })),
303 )
304 .into_response();
305 }
306 };
307
308 let mut grpc_client = Grpc::new(channel);
309 if let Err(e) = grpc_client.ready().await {
310 return (
311 StatusCode::SERVICE_UNAVAILABLE,
312 Json(serde_json::json!({
313 "error": "UNAVAILABLE",
314 "message": format!("gRPC upstream not ready: {e}"),
315 })),
316 )
317 .into_response();
318 }
319
320 match grpc_client.unary(grpc_request, grpc_path, grpc_codec).await {
321 Ok(response) => {
322 let response_msg = response.into_inner();
323 let serialize_opts = SerializeOptions::new()
324 .skip_default_fields(false)
325 .stringify_64_bit_integers(true);
326 match response_msg
327 .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
328 {
329 Ok(json_value) => (StatusCode::OK, Json(json_value)).into_response(),
330 Err(e) => {
331 tracing::error!("Failed to serialize gRPC response: {e}");
332 (
333 StatusCode::INTERNAL_SERVER_ERROR,
334 Json(serde_json::json!({
335 "error": "INTERNAL",
336 "message": "failed to serialize response",
337 })),
338 )
339 .into_response()
340 }
341 }
342 }
343 Err(status) => error::status_to_response(status),
344 }
345}
346
347fn extract_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
349 let http_ext = match pool.get_extension_by_name("google.api.http") {
350 Some(ext) => ext,
351 None => {
352 tracing::warn!("google.api.http extension not found in descriptor pool");
353 return Vec::new();
354 }
355 };
356
357 let mut entries = Vec::new();
358
359 for service in pool.services() {
360 for method in service.methods() {
361 if method.is_client_streaming() || method.is_server_streaming() {
362 continue;
363 }
364
365 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
366
367 if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
368 entries.push(RouteEntry {
369 http_path,
370 http_method,
371 grpc_path,
372 method: method.clone(),
373 });
374 }
375 }
376 }
377
378 entries
379}
380
381fn extract_streaming_routes(pool: &DescriptorPool) -> Vec<RouteEntry> {
383 let http_ext = match pool.get_extension_by_name("google.api.http") {
384 Some(ext) => ext,
385 None => return Vec::new(),
386 };
387
388 let mut entries = Vec::new();
389
390 for service in pool.services() {
391 for method in service.methods() {
392 if !method.is_server_streaming() || method.is_client_streaming() {
393 continue;
394 }
395
396 let grpc_path = format!("/{}/{}", service.full_name(), method.name());
397
398 if let Some((http_method, http_path)) = extract_http_rule(&method, &http_ext) {
399 tracing::info!(
400 "Registering streaming route: {} {} → {}",
401 match http_method {
402 HttpMethod::Get => "GET",
403 HttpMethod::Post => "POST",
404 _ => "OTHER",
405 },
406 http_path,
407 grpc_path
408 );
409 entries.push(RouteEntry {
410 http_path,
411 http_method,
412 grpc_path,
413 method: method.clone(),
414 });
415 }
416 }
417 }
418
419 entries
420}
421
422fn extract_http_rule(
424 method: &MethodDescriptor,
425 http_ext: &prost_reflect::ExtensionDescriptor,
426) -> Option<(HttpMethod, String)> {
427 let options = method.options();
428
429 if !options.has_extension(http_ext) {
430 return None;
431 }
432
433 let http_rule = options.get_extension(http_ext);
434 if let prost_reflect::Value::Message(rule_msg) = http_rule.into_owned() {
435 for (method_name, http_method) in [
436 ("get", HttpMethod::Get),
437 ("post", HttpMethod::Post),
438 ("put", HttpMethod::Put),
439 ("delete", HttpMethod::Delete),
440 ("patch", HttpMethod::Patch),
441 ] {
442 if let Some(val) = rule_msg.get_field_by_name(method_name) {
443 if let prost_reflect::Value::String(path) = val.into_owned() {
444 if !path.is_empty() {
445 return Some((http_method, path));
446 }
447 }
448 }
449 }
450 }
451
452 None
453}
454
455pub fn proto_path_to_axum(path: &str) -> String {
457 let mut result = String::with_capacity(path.len());
458
459 for ch in path.chars() {
460 match ch {
461 '{' => result.push(':'),
462 '}' => {}
463 _ => result.push(ch),
464 }
465 }
466
467 result
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_proto_path_to_axum() {
476 assert_eq!(proto_path_to_axum("/v1/profiles/{id}"), "/v1/profiles/:id");
477 assert_eq!(
478 proto_path_to_axum("/v1/admin/profiles/{profile_id}/metadata/{key}"),
479 "/v1/admin/profiles/:profile_id/metadata/:key"
480 );
481 assert_eq!(proto_path_to_axum("/v1/auth/login"), "/v1/auth/login");
482 }
483}