umbral_core/middleware.rs
1//! A framework-level request/response middleware contract (feature #68).
2//!
3//! axum/tower already give you `Layer` + `Service`, but writing one
4//! correctly means understanding poll-readiness, `BoxFuture`, and the
5//! `Service` trait's ownership rules. Most application middleware only
6//! wants two things: *look at the request before the handler*, and
7//! *look at the response after*. The [`Middleware`] trait is that
8//! narrow, ergonomic surface: a request-side hook and a
9//! response-side hook, typed for Rust.
10//!
11//! Plugins contribute middleware via `Plugin::middleware`; an app adds
12//! its own via `AppBuilder::middleware`. `App::build` collects them all
13//! into one [`MiddlewareStack`] and installs it as a single axum layer.
14//!
15//! ## Composition (the onion)
16//!
17//! `before_request` hooks run in registration order; `after_response`
18//! hooks run in the *reverse* order, so each middleware wraps the ones
19//! registered after it — the standard onion model that makes
20//! composition predictable. A `before_request` may short-circuit by
21//! returning `Err(response)`: the handler and every later middleware are
22//! skipped, and only the `after_response` hooks of the middleware that
23//! already ran (in reverse) get to see the short-circuit response.
24
25use std::sync::Arc;
26
27use async_trait::async_trait;
28use axum::extract::{Request, State};
29use axum::middleware::Next;
30use axum::response::Response;
31
32/// A typed request/response middleware. Implement either hook (or both);
33/// the defaults pass through untouched.
34///
35/// ```ignore
36/// use umbral::prelude::*;
37/// use axum::extract::Request;
38/// use axum::response::Response;
39///
40/// struct RequestId;
41///
42/// #[umbral::async_trait]
43/// impl Middleware for RequestId {
44/// async fn before_request(&self, mut req: Request) -> Result<Request, Response> {
45/// req.headers_mut().insert("x-request-id", new_id().parse().unwrap());
46/// Ok(req)
47/// }
48/// }
49/// ```
50#[async_trait]
51pub trait Middleware: Send + Sync + 'static {
52 /// A short label for diagnostics. Defaults to the type name.
53 fn name(&self) -> &'static str {
54 std::any::type_name::<Self>()
55 }
56
57 /// Declarative position in the chain. **Lower values are OUTER** — the
58 /// middleware's `before_request` runs earlier and its `after_response`
59 /// runs later (onion order). Middleware with equal `order` keep their
60 /// registration order (app-level before plugin-level; plugins in
61 /// dependency order). `MiddlewareStack::apply` stable-sorts by this
62 /// before installing, so a middleware can place itself relative to
63 /// others (e.g. a session loader at `-100`, an auth gate at `-50`)
64 /// without depending on registration timing. Default `0`.
65 fn order(&self) -> i32 {
66 0
67 }
68
69 /// Inspect or modify the request before it reaches the handler.
70 ///
71 /// Return `Ok(req)` to continue (with the possibly-modified request),
72 /// or `Err(response)` to short-circuit: the handler and all later
73 /// middleware are skipped, and the response unwinds back out through
74 /// the `after_response` hooks of the middleware that already ran.
75 ///
76 /// Default: pass the request through unchanged.
77 async fn before_request(&self, req: Request) -> Result<Request, Response> {
78 Ok(req)
79 }
80
81 /// Inspect or modify the response on the way out.
82 ///
83 /// Default: pass the response through unchanged.
84 async fn after_response(&self, res: Response) -> Response {
85 res
86 }
87}
88
89/// An ordered set of [`Middleware`], collected from the app builder and
90/// every plugin, installed as one axum layer by `App::build`.
91#[derive(Clone, Default)]
92pub struct MiddlewareStack {
93 middleware: Vec<Arc<dyn Middleware>>,
94}
95
96impl MiddlewareStack {
97 /// An empty stack.
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 /// Append one middleware to the end of the stack. Its `before_request`
103 /// runs after every middleware already in the stack; its
104 /// `after_response` runs before them (onion order).
105 pub fn push(&mut self, mw: Arc<dyn Middleware>) {
106 self.middleware.push(mw);
107 }
108
109 /// Append every middleware from `other`, preserving order.
110 pub fn extend(&mut self, other: impl IntoIterator<Item = Arc<dyn Middleware>>) {
111 self.middleware.extend(other);
112 }
113
114 /// True when no middleware is registered — `App::build` skips
115 /// installing the layer entirely in that case.
116 pub fn is_empty(&self) -> bool {
117 self.middleware.is_empty()
118 }
119
120 /// Number of registered middleware.
121 pub fn len(&self) -> usize {
122 self.middleware.len()
123 }
124
125 /// Wrap `router` with this stack as a single axum middleware layer.
126 /// A no-op (returns the router unchanged) when the stack is empty.
127 pub fn apply(mut self, router: axum::Router) -> axum::Router {
128 if self.middleware.is_empty() {
129 return router;
130 }
131 // Declarative ordering: stable-sort by `Middleware::order` (lower =
132 // outer) so chain position is controllable independent of
133 // registration timing. `sort_by_key` is stable, so equal-`order`
134 // middleware keep their insertion order (app before plugins).
135 self.middleware.sort_by_key(|mw| mw.order());
136 let state = Arc::new(self.middleware);
137 router.layer(axum::middleware::from_fn_with_state(state, run_stack))
138 }
139}
140
141/// The axum middleware fn that drives one [`MiddlewareStack`] per request:
142/// run the `before_request` hooks in order (short-circuiting on the first
143/// `Err`), invoke the handler, then run the `after_response` hooks of the
144/// middleware that ran, in reverse.
145async fn run_stack(
146 State(stack): State<Arc<Vec<Arc<dyn Middleware>>>>,
147 req: Request,
148 next: Next,
149) -> Response {
150 // `Option` so the request can be moved into each `before_request` and
151 // handed back, without the borrow checker tripping on the short-
152 // circuit (`Err`) path where it isn't returned.
153 let mut req_opt = Some(req);
154 let mut ran = 0usize;
155 let mut short_circuit: Option<Response> = None;
156
157 for mw in stack.iter() {
158 let req = req_opt
159 .take()
160 .expect("request present for each before hook");
161 match mw.before_request(req).await {
162 Ok(modified) => {
163 req_opt = Some(modified);
164 ran += 1;
165 }
166 Err(resp) => {
167 short_circuit = Some(resp);
168 break;
169 }
170 }
171 }
172
173 let mut res = match short_circuit {
174 Some(resp) => resp,
175 None => {
176 next.run(
177 req_opt
178 .take()
179 .expect("request present when not short-circuited"),
180 )
181 .await
182 }
183 };
184
185 // Only the middleware whose `before_request` ran get an
186 // `after_response`, in reverse (onion unwind).
187 for mw in stack.iter().take(ran).rev() {
188 res = mw.after_response(res).await;
189 }
190 res
191}