Skip to main content

structured_proxy/
lib.rs

1//! Universal gRPC→REST transcoding proxy.
2//!
3//! Config-driven: same binary, different YAML = different product proxy.
4//! Works with ANY gRPC service via proto descriptors as config.
5//!
6//! ## Usage
7//!
8//! ```bash
9//! structured-proxy --config sid-proxy.yaml
10//! structured-proxy --config sflow-proxy.yaml
11//! ```
12//!
13//! ## JWT crypto backend
14//!
15//! Exactly one crypto backend feature must be enabled (they are mutually
16//! exclusive): `rust_crypto` (default, pure Rust) or `aws_lc_rs` (opt-in,
17//! constant-time / FIPS-capable, links aws-lc via C FFI). Enabling both or
18//! neither is rejected at compile time by the guards below.
19
20// jsonwebtoken selects its provider from these features and would otherwise
21// panic at runtime on an invalid combination; turn that into a build error.
22#[cfg(all(feature = "rust_crypto", feature = "aws_lc_rs"))]
23compile_error!("features `rust_crypto` and `aws_lc_rs` are mutually exclusive; enable exactly one");
24
25#[cfg(not(any(feature = "rust_crypto", feature = "aws_lc_rs")))]
26compile_error!("exactly one JWT crypto backend must be enabled: `rust_crypto` or `aws_lc_rs`");
27
28pub mod auth;
29pub mod config;
30pub mod oidc;
31pub mod openapi;
32pub mod shield;
33pub mod transcode;
34
35use axum::extract::State;
36use axum::http::{Request, StatusCode};
37use axum::middleware::Next;
38use axum::response::{IntoResponse, Response};
39use axum::routing::get;
40use axum::{Json, Router};
41use prost_reflect::DescriptorPool;
42use std::net::SocketAddr;
43use tower_http::cors::{AllowOrigin, CorsLayer};
44use tower_http::trace::TraceLayer;
45
46use config::{DescriptorSource, ProxyConfig};
47
48/// Shared state for all proxy handlers.
49#[derive(Clone, Debug)]
50pub struct ProxyState {
51    /// Service name from config.
52    pub service_name: String,
53    /// gRPC upstream address.
54    pub grpc_upstream: String,
55    /// Lazy gRPC channel to upstream service.
56    pub grpc_channel: tonic::transport::Channel,
57    /// Maintenance mode active.
58    pub maintenance_mode: bool,
59    /// Maintenance exempt path patterns.
60    pub maintenance_exempt: Vec<String>,
61    /// Maintenance message.
62    pub maintenance_message: String,
63    /// Headers to forward from HTTP to gRPC.
64    pub forwarded_headers: Vec<String>,
65    /// Metrics namespace (derived from service name).
66    pub metrics_namespace: String,
67    /// Path class patterns for metrics.
68    pub metrics_classes: Vec<config::MetricsClassConfig>,
69    /// SSE keep-alive interval (seconds) for server-streaming responses.
70    pub sse_keep_alive_secs: u64,
71}
72
73/// Universal proxy server.
74pub struct ProxyServer {
75    config: ProxyConfig,
76    /// Optional pre-loaded descriptor pool (for embedded mode).
77    descriptor_pool: Option<DescriptorPool>,
78}
79
80impl ProxyServer {
81    /// Create from YAML config file.
82    pub fn from_config(config: ProxyConfig) -> Self {
83        Self {
84            config,
85            descriptor_pool: None,
86        }
87    }
88
89    /// Create with an embedded descriptor pool (for sid-proxy backward compat).
90    pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
91        self.descriptor_pool = Some(pool);
92        self
93    }
94
95    /// Load descriptor pool from configured sources.
96    ///
97    /// Multiple descriptor files are merged into a single pool,
98    /// enabling multi-service proxying from one binary.
99    fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
100        if let Some(pool) = &self.descriptor_pool {
101            return Ok(pool.clone());
102        }
103
104        let mut pool = DescriptorPool::new();
105
106        for source in &self.config.descriptors {
107            match source {
108                DescriptorSource::File { file } => {
109                    let bytes = std::fs::read(file).map_err(|e| {
110                        anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
111                    })?;
112                    pool.decode_file_descriptor_set(bytes.as_slice())
113                        .map_err(|e| {
114                            anyhow::anyhow!("Failed to decode descriptor file {:?}: {}", file, e)
115                        })?;
116                    tracing::info!("Loaded descriptor from {:?}", file);
117                }
118                DescriptorSource::Reflection { reflection } => {
119                    tracing::warn!(
120                        "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
121                        reflection
122                    );
123                }
124                DescriptorSource::Embedded { bytes } => {
125                    pool.decode_file_descriptor_set(*bytes).map_err(|e| {
126                        anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
127                    })?;
128                }
129            }
130        }
131
132        Ok(pool)
133    }
134
135    /// Build the axum router with all endpoints.
136    pub fn router(&self) -> anyhow::Result<Router> {
137        // Enforce cross-field invariants on the embedded path too, where the
138        // config is built directly instead of through `from_yaml_str`.
139        self.config.validate()?;
140        let pool = self.load_descriptors()?;
141
142        let grpc_upstream = self.config.upstream.default.clone();
143        let grpc_channel = tonic::transport::Channel::from_shared(grpc_upstream.clone())
144            .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
145            .connect_timeout(std::time::Duration::from_secs(5))
146            .timeout(std::time::Duration::from_secs(5))
147            .connect_lazy();
148
149        let service_name = self.config.service.name.clone();
150        let metrics_namespace = service_name.replace('-', "_");
151
152        let state = ProxyState {
153            service_name: service_name.clone(),
154            grpc_upstream,
155            grpc_channel,
156            maintenance_mode: self.config.maintenance.enabled,
157            maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
158            maintenance_message: self.config.maintenance.message.clone(),
159            forwarded_headers: self.config.forwarded_headers.clone(),
160            metrics_namespace,
161            metrics_classes: self.config.metrics_classes.clone(),
162            sse_keep_alive_secs: self.config.streaming.sse_keep_alive_secs,
163        };
164
165        let cors = self.build_cors();
166
167        // Build transcoding routes from descriptor pool.
168        let mut transcode_routes = transcode::routes(&pool, &self.config.aliases);
169
170        // External authorization (Envoy ext_authz) gates only the proxied API
171        // routes, never health / metrics / discovery. It runs inside the auth
172        // layer below, so the Check call sees the identity headers the JWT
173        // middleware injected.
174        let authz = match self.config.auth.as_ref().and_then(|a| a.authz.as_ref()) {
175            Some(cfg) => auth::authz::Authz::build(cfg)
176                .map_err(|e| anyhow::anyhow!("invalid authz config: {e}"))?,
177            None => None,
178        };
179        if let Some(authz) = authz {
180            transcode_routes = transcode_routes.layer(axum::middleware::from_fn_with_state(
181                authz,
182                auth::authz::middleware,
183            ));
184        }
185
186        // Health routes
187        let health_service_name = service_name.clone();
188        let health_routes = Router::new()
189            .route(
190                "/health",
191                get({
192                    let name = health_service_name.clone();
193                    move || async move {
194                        Json(serde_json::json!({
195                            "status": "ok",
196                            "service": name,
197                        }))
198                    }
199                }),
200            )
201            .route("/health/live", get(|| async { StatusCode::OK }))
202            .route(
203                "/health/ready",
204                get(|State(state): State<ProxyState>| async move {
205                    let mut client =
206                        tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
207                    match client
208                        .check(tonic_health::pb::HealthCheckRequest {
209                            service: String::new(),
210                        })
211                        .await
212                    {
213                        Ok(resp) => {
214                            let status = resp.into_inner().status;
215                            if status
216                                == tonic_health::pb::health_check_response::ServingStatus::Serving
217                                    as i32
218                            {
219                                StatusCode::OK
220                            } else {
221                                StatusCode::SERVICE_UNAVAILABLE
222                            }
223                        }
224                        Err(_) => StatusCode::SERVICE_UNAVAILABLE,
225                    }
226                }),
227            )
228            .route("/health/startup", get(|| async { StatusCode::OK }))
229            .route(
230                "/metrics",
231                get(|| async {
232                    let encoder = prometheus::TextEncoder::new();
233                    let metric_families = prometheus::default_registry().gather();
234                    match encoder.encode_to_string(&metric_families) {
235                        Ok(text) => (
236                            StatusCode::OK,
237                            [(
238                                axum::http::header::CONTENT_TYPE,
239                                "text/plain; version=0.0.4; charset=utf-8",
240                            )],
241                            text,
242                        )
243                            .into_response(),
244                        Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
245                    }
246                }),
247            );
248
249        // OpenAPI + docs routes (if enabled).
250        let openapi_routes = self.build_openapi_routes(&pool);
251
252        // OIDC discovery routes (if enabled). Public, like the health endpoints.
253        let oidc_routes = match &self.config.oidc_discovery {
254            Some(cfg) => oidc::Oidc::build(cfg)
255                .map_err(|e| anyhow::anyhow!("invalid oidc_discovery config: {e}"))?
256                .map(|o| o.routes())
257                .unwrap_or_default(),
258            None => Router::new(),
259        };
260
261        // Rate limiting (Shield), if configured and enabled.
262        let shield = match &self.config.shield {
263            Some(cfg) => shield::Shield::build(cfg)
264                .map_err(|e| anyhow::anyhow!("invalid shield config: {e}"))?,
265            None => None,
266        };
267
268        // JWT auth, if configured (auth.mode == "jwt").
269        let auth = match &self.config.auth {
270            Some(cfg) => {
271                auth::Auth::build(cfg).map_err(|e| anyhow::anyhow!("invalid auth config: {e}"))?
272            }
273            None => None,
274        };
275
276        let mut router = Router::new()
277            .merge(health_routes)
278            .merge(openapi_routes)
279            .merge(oidc_routes)
280            .merge(transcode_routes)
281            .layer(cors);
282
283        // Forward-auth verification endpoint, sharing the built Auth. Mounted
284        // after the auth layer below so the endpoint itself is not gated by the
285        // JWT middleware (it answers the gate, it isn't behind it).
286        let forward_auth = auth.as_ref().and_then(|built| {
287            auth::forward::ForwardAuth::build(self.config.auth.as_ref()?, built.clone())
288        });
289
290        // Auth runs inside Shield (added first = inner): rate limiting sheds
291        // load before any signature verification work.
292        if let Some(auth) = auth {
293            router = router.layer(axum::middleware::from_fn_with_state(auth, auth::middleware));
294        }
295
296        if let Some(forward_auth) = &forward_auth {
297            router = router.merge(forward_auth.routes());
298        }
299
300        // Shield is added before maintenance so maintenance wraps it (outer
301        // layers run first): a request rejected by the maintenance gate must
302        // not be charged against its rate-limit budget.
303        if let Some(shield) = shield {
304            router = router.layer(axum::middleware::from_fn_with_state(
305                shield,
306                shield::middleware,
307            ));
308        }
309
310        let router = router
311            .layer(axum::middleware::from_fn_with_state(
312                state.clone(),
313                maintenance_middleware,
314            ))
315            .layer(TraceLayer::new_for_http())
316            .with_state(state);
317
318        Ok(router)
319    }
320
321    fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
322        let openapi_config = match &self.config.openapi {
323            Some(cfg) if cfg.enabled => cfg,
324            _ => return Router::new(),
325        };
326
327        let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
328        let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
329        let openapi_path = openapi_config.path.clone();
330        let docs_path = openapi_config.docs_path.clone();
331        let title = openapi_config
332            .title
333            .clone()
334            .unwrap_or_else(|| self.config.service.name.clone());
335        let openapi_path_for_docs = openapi_path.clone();
336
337        tracing::info!("OpenAPI spec at {}, docs at {}", openapi_path, docs_path,);
338
339        Router::new()
340            .route(
341                &openapi_path,
342                get(move || async move {
343                    (
344                        StatusCode::OK,
345                        [(
346                            axum::http::header::CONTENT_TYPE,
347                            "application/json; charset=utf-8",
348                        )],
349                        spec_json,
350                    )
351                }),
352            )
353            .route(
354                &docs_path,
355                get(move || async move {
356                    let html = openapi::docs_html(&openapi_path_for_docs, &title);
357                    (
358                        StatusCode::OK,
359                        [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
360                        html,
361                    )
362                }),
363            )
364    }
365
366    fn build_cors(&self) -> CorsLayer {
367        if self.config.cors.origins.is_empty() {
368            tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
369            CorsLayer::permissive()
370        } else {
371            let origins: Vec<_> = self
372                .config
373                .cors
374                .origins
375                .iter()
376                .filter_map(|o| o.parse().ok())
377                .collect();
378            CorsLayer::new()
379                .allow_origin(AllowOrigin::list(origins))
380                .allow_methods(tower_http::cors::Any)
381                .allow_headers(tower_http::cors::Any)
382                .allow_credentials(true)
383                .expose_headers([
384                    "grpc-status".parse().unwrap(),
385                    "grpc-message".parse().unwrap(),
386                ])
387        }
388    }
389
390    /// Start serving on configured address.
391    pub async fn serve(&self) -> anyhow::Result<()> {
392        let router = self.router()?;
393        let app = router.into_make_service_with_connect_info::<SocketAddr>();
394        let addr: SocketAddr = self.config.listen.http.parse()?;
395        let listener = tokio::net::TcpListener::bind(addr).await?;
396
397        tracing::info!("{} listening on {}", self.config.service.name, addr);
398        axum::serve(listener, app).await?;
399        Ok(())
400    }
401}
402
403/// Maintenance mode middleware.
404async fn maintenance_middleware(
405    State(state): State<ProxyState>,
406    request: Request<axum::body::Body>,
407    next: Next,
408) -> Response {
409    if state.maintenance_mode {
410        let path = request.uri().path();
411        let exempt = state.maintenance_exempt.iter().any(|pattern| {
412            if pattern.ends_with("/**") {
413                let prefix = &pattern[..pattern.len() - 3];
414                path.starts_with(prefix)
415            } else {
416                path == pattern
417            }
418        });
419        if !exempt {
420            return (
421                StatusCode::SERVICE_UNAVAILABLE,
422                [("retry-after", "300")],
423                state.maintenance_message.clone(),
424            )
425                .into_response();
426        }
427    }
428    next.run(request).await
429}
430
431/// Create a lazy gRPC channel for testing (connects to nowhere).
432#[cfg(test)]
433pub(crate) fn test_channel() -> tonic::transport::Channel {
434    tonic::transport::Channel::from_static("http://127.0.0.1:1")
435        .connect_timeout(std::time::Duration::from_millis(100))
436        .connect_lazy()
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_minimal_config_server() {
445        let yaml = r#"
446upstream:
447  default: "http://127.0.0.1:50051"
448"#;
449        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
450        let server = ProxyServer::from_config(config);
451        assert!(server.descriptor_pool.is_none());
452    }
453
454    #[tokio::test]
455    async fn test_maintenance_exempt_matching() {
456        let state = ProxyState {
457            service_name: "test".into(),
458            grpc_upstream: "http://localhost:50051".into(),
459            grpc_channel: test_channel(),
460            maintenance_mode: true,
461            maintenance_exempt: vec![
462                "/health/**".into(),
463                "/.well-known/**".into(),
464                "/metrics".into(),
465            ],
466            maintenance_message: "Down".into(),
467            forwarded_headers: vec![],
468            metrics_namespace: "test".into(),
469            metrics_classes: vec![],
470            sse_keep_alive_secs: 15,
471        };
472
473        let check = |path: &str| -> bool {
474            state.maintenance_exempt.iter().any(|pattern| {
475                if pattern.ends_with("/**") {
476                    let prefix = &pattern[..pattern.len() - 3];
477                    path.starts_with(prefix)
478                } else {
479                    path == pattern
480                }
481            })
482        };
483
484        assert!(check("/health"));
485        assert!(check("/health/ready"));
486        assert!(check("/.well-known/openid-configuration"));
487        assert!(check("/metrics"));
488        assert!(!check("/v1/auth/login"));
489        assert!(!check("/oauth2/token"));
490    }
491}