1pub mod auth;
14pub mod config;
15pub mod oidc;
16pub mod openapi;
17pub mod shield;
18pub mod transcode;
19
20use axum::extract::State;
21use axum::http::{Request, StatusCode};
22use axum::middleware::Next;
23use axum::response::{IntoResponse, Response};
24use axum::routing::get;
25use axum::{Json, Router};
26use prost_reflect::DescriptorPool;
27use std::net::SocketAddr;
28use tower_http::cors::{AllowOrigin, CorsLayer};
29use tower_http::trace::TraceLayer;
30
31use config::{DescriptorSource, ProxyConfig};
32
33#[derive(Clone, Debug)]
35pub struct ProxyState {
36 pub service_name: String,
38 pub grpc_upstream: String,
40 pub grpc_channel: tonic::transport::Channel,
42 pub maintenance_mode: bool,
44 pub maintenance_exempt: Vec<String>,
46 pub maintenance_message: String,
48 pub forwarded_headers: Vec<String>,
50 pub metrics_namespace: String,
52 pub metrics_classes: Vec<config::MetricsClassConfig>,
54}
55
56pub struct ProxyServer {
58 config: ProxyConfig,
59 descriptor_pool: Option<DescriptorPool>,
61}
62
63impl ProxyServer {
64 pub fn from_config(config: ProxyConfig) -> Self {
66 Self {
67 config,
68 descriptor_pool: None,
69 }
70 }
71
72 pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
74 self.descriptor_pool = Some(pool);
75 self
76 }
77
78 fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
83 if let Some(pool) = &self.descriptor_pool {
84 return Ok(pool.clone());
85 }
86
87 let mut pool = DescriptorPool::new();
88
89 for source in &self.config.descriptors {
90 match source {
91 DescriptorSource::File { file } => {
92 let bytes = std::fs::read(file).map_err(|e| {
93 anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
94 })?;
95 pool.decode_file_descriptor_set(bytes.as_slice())
96 .map_err(|e| {
97 anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
98 })?;
99 tracing::info!("Loaded descriptor from {:?}", file);
100 }
101 DescriptorSource::Reflection { reflection } => {
102 tracing::warn!(
103 "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
104 reflection
105 );
106 }
107 DescriptorSource::Embedded { bytes } => {
108 pool.decode_file_descriptor_set(*bytes).map_err(|e| {
109 anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
110 })?;
111 }
112 }
113 }
114
115 Ok(pool)
116 }
117
118 pub fn router(&self) -> anyhow::Result<Router> {
120 let pool = self.load_descriptors()?;
121
122 let grpc_upstream = self.config.upstream.default.clone();
123 let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
124 .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
125 .connect_timeout(std::time::Duration::from_secs(5))
126 .timeout(std::time::Duration::from_secs(5))
127 .connect_lazy();
128
129 let service_name = self.config.service.name.clone();
130 let metrics_namespace = service_name.replace('-', "_");
131
132 let state = ProxyState {
133 service_name: service_name.clone(),
134 grpc_upstream,
135 grpc_channel,
136 maintenance_mode: self.config.maintenance.enabled,
137 maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
138 maintenance_message: self.config.maintenance.message.clone(),
139 forwarded_headers: self.config.forwarded_headers.clone(),
140 metrics_namespace,
141 metrics_classes: self.config.metrics_classes.clone(),
142 };
143
144 let cors = self.build_cors();
145
146 let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
148
149 let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
154 Some(cfg) => auth::authz::Authz::build(cfg)
155 .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
156 None => None,
157 };
158 if let Some(authz) = authz {
159 transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
160 authz,
161 auth::authz::middleware,
162 ));
163 }
164
165 let health_service_name = service_name.clone();
167 let health_routes = Router::new()
168 .route(
169 "/health",
170 get({
171 let name = health_service_name.clone();
172 move || async move {
173 Json(serde_json::json!({
174 "status": "ok",
175 "service": name,
176 }))
177 }
178 }),
179 )
180 .route("/health/live", get(|| async { StatusCode::OK }))
181 .route(
182 "/health/ready",
183 get(|State(state): State<ProxyState>| async move {
184 let mut client =
185 tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
186 match client
187 .check(tonic_health::pb::HealthCheckRequest {
188 service: String::new(),
189 })
190 .await
191 {
192 Ok(resp) => {
193 let status = resp.into_inner().status;
194 if status
195 == tonic_health::pb::health_check_response::ServingStatus::Serving
196 as i32
197 {
198 StatusCode::OK
199 } else {
200 StatusCode::SERVICE_UNAVAILABLE
201 }
202 }
203 Err(_) => StatusCode::SERVICE_UNAVAILABLE,
204 }
205 }),
206 )
207 .route("/health/startup", get(|| async { StatusCode::OK }))
208 .route(
209 "/metrics",
210 get(|| async {
211 let encoder = prometheus::TextEncoder::new();
212 let metric_families = prometheus::default_registry().gather();
213 match encoder.encode_to_string(&metric_families) {
214 Ok(text) => (
215 StatusCode::OK,
216 [(
217 axum::http::header::CONTENT_TYPE,
218 "text/plain; version=0.0.4; charset=utf-8",
219 )],
220 text,
221 )
222 .into_response(),
223 Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
224 }
225 }),
226 );
227
228 let openapi_routes = self.build_openapi_routes(&pool);
230
231 let oidc_routes = match &self.config.oidc_discovery {
233 Some(cfg) => oidc::Oidc::build(cfg)
234 .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
235 .map(|o| o.routes())
236 .unwrap_or_default(),
237 None => Router::new(),
238 };
239
240 let shield = match &self.config.shield {
242 Some(cfg) => shield::Shield::build(cfg)
243 .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
244 None => None,
245 };
246
247 let auth = match &self.config.auth {
249 Some(cfg) => {
250 auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
251 }
252 None => None,
253 };
254
255 let mut router = Router::new()
256 .merge(health_routes)
257 .merge(openapi_routes)
258 .merge(oidc_routes)
259 .merge(transcode_routes)
260 .layer(cors);
261
262 let forward_auth = auth.as_ref().and_then(|built| {
266 auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
267 });
268
269 if let Some(auth) = auth {
272 router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
273 }
274
275 if let Some(forward_auth) = &forward_auth {
276 router = router.merge(forward_auth.routes());
277 }
278
279 if let Some(shield) = shield {
283 router = router.layer(axum::middleware::from_fn_with_state(
284 shield,
285 shield::middleware,
286 ));
287 }
288
289 let router = router
290 .layer(axum::middleware::from_fn_with_state(
291 state.clone(),
292 maintenance_middleware,
293 ))
294 .layer(TraceLayer::new_for_http())
295 .with_state(state);
296
297 Ok(router)
298 }
299
300 fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
301 let openapi_config = match &self.config.openapi {
302 Some(cfg) if cfg.enabled => cfg,
303 _ => return Router::new(),
304 };
305
306 let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
307 let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
308 let openapi_path = openapi_config.path.clone();
309 let docs_path = openapi_config.docs_path.clone();
310 let title = openapi_config
311 .title
312 .clone()
313 .unwrap_or_else(|| self.config.service.name.clone());
314 let openapi_path_for_docs = openapi_path.clone();
315
316 tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
317
318 Router::new()
319 .route(
320 &openapi_path,
321 get(move || async move {
322 (
323 StatusCode::OK,
324 [(
325 axum::http::header::CONTENT_TYPE,
326 "application/json; charset=utf-8",
327 )],
328 spec_json,
329 )
330 }),
331 )
332 .route(
333 &docs_path,
334 get(move || async move {
335 let html = openapi::docs_html(&openapi_path_for_docs, &title);
336 (
337 StatusCode::OK,
338 [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
339 html,
340 )
341 }),
342 )
343 }
344
345 fn build_cors(&self) -> CorsLayer {
346 if self.config.cors.origins.is_empty() {
347 tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
348 CorsLayer::permissive()
349 } else {
350 let origins: Vec<_> = self
351 .config
352 .cors
353 .origins
354 .iter()
355 .filter_map(|o| o.parse().ok())
356 .collect();
357 CorsLayer::new()
358 .allow_origin(AllowOrigin::list(origins))
359 .allow_methods(tower_http::cors::Any)
360 .allow_headers(tower_http::cors::Any)
361 .allow_credentials(true)
362 .expose_headers([
363 "grpc-status".parse().unwrap(),
364 "grpc-message".parse().unwrap(),
365 ])
366 }
367 }
368
369 pub async fn serve(&self) -> anyhow::Result<()> {
371 let router = self.router()?;
372 let app = router.into_make_service_with_connect_info::<SocketAddr>();
373 let addr: SocketAddr = self.config.listen.http.parse()?;
374 let listener = tokio::net::TcpListener::bind(addr).await?;
375
376 tracing::info!("{} listening on {}", self.config.service.name, addr);
377 axum::serve(listener, app).await?;
378 Ok(())
379 }
380}
381
382async fn maintenance_middleware(
384 State(state): State<ProxyState>,
385 request: Request<axum::body::Body>,
386 next: Next,
387) -> Response {
388 if state.maintenance_mode {
389 let path = request.uri().path();
390 let exempt = state.maintenance_exempt.iter().any(|pattern| {
391 if pattern.ends_with("/**") {
392 let prefix = &pattern[..pattern.len() - 3];
393 path.starts_with(prefix)
394 } else {
395 path == pattern
396 }
397 });
398 if !exempt {
399 return (
400 StatusCode::SERVICE_UNAVAILABLE,
401 [("retry-after", "300")],
402 state.maintenance_message.clone(),
403 )
404 .into_response();
405 }
406 }
407 next.run(request).await
408}
409
410#[cfg(test)]
412pub(crate) fn test_channel() -> tonic::transport::Channel {
413 tonic::transport::Channel::from_static("http://127.0.0.1:1")
414 .connect_timeout(std::time::Duration::from_millis(100))
415 .connect_lazy()
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_minimal_config_server() {
424 let yaml = r#"
425upstream:
426 default: "http://127.0.0.1:50051"
427"#;
428 let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
429 let server = ProxyServer::from_config(config);
430 assert!(server.descriptor_pool.is_none());
431 }
432
433 #[tokio::test]
434 async fn test_maintenance_exempt_matching() {
435 let state = ProxyState {
436 service_name: "test".into(),
437 grpc_upstream: "http://localhost:50051".into(),
438 grpc_channel: test_channel(),
439 maintenance_mode: true,
440 maintenance_exempt: vec![
441 "/health/**".into(),
442 "/.well-known/**".into(),
443 "/metrics".into(),
444 ],
445 maintenance_message: "Down".into(),
446 forwarded_headers: vec![],
447 metrics_namespace: "test".into(),
448 metrics_classes: vec![],
449 };
450
451 let check = |path: &str| -> bool {
452 state.maintenance_exempt.iter().any(|pattern| {
453 if pattern.ends_with("/**") {
454 let prefix = &pattern[..pattern.len() - 3];
455 path.starts_with(prefix)
456 } else {
457 path == pattern
458 }
459 })
460 };
461
462 assert!(check("/health"));
463 assert!(check("/health/ready"));
464 assert!(check("/.well-known/openid-configuration"));
465 assert!(check("/metrics"));
466 assert!(!check("/v1/auth/login"));
467 assert!(!check("/oauth2/token"));
468 }
469}