Skip to main content

structured_proxy/
lib.rs

1//! Universal gRPC→REST transcoding proxy.
2//!
3//! Config-driven: same binary, different YAML = different product proxy.
4//! Works with ANY gRPC service via proto descriptors as config.
5//!
6//! ## Usage
7//!
8//! ```bash
9//! structured-proxy --config sid-proxy.yaml
10//! structured-proxy --config sflow-proxy.yaml
11//! ```
12
13pub mod config;
14pub mod openapi;
15pub mod transcode;
16
17use axum::extract::State;
18use axum::http::{Request, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use axum::routing::get;
22use axum::{Json, Router};
23use prost_reflect::DescriptorPool;
24use std::net::SocketAddr;
25use tower_http::cors::{AllowOrigin, CorsLayer};
26use tower_http::trace::TraceLayer;
27
28use config::{DescriptorSource, ProxyConfig};
29
30/// Shared state for all proxy handlers.
31#[derive(Clone, Debug)]
32pub struct ProxyState {
33    /// Service name from config.
34    pub service_name: String,
35    /// gRPC upstream address.
36    pub grpc_upstream: String,
37    /// Lazy gRPC channel to upstream service.
38    pub grpc_channel: tonic::transport::Channel,
39    /// Maintenance mode active.
40    pub maintenance_mode: bool,
41    /// Maintenance exempt path patterns.
42    pub maintenance_exempt: Vec<String>,
43    /// Maintenance message.
44    pub maintenance_message: String,
45    /// Headers to forward from HTTP to gRPC.
46    pub forwarded_headers: Vec<String>,
47    /// Metrics namespace (derived from service name).
48    pub metrics_namespace: String,
49    /// Path class patterns for metrics.
50    pub metrics_classes: Vec<config::MetricsClassConfig>,
51}
52
53/// Universal proxy server.
54pub struct ProxyServer {
55    config: ProxyConfig,
56    /// Optional pre-loaded descriptor pool (for embedded mode).
57    descriptor_pool: Option<DescriptorPool>,
58}
59
60impl ProxyServer {
61    /// Create from YAML config file.
62    pub fn from_config(config: ProxyConfig) -> Self {
63        Self {
64            config,
65            descriptor_pool: None,
66        }
67    }
68
69    /// Create with an embedded descriptor pool (for sid-proxy backward compat).
70    pub fn with_descriptors(mut self, pool: DescriptorPool) -> Self {
71        self.descriptor_pool = Some(pool);
72        self
73    }
74
75    /// Load descriptor pool from configured sources.
76    ///
77    /// Multiple descriptor files are merged into a single pool,
78    /// enabling multi-service proxying from one binary.
79    fn load_descriptors(&self) -> anyhow::Result<DescriptorPool> {
80        if let Some(pool) = &self.descriptor_pool {
81            return Ok(pool.clone());
82        }
83
84        let mut pool = DescriptorPool::new();
85
86        for source in &self.config.descriptors {
87            match source {
88                DescriptorSource::File { file } => {
89                    let bytes = std::fs::read(file).map_err(|e| {
90                        anyhow::anyhow!("Failed to read descriptor file {:?}: {}", file, e)
91                    })?;
92                    pool.decode_file_descriptor_set(bytes.as_slice()).map_err(|e| {
93                        anyhow::anyhow!(
94                            "Failed to decode descriptor file {:?}: {}",
95                            file,
96                            e
97                        )
98                    })?;
99                    tracing::info!("Loaded descriptor from {:?}", file);
100                }
101                DescriptorSource::Reflection { reflection } => {
102                    tracing::warn!(
103                        "gRPC reflection client not supported — use descriptor files instead (reflection endpoint: {})",
104                        reflection
105                    );
106                }
107                DescriptorSource::Embedded { bytes } => {
108                    pool.decode_file_descriptor_set(*bytes).map_err(|e| {
109                        anyhow::anyhow!("Failed to decode embedded descriptors: {}", e)
110                    })?;
111                }
112            }
113        }
114
115        Ok(pool)
116    }
117
118    /// Build the axum router with all endpoints.
119    pub fn router(&self) -> anyhow::Result<Router> {
120        let pool = self.load_descriptors()?;
121
122        let grpc_upstream = self.config.upstream.default.clone();
123        let grpc_channel =
124            tonic::transport::Channel::from_shared(grpc_upstream.clone())
125                .map_err(|e| anyhow::anyhow!("invalid gRPC upstream URL: {}", e))?
126                .connect_timeout(std::time::Duration::from_secs(5))
127                .timeout(std::time::Duration::from_secs(5))
128                .connect_lazy();
129
130        let service_name = self.config.service.name.clone();
131        let metrics_namespace = service_name.replace('-', "_");
132
133        let state = ProxyState {
134            service_name: service_name.clone(),
135            grpc_upstream,
136            grpc_channel,
137            maintenance_mode: self.config.maintenance.enabled,
138            maintenance_exempt: self.config.maintenance.exempt_paths.clone(),
139            maintenance_message: self.config.maintenance.message.clone(),
140            forwarded_headers: self.config.forwarded_headers.clone(),
141            metrics_namespace,
142            metrics_classes: self.config.metrics_classes.clone(),
143        };
144
145        let cors = self.build_cors();
146
147        // Build transcoding routes from descriptor pool
148        let transcode_routes = transcode::routes(&pool, &self.config.aliases);
149
150        // Health routes
151        let health_service_name = service_name.clone();
152        let health_routes = Router::new()
153            .route(
154                "/health",
155                get({
156                    let name = health_service_name.clone();
157                    move || async move {
158                        Json(serde_json::json!({
159                            "status": "ok",
160                            "service": name,
161                        }))
162                    }
163                }),
164            )
165            .route("/health/live", get(|| async { StatusCode::OK }))
166            .route(
167                "/health/ready",
168                get(|State(state): State<ProxyState>| async move {
169                    let mut client =
170                        tonic_health::pb::health_client::HealthClient::new(state.grpc_channel);
171                    match client
172                        .check(tonic_health::pb::HealthCheckRequest {
173                            service: String::new(),
174                        })
175                        .await
176                    {
177                        Ok(resp) => {
178                            let status = resp.into_inner().status;
179                            if status
180                                == tonic_health::pb::health_check_response::ServingStatus::Serving
181                                    as i32
182                            {
183                                StatusCode::OK
184                            } else {
185                                StatusCode::SERVICE_UNAVAILABLE
186                            }
187                        }
188                        Err(_) => StatusCode::SERVICE_UNAVAILABLE,
189                    }
190                }),
191            )
192            .route("/health/startup", get(|| async { StatusCode::OK }))
193            .route(
194                "/metrics",
195                get(|| async {
196                    let encoder = prometheus::TextEncoder::new();
197                    let metric_families = prometheus::default_registry().gather();
198                    match encoder.encode_to_string(&metric_families) {
199                        Ok(text) => (
200                            StatusCode::OK,
201                            [(
202                                axum::http::header::CONTENT_TYPE,
203                                "text/plain; version=0.0.4; charset=utf-8",
204                            )],
205                            text,
206                        )
207                            .into_response(),
208                        Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
209                    }
210                }),
211            );
212
213        // OpenAPI + docs routes (if enabled).
214        let openapi_routes = self.build_openapi_routes(&pool);
215
216        let router = Router::new()
217            .merge(health_routes)
218            .merge(openapi_routes)
219            .merge(transcode_routes)
220            .layer(cors)
221            .layer(axum::middleware::from_fn_with_state(
222                state.clone(),
223                maintenance_middleware,
224            ))
225            .layer(TraceLayer::new_for_http())
226            .with_state(state);
227
228        Ok(router)
229    }
230
231    fn build_openapi_routes(&self, pool: &DescriptorPool) -> Router<ProxyState> {
232        let openapi_config = match &self.config.openapi {
233            Some(cfg) if cfg.enabled => cfg,
234            _ => return Router::new(),
235        };
236
237        let spec = openapi::generate(pool, openapi_config, &self.config.aliases);
238        let spec_json = serde_json::to_string_pretty(&spec).unwrap_or_default();
239        let openapi_path = openapi_config.path.clone();
240        let docs_path = openapi_config.docs_path.clone();
241        let title = openapi_config
242            .title
243            .clone()
244            .unwrap_or_else(|| self.config.service.name.clone());
245        let openapi_path_for_docs = openapi_path.clone();
246
247        tracing::info!(
248            "OpenAPI spec at {}, docs at {}",
249            openapi_path,
250            docs_path,
251        );
252
253        Router::new()
254            .route(
255                &openapi_path,
256                get(move || async move {
257                    (
258                        StatusCode::OK,
259                        [(
260                            axum::http::header::CONTENT_TYPE,
261                            "application/json; charset=utf-8",
262                        )],
263                        spec_json,
264                    )
265                }),
266            )
267            .route(
268                &docs_path,
269                get(move || async move {
270                    let html = openapi::docs_html(&openapi_path_for_docs, &title);
271                    (
272                        StatusCode::OK,
273                        [(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
274                        html,
275                    )
276                }),
277            )
278    }
279
280    fn build_cors(&self) -> CorsLayer {
281        if self.config.cors.origins.is_empty() {
282            tracing::warn!("CORS origins not set — using permissive CORS (dev mode)");
283            CorsLayer::permissive()
284        } else {
285            let origins: Vec<_> = self
286                .config
287                .cors
288                .origins
289                .iter()
290                .filter_map(|o| o.parse().ok())
291                .collect();
292            CorsLayer::new()
293                .allow_origin(AllowOrigin::list(origins))
294                .allow_methods(tower_http::cors::Any)
295                .allow_headers(tower_http::cors::Any)
296                .allow_credentials(true)
297                .expose_headers([
298                    "grpc-status".parse().unwrap(),
299                    "grpc-message".parse().unwrap(),
300                ])
301        }
302    }
303
304    /// Start serving on configured address.
305    pub async fn serve(&self) -> anyhow::Result<()> {
306        let router = self.router()?;
307        let app = router.into_make_service_with_connect_info::<SocketAddr>();
308        let addr: SocketAddr = self.config.listen.http.parse()?;
309        let listener = tokio::net::TcpListener::bind(addr).await?;
310
311        tracing::info!(
312            "{} listening on {}",
313            self.config.service.name,
314            addr
315        );
316        axum::serve(listener, app).await?;
317        Ok(())
318    }
319}
320
321/// Maintenance mode middleware.
322async fn maintenance_middleware(
323    State(state): State<ProxyState>,
324    request: Request<axum::body::Body>,
325    next: Next,
326) -> Response {
327    if state.maintenance_mode {
328        let path = request.uri().path();
329        let exempt = state.maintenance_exempt.iter().any(|pattern| {
330            if pattern.ends_with("/**") {
331                let prefix = &pattern[..pattern.len() - 3];
332                path.starts_with(prefix)
333            } else {
334                path == pattern
335            }
336        });
337        if !exempt {
338            return (
339                StatusCode::SERVICE_UNAVAILABLE,
340                [("retry-after", "300")],
341                state.maintenance_message.clone(),
342            )
343                .into_response();
344        }
345    }
346    next.run(request).await
347}
348
349/// Create a lazy gRPC channel for testing (connects to nowhere).
350#[cfg(test)]
351pub(crate) fn test_channel() -> tonic::transport::Channel {
352    tonic::transport::Channel::from_static("http://127.0.0.1:1")
353        .connect_timeout(std::time::Duration::from_millis(100))
354        .connect_lazy()
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_minimal_config_server() {
363        let yaml = r#"
364upstream:
365  default: "http://127.0.0.1:50051"
366"#;
367        let config: ProxyConfig = serde_yaml::from_str(yaml).unwrap();
368        let server = ProxyServer::from_config(config);
369        assert!(server.descriptor_pool.is_none());
370    }
371
372    #[tokio::test]
373    async fn test_maintenance_exempt_matching() {
374        let state = ProxyState {
375            service_name: "test".into(),
376            grpc_upstream: "http://localhost:50051".into(),
377            grpc_channel: test_channel(),
378            maintenance_mode: true,
379            maintenance_exempt: vec![
380                "/health/**".into(),
381                "/.well-known/**".into(),
382                "/metrics".into(),
383            ],
384            maintenance_message: "Down".into(),
385            forwarded_headers: vec![],
386            metrics_namespace: "test".into(),
387            metrics_classes: vec![],
388        };
389
390        let check = |path: &str| -> bool {
391            state.maintenance_exempt.iter().any(|pattern| {
392                if pattern.ends_with("/**") {
393                    let prefix = &pattern[..pattern.len() - 3];
394                    path.starts_with(prefix)
395                } else {
396                    path == pattern
397                }
398            })
399        };
400
401        assert!(check("/health"));
402        assert!(check("/health/ready"));
403        assert!(check("/.well-known/openid-configuration"));
404        assert!(check("/metrics"));
405        assert!(!check("/v1/auth/login"));
406        assert!(!check("/oauth2/token"));
407    }
408}