Skip to main content

umbral_core/
middleware.rs

1//! A framework-level request/response middleware contract (feature #68).
2//!
3//! axum/tower already give you `Layer` + `Service`, but writing one
4//! correctly means understanding poll-readiness, `BoxFuture`, and the
5//! `Service` trait's ownership rules. Most application middleware only
6//! wants two things: *look at the request before the handler*, and
7//! *look at the response after*. The [`Middleware`] trait is that
8//! narrow, ergonomic surface: a request-side hook and a
9//! response-side hook, typed for Rust.
10//!
11//! Plugins contribute middleware via `Plugin::middleware`; an app adds
12//! its own via `AppBuilder::middleware`. `App::build` collects them all
13//! into one [`MiddlewareStack`] and installs it as a single axum layer.
14//!
15//! ## Composition (the onion)
16//!
17//! `before_request` hooks run in registration order; `after_response`
18//! hooks run in the *reverse* order, so each middleware wraps the ones
19//! registered after it — the standard onion model that makes
20//! composition predictable. A `before_request` may short-circuit by
21//! returning `Err(response)`: the handler and every later middleware are
22//! skipped, and only the `after_response` hooks of the middleware that
23//! already ran (in reverse) get to see the short-circuit response.
24
25use std::sync::Arc;
26
27use async_trait::async_trait;
28use axum::extract::{Request, State};
29use axum::middleware::Next;
30use axum::response::Response;
31
32/// A typed request/response middleware. Implement either hook (or both);
33/// the defaults pass through untouched.
34///
35/// ```ignore
36/// use umbral::prelude::*;
37/// use axum::extract::Request;
38/// use axum::response::Response;
39///
40/// struct RequestId;
41///
42/// #[umbral::async_trait]
43/// impl Middleware for RequestId {
44///     async fn before_request(&self, mut req: Request) -> Result<Request, Response> {
45///         req.headers_mut().insert("x-request-id", new_id().parse().unwrap());
46///         Ok(req)
47///     }
48/// }
49/// ```
50#[async_trait]
51pub trait Middleware: Send + Sync + 'static {
52    /// A short label for diagnostics. Defaults to the type name.
53    fn name(&self) -> &'static str {
54        std::any::type_name::<Self>()
55    }
56
57    /// Declarative position in the chain. **Lower values are OUTER** — the
58    /// middleware's `before_request` runs earlier and its `after_response`
59    /// runs later (onion order). Middleware with equal `order` keep their
60    /// registration order (app-level before plugin-level; plugins in
61    /// dependency order). `MiddlewareStack::apply` stable-sorts by this
62    /// before installing, so a middleware can place itself relative to
63    /// others (e.g. a session loader at `-100`, an auth gate at `-50`)
64    /// without depending on registration timing. Default `0`.
65    fn order(&self) -> i32 {
66        0
67    }
68
69    /// Inspect or modify the request before it reaches the handler.
70    ///
71    /// Return `Ok(req)` to continue (with the possibly-modified request),
72    /// or `Err(response)` to short-circuit: the handler and all later
73    /// middleware are skipped, and the response unwinds back out through
74    /// the `after_response` hooks of the middleware that already ran.
75    ///
76    /// Default: pass the request through unchanged.
77    async fn before_request(&self, req: Request) -> Result<Request, Response> {
78        Ok(req)
79    }
80
81    /// Inspect or modify the response on the way out.
82    ///
83    /// Default: pass the response through unchanged.
84    async fn after_response(&self, res: Response) -> Response {
85        res
86    }
87}
88
89/// An ordered set of [`Middleware`], collected from the app builder and
90/// every plugin, installed as one axum layer by `App::build`.
91#[derive(Clone, Default)]
92pub struct MiddlewareStack {
93    middleware: Vec<Arc<dyn Middleware>>,
94}
95
96impl MiddlewareStack {
97    /// An empty stack.
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Append one middleware to the end of the stack. Its `before_request`
103    /// runs after every middleware already in the stack; its
104    /// `after_response` runs before them (onion order).
105    pub fn push(&mut self, mw: Arc<dyn Middleware>) {
106        self.middleware.push(mw);
107    }
108
109    /// Append every middleware from `other`, preserving order.
110    pub fn extend(&mut self, other: impl IntoIterator<Item = Arc<dyn Middleware>>) {
111        self.middleware.extend(other);
112    }
113
114    /// True when no middleware is registered — `App::build` skips
115    /// installing the layer entirely in that case.
116    pub fn is_empty(&self) -> bool {
117        self.middleware.is_empty()
118    }
119
120    /// Number of registered middleware.
121    pub fn len(&self) -> usize {
122        self.middleware.len()
123    }
124
125    /// Wrap `router` with this stack as a single axum middleware layer.
126    /// A no-op (returns the router unchanged) when the stack is empty.
127    pub fn apply(mut self, router: axum::Router) -> axum::Router {
128        if self.middleware.is_empty() {
129            return router;
130        }
131        // Declarative ordering: stable-sort by `Middleware::order` (lower =
132        // outer) so chain position is controllable independent of
133        // registration timing. `sort_by_key` is stable, so equal-`order`
134        // middleware keep their insertion order (app before plugins).
135        self.middleware.sort_by_key(|mw| mw.order());
136        let state = Arc::new(self.middleware);
137        router.layer(axum::middleware::from_fn_with_state(state, run_stack))
138    }
139}
140
141/// The axum middleware fn that drives one [`MiddlewareStack`] per request:
142/// run the `before_request` hooks in order (short-circuiting on the first
143/// `Err`), invoke the handler, then run the `after_response` hooks of the
144/// middleware that ran, in reverse.
145async fn run_stack(
146    State(stack): State<Arc<Vec<Arc<dyn Middleware>>>>,
147    req: Request,
148    next: Next,
149) -> Response {
150    // `Option` so the request can be moved into each `before_request` and
151    // handed back, without the borrow checker tripping on the short-
152    // circuit (`Err`) path where it isn't returned.
153    let mut req_opt = Some(req);
154    let mut ran = 0usize;
155    let mut short_circuit: Option<Response> = None;
156
157    for mw in stack.iter() {
158        let req = req_opt
159            .take()
160            .expect("request present for each before hook");
161        match mw.before_request(req).await {
162            Ok(modified) => {
163                req_opt = Some(modified);
164                ran += 1;
165            }
166            Err(resp) => {
167                short_circuit = Some(resp);
168                break;
169            }
170        }
171    }
172
173    let mut res = match short_circuit {
174        Some(resp) => resp,
175        None => {
176            next.run(
177                req_opt
178                    .take()
179                    .expect("request present when not short-circuited"),
180            )
181            .await
182        }
183    };
184
185    // Only the middleware whose `before_request` ran get an
186    // `after_response`, in reverse (onion unwind).
187    for mw in stack.iter().take(ran).rev() {
188        res = mw.after_response(res).await;
189    }
190    res
191}