Skip to main content

structured_proxy/
lib.rs

1//! Universal gRPC→REST transcoding proxy.
2//!
3//! Config-driven: same binary, different YAML = different product proxy.
4//! Works with ANY gRPC service via proto descriptors as config.
5//!
6//! ## Usage
7//!
8//! ```bash
9//! structured-proxy --config sid-proxy.yaml
10//! structured-proxy --config sflow-proxy.yaml
11//! ```
12
13pub mod auth;
14pub mod config;
15pub mod oidc;
16pub mod openapi;
17pub mod shield;
18pub mod transcode;
19
20use axum::extract::State;
21use axum::http::{Request, StatusCode};
22use axum::middleware::Next;
23use axum::response::{IntoResponse, Response};
24use axum::routing::get;
25use axum::{Json, Router};
26use prost_reflect::DescriptorPool;
27use std::net::SocketAddr;
28use tower_http::cors::{AllowOrigin, CorsLayer};
29use tower_http::trace::TraceLayer;
30
31use config::{DescriptorSource, ProxyConfig};
32
33/// Shared state for all proxy handlers.
34#[derive(Clone, Debug)]
35pub struct ProxyState {
36    /// Service name from config.
37    pub service_name: String,
38    /// gRPC upstream address.
39    pub grpc_upstream: String,
40    /// Lazy gRPC channel to upstream service.
41    pub grpc_channel: tonic::transport::Channel,
42    /// Maintenance mode active.
43    pub maintenance_mode: bool,
44    /// Maintenance exempt path patterns.
45    pub maintenance_exempt: Vec<String>,
46    /// Maintenance message.
47    pub maintenance_message: String,
48    /// Headers to forward from HTTP to gRPC.
49    pub forwarded_headers: Vec<String>,
50    /// Metrics namespace (derived from service name).
51    pub metrics_namespace: String,
52    /// Path class patterns for metrics.
53    pub metrics_classes: Vec<config::MetricsClassConfig>,
54}
55
56/// Universal proxy server.
57pub struct ProxyServer {
58    config: ProxyConfig,
59    /// Optional pre-loaded descriptor pool (for embedded mode).
60    descriptor_pool: Option<DescriptorPool>,
61}
62
63impl ProxyServer {
64    /// Create from YAML config file.
65    pub fn from_config(config: ProxyConfig) -> Self {
66        Self {
67            config,
68            descriptor_pool: None,
69        }
70    }
71
72    /// Create with an embedded descriptor pool (for sid-proxy backward compat).
73    pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
74        self.descriptor_pool = Some(pool);
75        self
76    }
77
78    /// Load descriptor pool from configured sources.
79    ///
80    /// Multiple descriptor files are merged into a single pool,
81    /// enabling multi-service proxying from one binary.
82    fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
83        if let Some(pool) = &self.descriptor_pool {
84            return Ok(pool.clone());
85        }
86
87        let mut pool = DescriptorPool::new();
88
89        for source in &self.config.descriptors {
90            match source {
91                DescriptorSource::File { file } => {
92                    let bytes = std::fs::read(file).map_err(|e| {
93                        anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
94                    })?;
95                    pool.decode_file_descriptor_set(bytes.as_slice())
96                        .map_err(|e| {
97                            anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
98                        })?;
99                    tracing::info!("Loaded descriptor from {:?}", file);
100                }
101                DescriptorSource::Reflection { reflection } => {
102                    tracing::warn!(
103                        "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
104                        reflection
105                    );
106                }
107                DescriptorSource::Embedded { bytes } => {
108                    pool.decode_file_descriptor_set(*bytes).map_err(|e| {
109                        anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
110                    })?;
111                }
112            }
113        }
114
115        Ok(pool)
116    }
117
118    /// Build the axum router with all endpoints.
119    pub fn router(&self) -> anyhow::Result<Router> {
120        let pool = self.load_descriptors()?;
121
122        let grpc_upstream = self.config.upstream.default.clone();
123        let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
124            .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
125            .connect_timeout(std::time::Duration::from_secs(5))
126            .timeout(std::time::Duration::from_secs(5))
127            .connect_lazy();
128
129        let service_name = self.config.service.name.clone();
130        let metrics_namespace = service_name.replace('-', "_");
131
132        let state = ProxyState {
133            service_name: service_name.clone(),
134            grpc_upstream,
135            grpc_channel,
136            maintenance_mode: self.config.maintenance.enabled,
137            maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
138            maintenance_message: self.config.maintenance.message.clone(),
139            forwarded_headers: self.config.forwarded_headers.clone(),
140            metrics_namespace,
141            metrics_classes: self.config.metrics_classes.clone(),
142        };
143
144        let cors = self.build_cors();
145
146        // Build transcoding routes from descriptor pool.
147        let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
148
149        // External authorization (Envoy ext_authz) gates only the proxied API
150        // routes, never health / metrics / discovery. It runs inside the auth
151        // layer below, so the Check call sees the identity headers the JWT
152        // middleware injected.
153        let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
154            Some(cfg) => auth::authz::Authz::build(cfg)
155                .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
156            None => None,
157        };
158        if let Some(authz) = authz {
159            transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
160                authz,
161                auth::authz::middleware,
162            ));
163        }
164
165        // Health routes
166        let health_service_name = service_name.clone();
167        let health_routes = Router::new()
168            .route(
169                "/health",
170                get({
171                    let name = health_service_name.clone();
172                    move || async move {
173                        Json(serde_json::json!({
174                            "status": "ok",
175                            "service": name,
176                        }))
177                    }
178                }),
179            )
180            .route("/health/live", get(|| async { StatusCode::OK }))
181            .route(
182                "/health/ready",
183                get(|State(state): State<ProxyState>| async move {
184                    let mut client =
185                        tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
186                    match client
187                        .check(tonic_health::pb::HealthCheckRequest {
188                            service: String::new(),
189                        })
190                        .await
191                    {
192                        Ok(resp) => {
193                            let status = resp.into_inner().status;
194                            if status
195                                == tonic_health::pb::health_check_response::ServingStatus::Serving
196                                    as i32
197                            {
198                                StatusCode::OK
199                            } else {
200                                StatusCode::SERVICE_UNAVAILABLE
201                            }
202                        }
203                        Err(_) => StatusCode::SERVICE_UNAVAILABLE,
204                    }
205                }),
206            )
207            .route("/health/startup", get(|| async { StatusCode::OK }))
208            .route(
209                "/metrics",
210                get(|| async {
211                    let encoder = prometheus::TextEncoder::new();
212                    let metric_families = prometheus::default_registry().gather();
213                    match encoder.encode_to_string(&metric_families) {
214                        Ok(text) => (
215                            StatusCode::OK,
216                            [(
217                                axum::http::header::CONTENT_TYPE,
218                                "text/plain; version=0.0.4; charset=utf-8",
219                            )],
220                            text,
221                        )
222                            .into_response(),
223                        Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
224                    }
225                }),
226            );
227
228        // OpenAPI + docs routes (if enabled).
229        let openapi_routes = self.build_openapi_routes(&pool);
230
231        // OIDC discovery routes (if enabled). Public, like the health endpoints.
232        let oidc_routes = match &self.config.oidc_discovery {
233            Some(cfg) => oidc::Oidc::build(cfg)
234                .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
235                .map(|o| o.routes())
236                .unwrap_or_default(),
237            None => Router::new(),
238        };
239
240        // Rate limiting (Shield), if configured and enabled.
241        let shield = match &self.config.shield {
242            Some(cfg) => shield::Shield::build(cfg)
243                .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
244            None => None,
245        };
246
247        // JWT auth, if configured (auth.mode == "jwt").
248        let auth = match &self.config.auth {
249            Some(cfg) => {
250                auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
251            }
252            None => None,
253        };
254
255        let mut router = Router::new()
256            .merge(health_routes)
257            .merge(openapi_routes)
258            .merge(oidc_routes)
259            .merge(transcode_routes)
260            .layer(cors);
261
262        // Forward-auth verification endpoint, sharing the built Auth. Mounted
263        // after the auth layer below so the endpoint itself is not gated by the
264        // JWT middleware (it answers the gate, it isn't behind it).
265        let forward_auth = auth.as_ref().and_then(|built| {
266            auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
267        });
268
269        // Auth runs inside Shield (added first = inner): rate limiting sheds
270        // load before any signature verification work.
271        if let Some(auth) = auth {
272            router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
273        }
274
275        if let Some(forward_auth) = &forward_auth {
276            router = router.merge(forward_auth.routes());
277        }
278
279        // Shield is added before maintenance so maintenance wraps it (outer
280        // layers run first): a request rejected by the maintenance gate must
281        // not be charged against its rate-limit budget.
282        if let Some(shield) = shield {
283            router = router.layer(axum::middleware::from_fn_with_state(
284                shield,
285                shield::middleware,
286            ));
287        }
288
289        let router = router
290            .layer(axum::middleware::from_fn_with_state(
291                state.clone(),
292                maintenance_middleware,
293            ))
294            .layer(TraceLayer::new_for_http())
295            .with_state(state);
296
297        Ok(router)
298    }
299
300    fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
301        let openapi_config = match &self.config.openapi {
302            Some(cfg) if cfg.enabled => cfg,
303            _ => return Router::new(),
304        };
305
306        let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
307        let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
308        let openapi_path = openapi_config.path.clone();
309        let docs_path = openapi_config.docs_path.clone();
310        let title = openapi_config
311            .title
312            .clone()
313            .unwrap_or_else(|| self.config.service.name.clone());
314        let openapi_path_for_docs = openapi_path.clone();
315
316        tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
317
318        Router::new()
319            .route(
320                &openapi_path,
321                get(move || async move {
322                    (
323                        StatusCode::OK,
324                        [(
325                            axum::http::header::CONTENT_TYPE,
326                            "application/json; charset=utf-8",
327                        )],
328                        spec_json,
329                    )
330                }),
331            )
332            .route(
333                &docs_path,
334                get(move || async move {
335                    let html = openapi::docs_html(&openapi_path_for_docs, &title);
336                    (
337                        StatusCode::OK,
338                        [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
339                        html,
340                    )
341                }),
342            )
343    }
344
345    fn build_cors(&self) -> CorsLayer {
346        if self.config.cors.origins.is_empty() {
347            tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
348            CorsLayer::permissive()
349        } else {
350            let origins: Vec<_> = self
351                .config
352                .cors
353                .origins
354                .iter()
355                .filter_map(|o| o.parse().ok())
356                .collect();
357            CorsLayer::new()
358                .allow_origin(AllowOrigin::list(origins))
359                .allow_methods(tower_http::cors::Any)
360                .allow_headers(tower_http::cors::Any)
361                .allow_credentials(true)
362                .expose_headers([
363                    "grpc-status".parse().unwrap(),
364                    "grpc-message".parse().unwrap(),
365                ])
366        }
367    }
368
369    /// Start serving on configured address.
370    pub async fn serve(&self) -> anyhow::Result<()> {
371        let router = self.router()?;
372        let app = router.into_make_service_with_connect_info::<SocketAddr>();
373        let addr: SocketAddr = self.config.listen.http.parse()?;
374        let listener = tokio::net::TcpListener::bind(addr).await?;
375
376        tracing::info!("{} listening on {}", self.config.service.name, addr);
377        axum::serve(listener, app).await?;
378        Ok(())
379    }
380}
381
382/// Maintenance mode middleware.
383async fn maintenance_middleware(
384    State(state): State<ProxyState>,
385    request: Request<axum::body::Body>,
386    next: Next,
387) -> Response {
388    if state.maintenance_mode {
389        let path = request.uri().path();
390        let exempt = state.maintenance_exempt.iter().any(|pattern| {
391            if pattern.ends_with("/**") {
392                let prefix = &pattern[..pattern.len() - 3];
393                path.starts_with(prefix)
394            } else {
395                path == pattern
396            }
397        });
398        if !exempt {
399            return (
400                StatusCode::SERVICE_UNAVAILABLE,
401                [("retry-after", "300")],
402                state.maintenance_message.clone(),
403            )
404                .into_response();
405        }
406    }
407    next.run(request).await
408}
409
410/// Create a lazy gRPC channel for testing (connects to nowhere).
411#[cfg(test)]
412pub(crate) fn test_channel() -> tonic::transport::Channel {
413    tonic::transport::Channel::from_static("http://127.0.0.1:1")
414        .connect_timeout(std::time::Duration::from_millis(100))
415        .connect_lazy()
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_minimal_config_server() {
424        let yaml = r#"
425upstream:
426  default: "http://127.0.0.1:50051"
427"#;
428        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
429        let server = ProxyServer::from_config(config);
430        assert!(server.descriptor_pool.is_none());
431    }
432
433    #[tokio::test]
434    async fn test_maintenance_exempt_matching() {
435        let state = ProxyState {
436            service_name: "test".into(),
437            grpc_upstream: "http://localhost:50051".into(),
438            grpc_channel: test_channel(),
439            maintenance_mode: true,
440            maintenance_exempt: vec![
441                "/health/**".into(),
442                "/.well-known/**".into(),
443                "/metrics".into(),
444            ],
445            maintenance_message: "Down".into(),
446            forwarded_headers: vec![],
447            metrics_namespace: "test".into(),
448            metrics_classes: vec![],
449        };
450
451        let check = |path: &str| -> bool {
452            state.maintenance_exempt.iter().any(|pattern| {
453                if pattern.ends_with("/**") {
454                    let prefix = &pattern[..pattern.len() - 3];
455                    path.starts_with(prefix)
456                } else {
457                    path == pattern
458                }
459            })
460        };
461
462        assert!(check("/health"));
463        assert!(check("/health/ready"));
464        assert!(check("/.well-known/openid-configuration"));
465        assert!(check("/metrics"));
466        assert!(!check("/v1/auth/login"));
467        assert!(!check("/oauth2/token"));
468    }
469}