1#[cfg(all(feature = "rust_crypto", feature = "aws_lc_rs"))]
23compile_error!("features `rust_crypto` and `aws_lc_rs` are mutually exclusive; enable exactly one");
24
25#[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))]
26compile_error!("exactly one JWT crypto backend must be enabled: `rust_crypto` or `aws_lc_rs`");
27
28pub mod auth;
29pub mod config;
30pub mod oidc;
31pub mod openapi;
32pub mod shield;
33pub mod transcode;
34
35use axum::extract::State;
36use axum::http::{Request, StatusCode};
37use axum::middleware::Next;
38use axum::response::{IntoResponse, Response};
39use axum::routing::get;
40use axum::{Json, Router};
41use prost_reflect::DescriptorPool;
42use std::net::SocketAddr;
43use tower_http::cors::{AllowOrigin, CorsLayer};
44use tower_http::trace::TraceLayer;
45
46use config::{DescriptorSource, ProxyConfig};
47
48#[derive(Clone, Debug)]
50pub struct ProxyState {
51 pub service_name: String,
53 pub grpc_upstream: String,
55 pub grpc_channel: tonic::transport::Channel,
57 pub maintenance_mode: bool,
59 pub maintenance_exempt: Vec<String>,
61 pub maintenance_message: String,
63 pub forwarded_headers: Vec<String>,
65 pub metrics_namespace: String,
67 pub metrics_classes: Vec<config::MetricsClassConfig>,
69 pub sse_keep_alive_secs: u64,
71}
72
73pub struct ProxyServer {
75 config: ProxyConfig,
76 descriptor_pool: Option<DescriptorPool>,
78}
79
80impl ProxyServer {
81 pub fn from_config(config: ProxyConfig) -> Self {
83 Self {
84 config,
85 descriptor_pool: None,
86 }
87 }
88
89 pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
91 self.descriptor_pool = Some(pool);
92 self
93 }
94
95 fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
100 if let Some(pool) = &self.descriptor_pool {
101 return Ok(pool.clone());
102 }
103
104 let mut pool = DescriptorPool::new();
105
106 for source in &self.config.descriptors {
107 match source {
108 DescriptorSource::File { file } => {
109 let bytes = std::fs::read(file).map_err(|e| {
110 anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
111 })?;
112 pool.decode_file_descriptor_set(bytes.as_slice())
113 .map_err(|e| {
114 anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
115 })?;
116 tracing::info!("Loaded descriptor from {:?}", file);
117 }
118 DescriptorSource::Reflection { reflection } => {
119 tracing::warn!(
120 "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
121 reflection
122 );
123 }
124 DescriptorSource::Embedded { bytes } => {
125 pool.decode_file_descriptor_set(*bytes).map_err(|e| {
126 anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
127 })?;
128 }
129 }
130 }
131
132 Ok(pool)
133 }
134
135 pub fn router(&self) -> anyhow::Result<Router> {
137 self.config.validate()?;
140 let pool = self.load_descriptors()?;
141
142 let grpc_upstream = self.config.upstream.default.clone();
143 let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
144 .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
145 .connect_timeout(std::time::Duration::from_secs(5))
146 .timeout(std::time::Duration::from_secs(5))
147 .connect_lazy();
148
149 let service_name = self.config.service.name.clone();
150 let metrics_namespace = service_name.replace('-', "_");
151
152 let state = ProxyState {
153 service_name: service_name.clone(),
154 grpc_upstream,
155 grpc_channel,
156 maintenance_mode: self.config.maintenance.enabled,
157 maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
158 maintenance_message: self.config.maintenance.message.clone(),
159 forwarded_headers: self.config.forwarded_headers.clone(),
160 metrics_namespace,
161 metrics_classes: self.config.metrics_classes.clone(),
162 sse_keep_alive_secs: self.config.streaming.sse_keep_alive_secs,
163 };
164
165 let cors = self.build_cors();
166
167 let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
169
170 let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
175 Some(cfg) => auth::authz::Authz::build(cfg)
176 .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
177 None => None,
178 };
179 if let Some(authz) = authz {
180 transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
181 authz,
182 auth::authz::middleware,
183 ));
184 }
185
186 let health_service_name = service_name.clone();
188 let health_routes = Router::new()
189 .route(
190 "/health",
191 get({
192 let name = health_service_name.clone();
193 move || async move {
194 Json(serde_json::json!({
195 "status": "ok",
196 "service": name,
197 }))
198 }
199 }),
200 )
201 .route("/health/live", get(|| async { StatusCode::OK }))
202 .route(
203 "/health/ready",
204 get(|State(state): State<ProxyState>| async move {
205 let mut client =
206 tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
207 match client
208 .check(tonic_health::pb::HealthCheckRequest {
209 service: String::new(),
210 })
211 .await
212 {
213 Ok(resp) => {
214 let status = resp.into_inner().status;
215 if status
216 == tonic_health::pb::health_check_response::ServingStatus::Serving
217 as i32
218 {
219 StatusCode::OK
220 } else {
221 StatusCode::SERVICE_UNAVAILABLE
222 }
223 }
224 Err(_) => StatusCode::SERVICE_UNAVAILABLE,
225 }
226 }),
227 )
228 .route("/health/startup", get(|| async { StatusCode::OK }))
229 .route(
230 "/metrics",
231 get(|| async {
232 let encoder = prometheus::TextEncoder::new();
233 let metric_families = prometheus::default_registry().gather();
234 match encoder.encode_to_string(&metric_families) {
235 Ok(text) => (
236 StatusCode::OK,
237 [(
238 axum::http::header::CONTENT_TYPE,
239 "text/plain; version=0.0.4; charset=utf-8",
240 )],
241 text,
242 )
243 .into_response(),
244 Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
245 }
246 }),
247 );
248
249 let openapi_routes = self.build_openapi_routes(&pool);
251
252 let oidc_routes = match &self.config.oidc_discovery {
254 Some(cfg) => oidc::Oidc::build(cfg)
255 .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
256 .map(|o| o.routes())
257 .unwrap_or_default(),
258 None => Router::new(),
259 };
260
261 let shield = match &self.config.shield {
263 Some(cfg) => shield::Shield::build(cfg)
264 .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
265 None => None,
266 };
267
268 let auth = match &self.config.auth {
270 Some(cfg) => {
271 auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
272 }
273 None => None,
274 };
275
276 let mut router = Router::new()
277 .merge(health_routes)
278 .merge(openapi_routes)
279 .merge(oidc_routes)
280 .merge(transcode_routes)
281 .layer(cors);
282
283 let forward_auth = auth.as_ref().and_then(|built| {
287 auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
288 });
289
290 if let Some(auth) = auth {
293 router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
294 }
295
296 if let Some(forward_auth) = &forward_auth {
297 router = router.merge(forward_auth.routes());
298 }
299
300 if let Some(shield) = shield {
304 router = router.layer(axum::middleware::from_fn_with_state(
305 shield,
306 shield::middleware,
307 ));
308 }
309
310 let router = router
311 .layer(axum::middleware::from_fn_with_state(
312 state.clone(),
313 maintenance_middleware,
314 ))
315 .layer(TraceLayer::new_for_http())
316 .with_state(state);
317
318 Ok(router)
319 }
320
321 fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
322 let openapi_config = match &self.config.openapi {
323 Some(cfg) if cfg.enabled => cfg,
324 _ => return Router::new(),
325 };
326
327 let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
328 let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
329 let openapi_path = openapi_config.path.clone();
330 let docs_path = openapi_config.docs_path.clone();
331 let title = openapi_config
332 .title
333 .clone()
334 .unwrap_or_else(|| self.config.service.name.clone());
335 let openapi_path_for_docs = openapi_path.clone();
336
337 tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
338
339 Router::new()
340 .route(
341 &openapi_path,
342 get(move || async move {
343 (
344 StatusCode::OK,
345 [(
346 axum::http::header::CONTENT_TYPE,
347 "application/json; charset=utf-8",
348 )],
349 spec_json,
350 )
351 }),
352 )
353 .route(
354 &docs_path,
355 get(move || async move {
356 let html = openapi::docs_html(&openapi_path_for_docs, &title);
357 (
358 StatusCode::OK,
359 [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
360 html,
361 )
362 }),
363 )
364 }
365
366 fn build_cors(&self) -> CorsLayer {
367 if self.config.cors.origins.is_empty() {
368 tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
369 CorsLayer::permissive()
370 } else {
371 let origins: Vec<_> = self
372 .config
373 .cors
374 .origins
375 .iter()
376 .filter_map(|o| o.parse().ok())
377 .collect();
378 CorsLayer::new()
379 .allow_origin(AllowOrigin::list(origins))
380 .allow_methods(tower_http::cors::Any)
381 .allow_headers(tower_http::cors::Any)
382 .allow_credentials(true)
383 .expose_headers([
384 "grpc-status".parse().unwrap(),
385 "grpc-message".parse().unwrap(),
386 ])
387 }
388 }
389
390 pub async fn serve(&self) -> anyhow::Result<()> {
392 let router = self.router()?;
393 let app = router.into_make_service_with_connect_info::<SocketAddr>();
394 let addr: SocketAddr = self.config.listen.http.parse()?;
395 let listener = tokio::net::TcpListener::bind(addr).await?;
396
397 tracing::info!("{} listening on {}", self.config.service.name, addr);
398 axum::serve(listener, app).await?;
399 Ok(())
400 }
401}
402
403async fn maintenance_middleware(
405 State(state): State<ProxyState>,
406 request: Request<axum::body::Body>,
407 next: Next,
408) -> Response {
409 if state.maintenance_mode {
410 let path = request.uri().path();
411 let exempt = state.maintenance_exempt.iter().any(|pattern| {
412 if pattern.ends_with("/**") {
413 let prefix = &pattern[..pattern.len() - 3];
414 path.starts_with(prefix)
415 } else {
416 path == pattern
417 }
418 });
419 if !exempt {
420 return (
421 StatusCode::SERVICE_UNAVAILABLE,
422 [("retry-after", "300")],
423 state.maintenance_message.clone(),
424 )
425 .into_response();
426 }
427 }
428 next.run(request).await
429}
430
431#[cfg(test)]
433pub(crate) fn test_channel() -> tonic::transport::Channel {
434 tonic::transport::Channel::from_static("http://127.0.0.1:1")
435 .connect_timeout(std::time::Duration::from_millis(100))
436 .connect_lazy()
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_minimal_config_server() {
445 let yaml = r#"
446upstream:
447 default: "http://127.0.0.1:50051"
448"#;
449 let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
450 let server = ProxyServer::from_config(config);
451 assert!(server.descriptor_pool.is_none());
452 }
453
454 #[tokio::test]
455 async fn test_maintenance_exempt_matching() {
456 let state = ProxyState {
457 service_name: "test".into(),
458 grpc_upstream: "http://localhost:50051".into(),
459 grpc_channel: test_channel(),
460 maintenance_mode: true,
461 maintenance_exempt: vec![
462 "/health/**".into(),
463 "/.well-known/**".into(),
464 "/metrics".into(),
465 ],
466 maintenance_message: "Down".into(),
467 forwarded_headers: vec![],
468 metrics_namespace: "test".into(),
469 metrics_classes: vec![],
470 sse_keep_alive_secs: 15,
471 };
472
473 let check = |path: &str| -> bool {
474 state.maintenance_exempt.iter().any(|pattern| {
475 if pattern.ends_with("/**") {
476 let prefix = &pattern[..pattern.len() - 3];
477 path.starts_with(prefix)
478 } else {
479 path == pattern
480 }
481 })
482 };
483
484 assert!(check("/health"));
485 assert!(check("/health/ready"));
486 assert!(check("/.well-known/openid-configuration"));
487 assert!(check("/metrics"));
488 assert!(!check("/v1/auth/login"));
489 assert!(!check("/oauth2/token"));
490 }
491}