Skip to main content

rust_api/
middleware.rs

1//! Request-lifecycle middleware utilities and Tower layer factories.
2//!
3//! This module provides utilities that operate at **request time**, not at
4//! route registration time. These are Tower layer factories — they produce
5//! middleware that wraps individual routes or entire routers.
6//!
7//! # Module Boundaries
8//!
9//! - `pipeline.rs` — build-time: compose routes into a `Router`
10//! - `controller.rs` — build-time: declare which handlers belong to a
11//!   controller
12//! - `middleware.rs` — request-time: inspect/modify requests and responses
13//!
14//! # Protected Route Groups
15//!
16//! Auth is a **cross-cutting concern** — it belongs here as a router transform,
17//! not inside a controller handler. Apply it to a route group via `.map()`:
18//!
19//! ```ignore
20//! RouterPipeline::new()
21//!     .mount_guarded::<AdminController, _>(admin_svc, || { /* config check */ })
22//!     .map(require_bearer(admin_key))
23//! ```
24//!
25//! Or scoped to just a sub-group:
26//!
27//! ```ignore
28//! RouterPipeline::new()
29//!     .group("/admin", |g| g
30//!         .mount::<AdminController>(admin_svc)
31//!         .map(require_bearer(admin_key))   // only admin routes are protected
32//!     )
33//! ```
34
35use axum::{body::Body, http::Request, middleware::Next, response::IntoResponse, Router};
36
37// ---------------------------------------------------------------------------
38// require_bearer
39// ---------------------------------------------------------------------------
40
41/// Returns a `Router -> Router` transform that enforces `Authorization: Bearer
42/// <token>` on every request passing through the router it is applied to.
43///
44/// Returns `401 Unauthorized` if the header is absent, malformed, or if the
45/// token does not match `expected` (compared in **constant time** to prevent
46/// timing oracles).
47///
48/// # Usage
49///
50/// Pass directly to `.map()` — the function signature matches `.map()`'s
51/// expected `Fn(Router<()>) -> Router<()>`:
52///
53/// ```ignore
54/// use rust_api::prelude::*;
55///
56/// RouterPipeline::new()
57///     .mount_guarded::<AdminController, _>(admin_svc, || { /* config check */ })
58///     .map(require_bearer(admin_key))
59///     .build()?
60/// ```
61pub fn require_bearer(
62    expected: impl Into<String>,
63) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static {
64    let expected = expected.into();
65    move |router: Router<()>| {
66        let expected = expected.clone();
67        router.layer(axum::middleware::from_fn(
68            move |req: Request<Body>, next: Next| {
69                let expected = expected.clone();
70                async move {
71                    let authorized = req
72                        .headers()
73                        .get(axum::http::header::AUTHORIZATION)
74                        .and_then(|v| v.to_str().ok())
75                        .and_then(|v| v.strip_prefix("Bearer "))
76                        .map(|token| constant_time_eq(token.as_bytes(), expected.as_bytes()))
77                        .unwrap_or(false);
78
79                    if authorized {
80                        next.run(req).await
81                    } else {
82                        axum::http::StatusCode::UNAUTHORIZED.into_response()
83                    }
84                }
85            },
86        ))
87    }
88}
89
90// ---------------------------------------------------------------------------
91// guard
92// ---------------------------------------------------------------------------
93
94/// Returns a `Router -> Router` transform that guards every request with a
95/// predicate.
96///
97/// Returns `403 Forbidden` if `guard_fn(&request)` returns `false`. The
98/// predicate runs before any extractors, so it has access to headers, URI,
99/// and method.
100///
101/// For **authentication**, prefer [`require_bearer`] — it handles the
102/// `Authorization: Bearer` protocol correctly. `guard` is suited for
103/// non-auth predicates (e.g., IP allowlists, feature flags, method
104/// restrictions).
105///
106/// # Usage
107///
108/// Pass directly to `.map()` on the pipeline:
109///
110/// ```ignore
111/// use rust_api::prelude::*;
112///
113/// RouterPipeline::new()
114///     .mount::<MyController>(svc)
115///     .map(guard(|req| is_allowed_ip(req)))
116///     .build()?
117/// ```
118pub fn guard<G>(guard_fn: G) -> impl Fn(Router<()>) -> Router<()> + Clone + Send + 'static
119where
120    G: Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
121{
122    move |router: Router<()>| {
123        let guard_fn = guard_fn.clone();
124        router.layer(axum::middleware::from_fn(
125            move |req: Request<Body>, next: Next| {
126                let guard_fn = guard_fn.clone();
127                async move {
128                    if guard_fn(&req) {
129                        next.run(req).await
130                    } else {
131                        axum::http::StatusCode::FORBIDDEN.into_response()
132                    }
133                }
134            },
135        ))
136    }
137}
138
139// ---------------------------------------------------------------------------
140// Internal helpers
141// ---------------------------------------------------------------------------
142
143/// Constant-time byte-slice equality — prevents timing oracle attacks.
144///
145/// XORs every byte of both slices (zero-padded to the longer length) and
146/// accumulates the differences. No early exit: a short token cannot
147/// short-circuit the comparison.
148fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
149    let len = a.len().max(b.len());
150    let mut diff: u8 = 0;
151    for i in 0..len {
152        let ab = a.get(i).copied().unwrap_or(0);
153        let bb = b.get(i).copied().unwrap_or(0);
154        diff |= ab ^ bb;
155    }
156    diff == 0
157}
158
159// ---------------------------------------------------------------------------
160// Tests
161// ---------------------------------------------------------------------------
162
163#[cfg(test)]
164mod tests {
165    use axum::{body::Body, http::Request, routing::get, Router};
166    use http_body_util::BodyExt;
167    use tower::ServiceExt;
168
169    use super::*;
170
171    // -----------------------------------------------------------------------
172    // constant_time_eq
173    // -----------------------------------------------------------------------
174
175    #[test]
176    fn ct_eq_identical_slices() {
177        assert!(constant_time_eq(b"secret", b"secret"));
178    }
179
180    #[test]
181    fn ct_eq_different_slices() {
182        assert!(!constant_time_eq(b"secret", b"wrong!"));
183    }
184
185    #[test]
186    fn ct_eq_empty_slices() {
187        assert!(constant_time_eq(b"", b""));
188    }
189
190    #[test]
191    fn ct_eq_different_lengths_short_a() {
192        assert!(!constant_time_eq(b"abc", b"abcd"));
193    }
194
195    #[test]
196    fn ct_eq_different_lengths_short_b() {
197        assert!(!constant_time_eq(b"abcd", b"abc"));
198    }
199
200    #[test]
201    fn ct_eq_empty_vs_nonempty() {
202        assert!(!constant_time_eq(b"", b"x"));
203    }
204
205    // -----------------------------------------------------------------------
206    // require_bearer
207    // -----------------------------------------------------------------------
208
209    fn bearer_router() -> Router<()> {
210        let inner = Router::new().route("/protected", get(|| async { "ok" }));
211        require_bearer("correct-token")(inner)
212    }
213
214    #[tokio::test]
215    async fn bearer_accepts_correct_token() {
216        let app = bearer_router();
217        let response = app
218            .oneshot(
219                Request::builder()
220                    .uri("/protected")
221                    .header("Authorization", "Bearer correct-token")
222                    .body(Body::empty())
223                    .unwrap(),
224            )
225            .await
226            .unwrap();
227
228        assert_eq!(response.status(), 200);
229        let body = response.into_body().collect().await.unwrap().to_bytes();
230        assert_eq!(&body[..], b"ok");
231    }
232
233    #[tokio::test]
234    async fn bearer_rejects_wrong_token() {
235        let app = bearer_router();
236        let response = app
237            .oneshot(
238                Request::builder()
239                    .uri("/protected")
240                    .header("Authorization", "Bearer wrong-token")
241                    .body(Body::empty())
242                    .unwrap(),
243            )
244            .await
245            .unwrap();
246
247        assert_eq!(response.status(), 401);
248    }
249
250    #[tokio::test]
251    async fn bearer_rejects_missing_header() {
252        let app = bearer_router();
253        let response = app
254            .oneshot(
255                Request::builder()
256                    .uri("/protected")
257                    .body(Body::empty())
258                    .unwrap(),
259            )
260            .await
261            .unwrap();
262
263        assert_eq!(response.status(), 401);
264    }
265
266    #[tokio::test]
267    async fn bearer_rejects_malformed_header() {
268        let app = bearer_router();
269        let response = app
270            .oneshot(
271                Request::builder()
272                    .uri("/protected")
273                    .header("Authorization", "correct-token")
274                    .body(Body::empty())
275                    .unwrap(),
276            )
277            .await
278            .unwrap();
279
280        assert_eq!(response.status(), 401);
281    }
282
283    // -----------------------------------------------------------------------
284    // guard
285    // -----------------------------------------------------------------------
286
287    fn guard_router(
288        predicate: impl Fn(&Request<Body>) -> bool + Clone + Send + Sync + 'static,
289    ) -> Router<()> {
290        let inner = Router::new().route("/guarded", get(|| async { "ok" }));
291        guard(predicate)(inner)
292    }
293
294    #[tokio::test]
295    async fn guard_allows_request_when_predicate_is_true() {
296        let app = guard_router(|_req| true);
297        let response = app
298            .oneshot(
299                Request::builder()
300                    .uri("/guarded")
301                    .body(Body::empty())
302                    .unwrap(),
303            )
304            .await
305            .unwrap();
306
307        assert_eq!(response.status(), 200);
308    }
309
310    #[tokio::test]
311    async fn guard_blocks_request_with_403_when_predicate_is_false() {
312        let app = guard_router(|_req| false);
313        let response = app
314            .oneshot(
315                Request::builder()
316                    .uri("/guarded")
317                    .body(Body::empty())
318                    .unwrap(),
319            )
320            .await
321            .unwrap();
322
323        assert_eq!(response.status(), 403);
324    }
325
326    #[tokio::test]
327    async fn guard_predicate_receives_live_request_headers() {
328        let app = guard_router(|req| req.headers().contains_key("x-allowed"));
329
330        // without header → 403
331        let blocked = app
332            .clone()
333            .oneshot(
334                Request::builder()
335                    .uri("/guarded")
336                    .body(Body::empty())
337                    .unwrap(),
338            )
339            .await
340            .unwrap();
341        assert_eq!(blocked.status(), 403);
342
343        // with header → 200
344        let allowed = app
345            .oneshot(
346                Request::builder()
347                    .uri("/guarded")
348                    .header("x-allowed", "yes")
349                    .body(Body::empty())
350                    .unwrap(),
351            )
352            .await
353            .unwrap();
354        assert_eq!(allowed.status(), 200);
355    }
356
357    #[tokio::test]
358    async fn guard_predicate_receives_live_request_uri() {
359        // predicate inspects the URI path
360        let app = guard_router(|req| req.uri().path().starts_with("/guarded"));
361
362        let response = app
363            .oneshot(
364                Request::builder()
365                    .uri("/guarded")
366                    .body(Body::empty())
367                    .unwrap(),
368            )
369            .await
370            .unwrap();
371        assert_eq!(response.status(), 200);
372    }
373}