Skip to main content

zeph_gateway/
server.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::SocketAddr;
5use std::time::Instant;
6
7use tokio::sync::{mpsc, watch};
8
9use crate::error::GatewayError;
10use crate::router::build_router;
11
12/// Shared state threaded through every axum handler.
13///
14/// Cloned cheaply for each request because all fields are either `Clone` or
15/// wrapped in `Arc`-backed primitives.
16#[derive(Clone)]
17pub(crate) struct AppState {
18    /// Channel used to forward sanitised webhook messages to the agent.
19    pub webhook_tx: mpsc::Sender<String>,
20    /// Monotonic timestamp recorded when the server started, used by `/health`.
21    pub started_at: Instant,
22}
23
24/// HTTP gateway server with bearer-auth, rate limiting, and body-size enforcement.
25///
26/// Build the server with [`GatewayServer::new`], apply optional configuration via
27/// the builder methods, then drive it with [`GatewayServer::serve`].
28///
29/// # Defaults
30///
31/// | Setting | Default |
32/// |---|---|
33/// | Bearer auth | disabled (open) |
34/// | Rate limit | 120 requests / 60 s per IP |
35/// | Max body size | 1 MiB (1 048 576 bytes) |
36///
37/// # Example
38///
39/// ```no_run
40/// use tokio::sync::{mpsc, watch};
41/// use zeph_gateway::GatewayServer;
42///
43/// #[tokio::main]
44/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
45///     let (tx, _rx) = mpsc::channel::<String>(64);
46///     let (_stx, srx) = watch::channel(false);
47///
48///     GatewayServer::new("127.0.0.1", 9000, tx, srx)
49///         .with_auth(Some("hunter2".into()))
50///         .with_rate_limit(30)
51///         .with_max_body_size(512 * 1024)
52///         .serve()
53///         .await?;
54///
55///     Ok(())
56/// }
57/// ```
58pub struct GatewayServer {
59    addr: SocketAddr,
60    auth_token: Option<String>,
61    rate_limit: u32,
62    max_body_size: usize,
63    webhook_tx: mpsc::Sender<String>,
64    shutdown_rx: watch::Receiver<bool>,
65    /// Prometheus metrics registry and endpoint path (feature-gated).
66    #[cfg(feature = "prometheus")]
67    metrics_registry: Option<(
68        std::sync::Arc<prometheus_client::registry::Registry>,
69        String,
70    )>,
71}
72
73impl GatewayServer {
74    /// Create a new gateway server.
75    ///
76    /// `bind` is parsed as an IP address string (e.g. `"127.0.0.1"` or `"0.0.0.0"`).
77    /// If parsing fails, the server falls back to `127.0.0.1:<port>` and emits a warning.
78    ///
79    /// `webhook_tx` receives every valid, sanitised webhook message as a formatted
80    /// `"[sender@channel] body"` string.
81    ///
82    /// `shutdown_rx` is a [`watch::Receiver<bool>`] that signals graceful shutdown
83    /// when its value transitions to `true`.  Sending `true` causes the server to
84    /// stop accepting new connections and drain in-flight requests.
85    ///
86    /// # Panics
87    ///
88    /// Does not panic. Invalid `bind` values fall back to `127.0.0.1` with a log warning.
89    #[must_use]
90    pub fn new(
91        bind: &str,
92        port: u16,
93        webhook_tx: mpsc::Sender<String>,
94        shutdown_rx: watch::Receiver<bool>,
95    ) -> Self {
96        let addr: SocketAddr = format!("{bind}:{port}").parse().unwrap_or_else(|e| {
97            tracing::warn!("invalid bind '{bind}': {e}, falling back to 127.0.0.1:{port}");
98            SocketAddr::from(([127, 0, 0, 1], port))
99        });
100
101        if bind == "0.0.0.0" {
102            tracing::warn!("gateway binding to 0.0.0.0 — ensure this is intended for production");
103        }
104
105        Self {
106            addr,
107            auth_token: None,
108            rate_limit: 120,
109            max_body_size: 1_048_576,
110            webhook_tx,
111            shutdown_rx,
112            #[cfg(feature = "prometheus")]
113            metrics_registry: None,
114        }
115    }
116
117    /// Set the bearer token required on `POST /webhook` requests.
118    ///
119    /// When `token` is `Some`, every request to `/webhook` must carry an
120    /// `Authorization: Bearer <token>` header.  The comparison is performed
121    /// in constant time (BLAKE3 + `subtle::ConstantTimeEq`) to prevent
122    /// timing-oracle attacks.
123    ///
124    /// When `token` is `None`, bearer authentication is disabled and a warning
125    /// is logged at startup.
126    ///
127    /// # Example
128    ///
129    /// ```
130    /// use tokio::sync::{mpsc, watch};
131    /// use zeph_gateway::GatewayServer;
132    ///
133    /// let (tx, _rx) = mpsc::channel::<String>(1);
134    /// let (_stx, srx) = watch::channel(false);
135    ///
136    /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
137    ///     .with_auth(Some("super-secret".into()));
138    /// ```
139    #[must_use]
140    pub fn with_auth(mut self, token: Option<String>) -> Self {
141        self.auth_token = token;
142        self
143    }
144
145    /// Set the per-IP rate limit for `POST /webhook`.
146    ///
147    /// `limit` is the maximum number of requests allowed per remote IP in a
148    /// 60-second fixed window.  Setting `limit` to `0` disables rate limiting.
149    ///
150    /// # Example
151    ///
152    /// ```
153    /// use tokio::sync::{mpsc, watch};
154    /// use zeph_gateway::GatewayServer;
155    ///
156    /// let (tx, _rx) = mpsc::channel::<String>(1);
157    /// let (_stx, srx) = watch::channel(false);
158    ///
159    /// // Allow at most 30 webhook posts per minute per IP.
160    /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
161    ///     .with_rate_limit(30);
162    /// ```
163    #[must_use]
164    pub fn with_rate_limit(mut self, limit: u32) -> Self {
165        self.rate_limit = limit;
166        self
167    }
168
169    /// Set the maximum allowed request body size in bytes.
170    ///
171    /// Requests whose body exceeds this size are rejected with `413 Content Too Large`
172    /// before any handler is invoked. The default is 1 MiB (1 048 576 bytes).
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// use tokio::sync::{mpsc, watch};
178    /// use zeph_gateway::GatewayServer;
179    ///
180    /// let (tx, _rx) = mpsc::channel::<String>(1);
181    /// let (_stx, srx) = watch::channel(false);
182    ///
183    /// // Restrict bodies to 64 KiB.
184    /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
185    ///     .with_max_body_size(64 * 1024);
186    /// ```
187    #[must_use]
188    pub fn with_max_body_size(mut self, size: usize) -> Self {
189        self.max_body_size = size;
190        self
191    }
192
193    /// Attach a Prometheus metrics registry to the gateway.
194    ///
195    /// When set, the server mounts an additional route at `path` that returns the registry
196    /// contents encoded as `OpenMetrics` 1.0.0 text.  The endpoint is unauthenticated and
197    /// bypasses rate limiting.
198    ///
199    /// Requires the `prometheus` feature.
200    ///
201    /// # Example
202    ///
203    /// ```no_run
204    /// # #[cfg(feature = "prometheus")]
205    /// # {
206    /// use std::sync::Arc;
207    /// use prometheus_client::registry::Registry;
208    /// use tokio::sync::{mpsc, watch};
209    /// use zeph_gateway::GatewayServer;
210    ///
211    /// let (tx, _rx) = mpsc::channel::<String>(1);
212    /// let (_stx, srx) = watch::channel(false);
213    /// let registry = Arc::new(Registry::default());
214    ///
215    /// let server = GatewayServer::new("127.0.0.1", 8080, tx, srx)
216    ///     .with_metrics_registry(registry, "/metrics");
217    /// # }
218    /// ```
219    #[cfg(feature = "prometheus")]
220    #[must_use]
221    pub fn with_metrics_registry(
222        mut self,
223        registry: std::sync::Arc<prometheus_client::registry::Registry>,
224        path: impl Into<String>,
225    ) -> Self {
226        self.metrics_registry = Some((registry, path.into()));
227        self
228    }
229
230    /// Start the HTTP gateway server and block until shutdown is signalled.
231    ///
232    /// Binds a TCP listener on the configured address, installs middleware
233    /// (body-size limit → auth → rate limiting), and serves requests until
234    /// the [`watch::Receiver`] supplied to [`GatewayServer::new`] transitions
235    /// to `true`.
236    ///
237    /// # Errors
238    ///
239    /// - [`GatewayError::Bind`] — the listener could not be bound (port in use,
240    ///   permission denied, etc.).
241    /// - [`GatewayError::Server`] — the server encountered a fatal I/O error
242    ///   after binding.
243    pub async fn serve(self) -> Result<(), GatewayError> {
244        let state = AppState {
245            webhook_tx: self.webhook_tx,
246            started_at: Instant::now(),
247        };
248
249        if self.auth_token.is_none() {
250            tracing::warn!(
251                "gateway running without bearer auth — ensure firewall or upstream proxy enforces access control"
252            );
253        }
254
255        let router = build_router(
256            state,
257            self.auth_token.as_deref(),
258            self.rate_limit,
259            self.max_body_size,
260        );
261
262        #[cfg(feature = "prometheus")]
263        let router = if let Some((registry, path)) = self.metrics_registry {
264            let metrics_route = axum::Router::new()
265                .route(&path, axum::routing::get(crate::handlers::metrics_handler))
266                .with_state(registry);
267            router.merge(metrics_route)
268        } else {
269            router
270        };
271
272        let listener = tokio::net::TcpListener::bind(self.addr)
273            .await
274            .map_err(|e| GatewayError::Bind(self.addr.to_string(), e))?;
275        tracing::info!("gateway listening on {}", self.addr);
276
277        let mut shutdown_rx = self.shutdown_rx;
278        axum::serve(
279            listener,
280            router.into_make_service_with_connect_info::<SocketAddr>(),
281        )
282        .with_graceful_shutdown(async move {
283            while !*shutdown_rx.borrow_and_update() {
284                if shutdown_rx.changed().await.is_err() {
285                    std::future::pending::<()>().await;
286                }
287            }
288            tracing::info!("gateway shutting down");
289        })
290        .await
291        .map_err(|e| GatewayError::Server(format!("{e}")))?;
292
293        Ok(())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[cfg(feature = "prometheus")]
302    #[tokio::test]
303    async fn test_metrics_endpoint_returns_openmetrics() {
304        use axum::body::Body;
305        use http_body_util::BodyExt;
306        use prometheus_client::registry::Registry;
307        use tower::ServiceExt;
308
309        let registry = std::sync::Arc::new(Registry::default());
310
311        let (tx, _rx) = mpsc::channel(1);
312        let (_stx, srx) = watch::channel(false);
313        let server = GatewayServer::new("127.0.0.1", 19999, tx, srx)
314            .with_metrics_registry(std::sync::Arc::clone(&registry), "/metrics");
315
316        // Build the router directly without binding a port
317        let state = AppState {
318            webhook_tx: server.webhook_tx,
319            started_at: Instant::now(),
320        };
321        let router = crate::router::build_router(
322            state,
323            server.auth_token.as_deref(),
324            server.rate_limit,
325            server.max_body_size,
326        );
327        let metrics_route = axum::Router::new()
328            .route(
329                "/metrics",
330                axum::routing::get(crate::handlers::metrics_handler),
331            )
332            .with_state(registry);
333        let router = router.merge(metrics_route);
334
335        let req = axum::http::Request::builder()
336            .method("GET")
337            .uri("/metrics")
338            .body(Body::empty())
339            .unwrap();
340
341        let response = router.oneshot(req).await.unwrap();
342        assert_eq!(response.status(), axum::http::StatusCode::OK);
343
344        let ct = response
345            .headers()
346            .get("content-type")
347            .unwrap()
348            .to_str()
349            .unwrap();
350        assert!(
351            ct.contains("application/openmetrics-text"),
352            "unexpected content-type: {ct}"
353        );
354
355        let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
356        let body = String::from_utf8(body_bytes.to_vec()).unwrap();
357        assert!(body.ends_with("# EOF\n"), "missing EOF marker in:\n{body}");
358    }
359
360    #[test]
361    fn server_builder_chain() {
362        let (tx, _rx) = mpsc::channel(1);
363        let (_stx, srx) = watch::channel(false);
364        let server = GatewayServer::new("127.0.0.1", 8090, tx, srx)
365            .with_auth(Some("token".into()))
366            .with_rate_limit(60)
367            .with_max_body_size(512);
368
369        assert_eq!(server.rate_limit, 60);
370        assert_eq!(server.max_body_size, 512);
371        assert!(server.auth_token.is_some());
372    }
373
374    #[test]
375    fn server_invalid_bind_fallback() {
376        let (tx, _rx) = mpsc::channel(1);
377        let (_stx, srx) = watch::channel(false);
378        let server = GatewayServer::new("not_an_ip", 9999, tx, srx);
379        assert_eq!(server.addr.port(), 9999);
380    }
381}