1pub mod config;
14pub mod openapi;
15pub mod transcode;
16
17use axum::extract::State;
18use axum::http::{Request, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::get;
22use axum::{Json, Router};
23use prost_reflect::DescriptorPool;
24use std::net::SocketAddr;
25use tower_http::cors::{AllowOrigin, CorsLayer};
26use tower_http::trace::TraceLayer;
27
28use config::{DescriptorSource, ProxyConfig};
29
30#[derive(Clone, Debug)]
32pub struct ProxyState {
33 pub service_name: String,
35 pub grpc_upstream: String,
37 pub grpc_channel: tonic::transport::Channel,
39 pub maintenance_mode: bool,
41 pub maintenance_exempt: Vec<String>,
43 pub maintenance_message: String,
45 pub forwarded_headers: Vec<String>,
47 pub metrics_namespace: String,
49 pub metrics_classes: Vec<config::MetricsClassConfig>,
51}
52
53pub struct ProxyServer {
55 config: ProxyConfig,
56 descriptor_pool: Option<DescriptorPool>,
58}
59
60impl ProxyServer {
61 pub fn from_config(config: ProxyConfig) -> Self {
63 Self {
64 config,
65 descriptor_pool: None,
66 }
67 }
68
69 pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
71 self.descriptor_pool = Some(pool);
72 self
73 }
74
75 fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
80 if let Some(pool) = &self.descriptor_pool {
81 return Ok(pool.clone());
82 }
83
84 let mut pool = DescriptorPool::new();
85
86 for source in &self.config.descriptors {
87 match source {
88 DescriptorSource::File { file } => {
89 let bytes = std::fs::read(file).map_err(|e| {
90 anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
91 })?;
92 pool.decode_file_descriptor_set(bytes.as_slice())
93 .map_err(|e| {
94 anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
95 })?;
96 tracing::info!("Loaded descriptor from {:?}", file);
97 }
98 DescriptorSource::Reflection { reflection } => {
99 tracing::warn!(
100 "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
101 reflection
102 );
103 }
104 DescriptorSource::Embedded { bytes } => {
105 pool.decode_file_descriptor_set(*bytes).map_err(|e| {
106 anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
107 })?;
108 }
109 }
110 }
111
112 Ok(pool)
113 }
114
115 pub fn router(&self) -> anyhow::Result<Router> {
117 let pool = self.load_descriptors()?;
118
119 let grpc_upstream = self.config.upstream.default.clone();
120 let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
121 .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
122 .connect_timeout(std::time::Duration::from_secs(5))
123 .timeout(std::time::Duration::from_secs(5))
124 .connect_lazy();
125
126 let service_name = self.config.service.name.clone();
127 let metrics_namespace = service_name.replace('-', "_");
128
129 let state = ProxyState {
130 service_name: service_name.clone(),
131 grpc_upstream,
132 grpc_channel,
133 maintenance_mode: self.config.maintenance.enabled,
134 maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
135 maintenance_message: self.config.maintenance.message.clone(),
136 forwarded_headers: self.config.forwarded_headers.clone(),
137 metrics_namespace,
138 metrics_classes: self.config.metrics_classes.clone(),
139 };
140
141 let cors = self.build_cors();
142
143 let transcode_routes = transcode::routes(&pool, &self.config.aliases);
145
146 let health_service_name = service_name.clone();
148 let health_routes = Router::new()
149 .route(
150 "/health",
151 get({
152 let name = health_service_name.clone();
153 move || async move {
154 Json(serde_json::json!({
155 "status": "ok",
156 "service": name,
157 }))
158 }
159 }),
160 )
161 .route("/health/live", get(|| async { StatusCode::OK }))
162 .route(
163 "/health/ready",
164 get(|State(state): State<ProxyState>| async move {
165 let mut client =
166 tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
167 match client
168 .check(tonic_health::pb::HealthCheckRequest {
169 service: String::new(),
170 })
171 .await
172 {
173 Ok(resp) => {
174 let status = resp.into_inner().status;
175 if status
176 == tonic_health::pb::health_check_response::ServingStatus::Serving
177 as i32
178 {
179 StatusCode::OK
180 } else {
181 StatusCode::SERVICE_UNAVAILABLE
182 }
183 }
184 Err(_) => StatusCode::SERVICE_UNAVAILABLE,
185 }
186 }),
187 )
188 .route("/health/startup", get(|| async { StatusCode::OK }))
189 .route(
190 "/metrics",
191 get(|| async {
192 let encoder = prometheus::TextEncoder::new();
193 let metric_families = prometheus::default_registry().gather();
194 match encoder.encode_to_string(&metric_families) {
195 Ok(text) => (
196 StatusCode::OK,
197 [(
198 axum::http::header::CONTENT_TYPE,
199 "text/plain; version=0.0.4; charset=utf-8",
200 )],
201 text,
202 )
203 .into_response(),
204 Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
205 }
206 }),
207 );
208
209 let openapi_routes = self.build_openapi_routes(&pool);
211
212 let router = Router::new()
213 .merge(health_routes)
214 .merge(openapi_routes)
215 .merge(transcode_routes)
216 .layer(cors)
217 .layer(axum::middleware::from_fn_with_state(
218 state.clone(),
219 maintenance_middleware,
220 ))
221 .layer(TraceLayer::new_for_http())
222 .with_state(state);
223
224 Ok(router)
225 }
226
227 fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
228 let openapi_config = match &self.config.openapi {
229 Some(cfg) if cfg.enabled => cfg,
230 _ => return Router::new(),
231 };
232
233 let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
234 let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
235 let openapi_path = openapi_config.path.clone();
236 let docs_path = openapi_config.docs_path.clone();
237 let title = openapi_config
238 .title
239 .clone()
240 .unwrap_or_else(|| self.config.service.name.clone());
241 let openapi_path_for_docs = openapi_path.clone();
242
243 tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
244
245 Router::new()
246 .route(
247 &openapi_path,
248 get(move || async move {
249 (
250 StatusCode::OK,
251 [(
252 axum::http::header::CONTENT_TYPE,
253 "application/json; charset=utf-8",
254 )],
255 spec_json,
256 )
257 }),
258 )
259 .route(
260 &docs_path,
261 get(move || async move {
262 let html = openapi::docs_html(&openapi_path_for_docs, &title);
263 (
264 StatusCode::OK,
265 [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
266 html,
267 )
268 }),
269 )
270 }
271
272 fn build_cors(&self) -> CorsLayer {
273 if self.config.cors.origins.is_empty() {
274 tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
275 CorsLayer::permissive()
276 } else {
277 let origins: Vec<_> = self
278 .config
279 .cors
280 .origins
281 .iter()
282 .filter_map(|o| o.parse().ok())
283 .collect();
284 CorsLayer::new()
285 .allow_origin(AllowOrigin::list(origins))
286 .allow_methods(tower_http::cors::Any)
287 .allow_headers(tower_http::cors::Any)
288 .allow_credentials(true)
289 .expose_headers([
290 "grpc-status".parse().unwrap(),
291 "grpc-message".parse().unwrap(),
292 ])
293 }
294 }
295
296 pub async fn serve(&self) -> anyhow::Result<()> {
298 let router = self.router()?;
299 let app = router.into_make_service_with_connect_info::<SocketAddr>();
300 let addr: SocketAddr = self.config.listen.http.parse()?;
301 let listener = tokio::net::TcpListener::bind(addr).await?;
302
303 tracing::info!("{} listening on {}", self.config.service.name, addr);
304 axum::serve(listener, app).await?;
305 Ok(())
306 }
307}
308
309async fn maintenance_middleware(
311 State(state): State<ProxyState>,
312 request: Request<axum::body::Body>,
313 next: Next,
314) -> Response {
315 if state.maintenance_mode {
316 let path = request.uri().path();
317 let exempt = state.maintenance_exempt.iter().any(|pattern| {
318 if pattern.ends_with("/**") {
319 let prefix = &pattern[..pattern.len() - 3];
320 path.starts_with(prefix)
321 } else {
322 path == pattern
323 }
324 });
325 if !exempt {
326 return (
327 StatusCode::SERVICE_UNAVAILABLE,
328 [("retry-after", "300")],
329 state.maintenance_message.clone(),
330 )
331 .into_response();
332 }
333 }
334 next.run(request).await
335}
336
337#[cfg(test)]
339pub(crate) fn test_channel() -> tonic::transport::Channel {
340 tonic::transport::Channel::from_static("http://127.0.0.1:1")
341 .connect_timeout(std::time::Duration::from_millis(100))
342 .connect_lazy()
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_minimal_config_server() {
351 let yaml = r#"
352upstream:
353 default: "http://127.0.0.1:50051"
354"#;
355 let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
356 let server = ProxyServer::from_config(config);
357 assert!(server.descriptor_pool.is_none());
358 }
359
360 #[tokio::test]
361 async fn test_maintenance_exempt_matching() {
362 let state = ProxyState {
363 service_name: "test".into(),
364 grpc_upstream: "http://localhost:50051".into(),
365 grpc_channel: test_channel(),
366 maintenance_mode: true,
367 maintenance_exempt: vec![
368 "/health/**".into(),
369 "/.well-known/**".into(),
370 "/metrics".into(),
371 ],
372 maintenance_message: "Down".into(),
373 forwarded_headers: vec![],
374 metrics_namespace: "test".into(),
375 metrics_classes: vec![],
376 };
377
378 let check = |path: &str| -> bool {
379 state.maintenance_exempt.iter().any(|pattern| {
380 if pattern.ends_with("/**") {
381 let prefix = &pattern[..pattern.len() - 3];
382 path.starts_with(prefix)
383 } else {
384 path == pattern
385 }
386 })
387 };
388
389 assert!(check("/health"));
390 assert!(check("/health/ready"));
391 assert!(check("/.well-known/openid-configuration"));
392 assert!(check("/metrics"));
393 assert!(!check("/v1/auth/login"));
394 assert!(!check("/oauth2/token"));
395 }
396}