1pub mod config;
14pub mod openapi;
15pub mod transcode;
16
17use axum::extract::State;
18use axum::http::{Request, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::get;
22use axum::{Json, Router};
23use prost_reflect::DescriptorPool;
24use std::net::SocketAddr;
25use tower_http::cors::{AllowOrigin, CorsLayer};
26use tower_http::trace::TraceLayer;
27
28use config::{DescriptorSource, ProxyConfig};
29
30#[derive(Clone, Debug)]
32pub struct ProxyState {
33 pub service_name: String,
35 pub grpc_upstream: String,
37 pub grpc_channel: tonic::transport::Channel,
39 pub maintenance_mode: bool,
41 pub maintenance_exempt: Vec<String>,
43 pub maintenance_message: String,
45 pub forwarded_headers: Vec<String>,
47 pub metrics_namespace: String,
49 pub metrics_classes: Vec<config::MetricsClassConfig>,
51}
52
53pub struct ProxyServer {
55 config: ProxyConfig,
56 descriptor_pool: Option<DescriptorPool>,
58}
59
60impl ProxyServer {
61 pub fn from_config(config: ProxyConfig) -> Self {
63 Self {
64 config,
65 descriptor_pool: None,
66 }
67 }
68
69 pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
71 self.descriptor_pool = Some(pool);
72 self
73 }
74
75 fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
80 if let Some(pool) = &self.descriptor_pool {
81 return Ok(pool.clone());
82 }
83
84 let mut pool = DescriptorPool::new();
85
86 for source in &self.config.descriptors {
87 match source {
88 DescriptorSource::File { file } => {
89 let bytes = std::fs::read(file).map_err(|e| {
90 anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
91 })?;
92 pool.decode_file_descriptor_set(bytes.as_slice()).map_err(|e| {
93 anyhow::anyhow!(
94 "Failed to decode descriptor file {:?}: {}",
95 file,
96 e
97 )
98 })?;
99 tracing::info!("Loaded descriptor from {:?}", file);
100 }
101 DescriptorSource::Reflection { reflection } => {
102 tracing::warn!(
103 "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
104 reflection
105 );
106 }
107 DescriptorSource::Embedded { bytes } => {
108 pool.decode_file_descriptor_set(*bytes).map_err(|e| {
109 anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
110 })?;
111 }
112 }
113 }
114
115 Ok(pool)
116 }
117
118 pub fn router(&self) -> anyhow::Result<Router> {
120 let pool = self.load_descriptors()?;
121
122 let grpc_upstream = self.config.upstream.default.clone();
123 let grpc_channel =
124 tonic::transport::Channel::from_shared(grpc_upstream.clone())
125 .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
126 .connect_timeout(std::time::Duration::from_secs(5))
127 .timeout(std::time::Duration::from_secs(5))
128 .connect_lazy();
129
130 let service_name = self.config.service.name.clone();
131 let metrics_namespace = service_name.replace('-', "_");
132
133 let state = ProxyState {
134 service_name: service_name.clone(),
135 grpc_upstream,
136 grpc_channel,
137 maintenance_mode: self.config.maintenance.enabled,
138 maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
139 maintenance_message: self.config.maintenance.message.clone(),
140 forwarded_headers: self.config.forwarded_headers.clone(),
141 metrics_namespace,
142 metrics_classes: self.config.metrics_classes.clone(),
143 };
144
145 let cors = self.build_cors();
146
147 let transcode_routes = transcode::routes(&pool, &self.config.aliases);
149
150 let health_service_name = service_name.clone();
152 let health_routes = Router::new()
153 .route(
154 "/health",
155 get({
156 let name = health_service_name.clone();
157 move || async move {
158 Json(serde_json::json!({
159 "status": "ok",
160 "service": name,
161 }))
162 }
163 }),
164 )
165 .route("/health/live", get(|| async { StatusCode::OK }))
166 .route(
167 "/health/ready",
168 get(|State(state): State<ProxyState>| async move {
169 let mut client =
170 tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
171 match client
172 .check(tonic_health::pb::HealthCheckRequest {
173 service: String::new(),
174 })
175 .await
176 {
177 Ok(resp) => {
178 let status = resp.into_inner().status;
179 if status
180 == tonic_health::pb::health_check_response::ServingStatus::Serving
181 as i32
182 {
183 StatusCode::OK
184 } else {
185 StatusCode::SERVICE_UNAVAILABLE
186 }
187 }
188 Err(_) => StatusCode::SERVICE_UNAVAILABLE,
189 }
190 }),
191 )
192 .route("/health/startup", get(|| async { StatusCode::OK }))
193 .route(
194 "/metrics",
195 get(|| async {
196 let encoder = prometheus::TextEncoder::new();
197 let metric_families = prometheus::default_registry().gather();
198 match encoder.encode_to_string(&metric_families) {
199 Ok(text) => (
200 StatusCode::OK,
201 [(
202 axum::http::header::CONTENT_TYPE,
203 "text/plain; version=0.0.4; charset=utf-8",
204 )],
205 text,
206 )
207 .into_response(),
208 Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
209 }
210 }),
211 );
212
213 let openapi_routes = self.build_openapi_routes(&pool);
215
216 let router = Router::new()
217 .merge(health_routes)
218 .merge(openapi_routes)
219 .merge(transcode_routes)
220 .layer(cors)
221 .layer(axum::middleware::from_fn_with_state(
222 state.clone(),
223 maintenance_middleware,
224 ))
225 .layer(TraceLayer::new_for_http())
226 .with_state(state);
227
228 Ok(router)
229 }
230
231 fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
232 let openapi_config = match &self.config.openapi {
233 Some(cfg) if cfg.enabled => cfg,
234 _ => return Router::new(),
235 };
236
237 let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
238 let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
239 let openapi_path = openapi_config.path.clone();
240 let docs_path = openapi_config.docs_path.clone();
241 let title = openapi_config
242 .title
243 .clone()
244 .unwrap_or_else(|| self.config.service.name.clone());
245 let openapi_path_for_docs = openapi_path.clone();
246
247 tracing::info!(
248 "OpenAPI spec at {}, docs at {}",
249 openapi_path,
250 docs_path,
251 );
252
253 Router::new()
254 .route(
255 &openapi_path,
256 get(move || async move {
257 (
258 StatusCode::OK,
259 [(
260 axum::http::header::CONTENT_TYPE,
261 "application/json; charset=utf-8",
262 )],
263 spec_json,
264 )
265 }),
266 )
267 .route(
268 &docs_path,
269 get(move || async move {
270 let html = openapi::docs_html(&openapi_path_for_docs, &title);
271 (
272 StatusCode::OK,
273 [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
274 html,
275 )
276 }),
277 )
278 }
279
280 fn build_cors(&self) -> CorsLayer {
281 if self.config.cors.origins.is_empty() {
282 tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
283 CorsLayer::permissive()
284 } else {
285 let origins: Vec<_> = self
286 .config
287 .cors
288 .origins
289 .iter()
290 .filter_map(|o| o.parse().ok())
291 .collect();
292 CorsLayer::new()
293 .allow_origin(AllowOrigin::list(origins))
294 .allow_methods(tower_http::cors::Any)
295 .allow_headers(tower_http::cors::Any)
296 .allow_credentials(true)
297 .expose_headers([
298 "grpc-status".parse().unwrap(),
299 "grpc-message".parse().unwrap(),
300 ])
301 }
302 }
303
304 pub async fn serve(&self) -> anyhow::Result<()> {
306 let router = self.router()?;
307 let app = router.into_make_service_with_connect_info::<SocketAddr>();
308 let addr: SocketAddr = self.config.listen.http.parse()?;
309 let listener = tokio::net::TcpListener::bind(addr).await?;
310
311 tracing::info!(
312 "{} listening on {}",
313 self.config.service.name,
314 addr
315 );
316 axum::serve(listener, app).await?;
317 Ok(())
318 }
319}
320
321async fn maintenance_middleware(
323 State(state): State<ProxyState>,
324 request: Request<axum::body::Body>,
325 next: Next,
326) -> Response {
327 if state.maintenance_mode {
328 let path = request.uri().path();
329 let exempt = state.maintenance_exempt.iter().any(|pattern| {
330 if pattern.ends_with("/**") {
331 let prefix = &pattern[..pattern.len() - 3];
332 path.starts_with(prefix)
333 } else {
334 path == pattern
335 }
336 });
337 if !exempt {
338 return (
339 StatusCode::SERVICE_UNAVAILABLE,
340 [("retry-after", "300")],
341 state.maintenance_message.clone(),
342 )
343 .into_response();
344 }
345 }
346 next.run(request).await
347}
348
349#[cfg(test)]
351pub(crate) fn test_channel() -> tonic::transport::Channel {
352 tonic::transport::Channel::from_static("http://127.0.0.1:1")
353 .connect_timeout(std::time::Duration::from_millis(100))
354 .connect_lazy()
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_minimal_config_server() {
363 let yaml = r#"
364upstream:
365 default: "http://127.0.0.1:50051"
366"#;
367 let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
368 let server = ProxyServer::from_config(config);
369 assert!(server.descriptor_pool.is_none());
370 }
371
372 #[tokio::test]
373 async fn test_maintenance_exempt_matching() {
374 let state = ProxyState {
375 service_name: "test".into(),
376 grpc_upstream: "http://localhost:50051".into(),
377 grpc_channel: test_channel(),
378 maintenance_mode: true,
379 maintenance_exempt: vec![
380 "/health/**".into(),
381 "/.well-known/**".into(),
382 "/metrics".into(),
383 ],
384 maintenance_message: "Down".into(),
385 forwarded_headers: vec![],
386 metrics_namespace: "test".into(),
387 metrics_classes: vec![],
388 };
389
390 let check = |path: &str| -> bool {
391 state.maintenance_exempt.iter().any(|pattern| {
392 if pattern.ends_with("/**") {
393 let prefix = &pattern[..pattern.len() - 3];
394 path.starts_with(prefix)
395 } else {
396 path == pattern
397 }
398 })
399 };
400
401 assert!(check("/health"));
402 assert!(check("/health/ready"));
403 assert!(check("/.well-known/openid-configuration"));
404 assert!(check("/metrics"));
405 assert!(!check("/v1/auth/login"));
406 assert!(!check("/oauth2/token"));
407 }
408}