Skip to main content

tork_core/middleware/
mod.rs

1//! The middleware layer.
2//!
3//! A middleware is a layer that wraps request handling: it receives the request
4//! and a [`Next`] handle, may inspect or modify the request, calls `next` to run
5//! the rest of the chain (or short-circuits), and may inspect or modify the
6//! response. Middlewares run in registration order, outermost first, and the
7//! innermost `next` invokes the route dispatch.
8
9use std::sync::Arc;
10
11use crate::app::AppInner;
12use crate::body::ReqBody;
13use crate::error::{Error, Result};
14use crate::response::Response;
15use crate::router::BoxFuture;
16
17pub mod body_limit;
18pub mod compression;
19pub mod cors;
20pub mod https_redirect;
21pub mod proxy_headers;
22pub mod request_id;
23pub mod security_headers;
24pub mod timeout;
25pub mod trace;
26pub mod trusted_host;
27
28pub use body_limit::BodyLimit;
29pub use compression::Compression;
30pub use cors::Cors;
31pub use https_redirect::HttpsRedirect;
32pub use proxy_headers::ProxyHeaders;
33pub use request_id::RequestId;
34pub use security_headers::SecurityHeaders;
35pub use timeout::Timeout;
36pub use trace::Trace;
37pub use trusted_host::TrustedHost;
38
39/// The request type threaded through the middleware chain.
40pub type Request = http::Request<ReqBody>;
41
42/// Controls what happens when a middleware whose [`name`](Middleware::name)
43/// already exists is registered again.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum DuplicatePolicy {
46    /// Keep every registration.
47    Allow,
48    /// Keep every registration, but log a warning.
49    Warn,
50    /// Reject the application configuration with an error.
51    Reject,
52    /// Keep only the most recent registration.
53    Replace,
54}
55
56/// A request/response middleware layer.
57///
58/// Built-in middlewares implement this directly; custom middlewares are usually
59/// written as an `async fn` annotated with `#[tork::middleware]`, which generates
60/// the implementation.
61pub trait Middleware: Send + Sync + 'static {
62    /// Processes `request`, optionally calling `next` to continue the chain.
63    fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>>;
64
65    /// A stable name used for duplicate detection and diagnostics.
66    ///
67    /// Built-in middlewares override this with a short name (for example
68    /// `"Cors"`); the default is the fully-qualified type name.
69    fn name(&self) -> &'static str {
70        std::any::type_name::<Self>()
71    }
72
73    /// Controls what happens if another middleware with the same name is added.
74    fn duplicate_policy(&self) -> DuplicatePolicy {
75        DuplicatePolicy::Allow
76    }
77}
78
79/// A handle to the remainder of the middleware chain.
80///
81/// Calling [`run`](Next::run) advances to the next middleware, or, once the
82/// chain is exhausted, dispatches to the matched route handler.
83pub struct Next {
84    state: Arc<NextState>,
85    index: usize,
86}
87
88struct NextState {
89    app: Arc<AppInner>,
90    stack: Arc<[Arc<dyn Middleware>]>,
91}
92
93impl Next {
94    /// Creates a chain handle positioned at the first middleware.
95    pub(crate) fn new(app: Arc<AppInner>, stack: Arc<[Arc<dyn Middleware>]>) -> Self {
96        Self {
97            state: Arc::new(NextState { app, stack }),
98            index: 0,
99        }
100    }
101
102    /// Runs the rest of the chain and returns the response.
103    ///
104    /// If more middlewares remain, the next one is invoked; otherwise the request
105    /// is dispatched to its route handler.
106    pub fn run(self, request: Request) -> BoxFuture<'static, Result<Response>> {
107        match self.state.stack.get(self.index).cloned() {
108            Some(middleware) => {
109                let next = Next {
110                    state: self.state,
111                    index: self.index + 1,
112                };
113                middleware.handle(request, next)
114            }
115            None => {
116                let app = self.state.app.clone();
117                Box::pin(async move { Ok(app.dispatch(request).await) })
118            }
119        }
120    }
121}
122
123/// Resolves duplicate registrations according to each middleware's policy.
124///
125/// # Errors
126///
127/// Returns an error if a middleware whose policy is [`DuplicatePolicy::Reject`]
128/// is registered more than once.
129pub(crate) fn resolve_duplicates(
130    middleware: Vec<Arc<dyn Middleware>>,
131) -> Result<Vec<Arc<dyn Middleware>>> {
132    let mut resolved: Vec<Arc<dyn Middleware>> = Vec::with_capacity(middleware.len());
133
134    for entry in middleware {
135        let name = entry.name();
136        let existing = resolved.iter().position(|m| m.name() == name);
137
138        match (existing, entry.duplicate_policy()) {
139            (None, _) | (Some(_), DuplicatePolicy::Allow) => resolved.push(entry),
140            (Some(_), DuplicatePolicy::Warn) => {
141                eprintln!(
142                    "tork: middleware `{}` registered more than once",
143                    short_name(name)
144                );
145                resolved.push(entry);
146            }
147            (Some(index), DuplicatePolicy::Replace) => resolved[index] = entry,
148            (Some(_), DuplicatePolicy::Reject) => {
149                let short = short_name(name);
150                return Err(Error::internal(format!(
151                    "Duplicate middleware detected: {short}\n\
152                     {short} middleware can only be registered once per scope.\n\
153                     Already registered at app level."
154                ))
155                .with_code("DUPLICATE_MIDDLEWARE"));
156            }
157        }
158    }
159
160    Ok(resolved)
161}
162
163/// Returns the last `::`-separated segment of a type name.
164fn short_name(name: &str) -> &str {
165    name.rsplit("::").next().unwrap_or(name)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::app::{App, AppInner};
172    use crate::body::box_body;
173    use crate::constants::TEXT_PLAIN_UTF8;
174    use crate::extract::RequestContext;
175    use crate::response::bytes_response;
176    use crate::router::{HandlerFn, Route, Router};
177    use crate::{Method, StatusCode};
178
179    use bytes::Bytes;
180    use http::HeaderValue;
181    use http_body_util::{BodyExt, Full};
182
183    /// Middleware that records that it ran on the response, then calls `next`.
184    struct Mark(&'static str);
185    impl Middleware for Mark {
186        fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
187            let header = self.0;
188            Box::pin(async move {
189                let mut response = next.run(request).await?;
190                response
191                    .headers_mut()
192                    .append("x-mark", HeaderValue::from_static(header));
193                Ok(response)
194            })
195        }
196    }
197
198    /// Middleware that short-circuits without calling `next`.
199    struct ShortCircuit;
200    impl Middleware for ShortCircuit {
201        fn handle(&self, _request: Request, _next: Next) -> BoxFuture<'static, Result<Response>> {
202            Box::pin(async { Err(crate::Error::forbidden("blocked")) })
203        }
204    }
205
206    fn pong_handler() -> HandlerFn {
207        std::sync::Arc::new(
208            |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
209                Box::pin(async {
210                    Ok(bytes_response(
211                        StatusCode::OK,
212                        TEXT_PLAIN_UTF8,
213                        Bytes::from_static(b"pong"),
214                    ))
215                })
216            },
217        )
218    }
219
220    fn app_with(middlewares: Vec<Box<dyn FnOnce(App) -> App>>) -> std::sync::Arc<AppInner> {
221        let mut app = App::new().include_router(Router::new().route(Route::new(
222            Method::GET,
223            "/",
224            pong_handler(),
225        )));
226        for add in middlewares {
227            app = add(app);
228        }
229        std::sync::Arc::new(app.build().unwrap())
230    }
231
232    fn request() -> Request {
233        http::Request::builder()
234            .method(Method::GET)
235            .uri("/")
236            .body(box_body(Full::new(Bytes::new())))
237            .unwrap()
238    }
239
240    async fn body_string(response: Response) -> String {
241        let bytes = response.into_body().collect().await.unwrap().to_bytes();
242        String::from_utf8(bytes.to_vec()).unwrap()
243    }
244
245    #[tokio::test]
246    async fn chain_runs_outermost_first_and_reaches_dispatch() {
247        let app = app_with(vec![
248            Box::new(|a: App| a.middleware(Mark("outer"))),
249            Box::new(|a: App| a.middleware(Mark("inner"))),
250        ]);
251
252        let response = app.handle(request()).await;
253        assert_eq!(response.status(), StatusCode::OK);
254
255        // Both layers ran (order: inner appends first on the way out, then outer).
256        let marks: Vec<_> = response
257            .headers()
258            .get_all("x-mark")
259            .iter()
260            .map(|v| v.to_str().unwrap().to_owned())
261            .collect();
262        assert_eq!(marks, vec!["inner", "outer"]);
263        assert_eq!(body_string(response).await, "pong");
264    }
265
266    #[tokio::test]
267    async fn middleware_can_short_circuit() {
268        let app = app_with(vec![Box::new(|a: App| a.middleware(ShortCircuit))]);
269        let response = app.handle(request()).await;
270        assert_eq!(response.status(), StatusCode::FORBIDDEN);
271    }
272
273    /// A middleware with a configurable name and duplicate policy.
274    struct Policy {
275        name: &'static str,
276        policy: DuplicatePolicy,
277    }
278    impl Middleware for Policy {
279        fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
280            next.run(request)
281        }
282        fn name(&self) -> &'static str {
283            self.name
284        }
285        fn duplicate_policy(&self) -> DuplicatePolicy {
286            self.policy
287        }
288    }
289
290    fn policy(name: &'static str, policy: DuplicatePolicy) -> std::sync::Arc<dyn Middleware> {
291        std::sync::Arc::new(Policy { name, policy })
292    }
293
294    #[test]
295    fn resolve_duplicates_applies_each_policy() {
296        // Allow keeps every registration.
297        let allowed = resolve_duplicates(vec![
298            policy("a", DuplicatePolicy::Allow),
299            policy("a", DuplicatePolicy::Allow),
300        ])
301        .unwrap();
302        assert_eq!(allowed.len(), 2);
303
304        // Replace keeps only the most recent.
305        let replaced = resolve_duplicates(vec![
306            policy("b", DuplicatePolicy::Replace),
307            policy("b", DuplicatePolicy::Replace),
308        ])
309        .unwrap();
310        assert_eq!(replaced.len(), 1);
311
312        // Reject fails the configuration.
313        assert!(resolve_duplicates(vec![
314            policy("c", DuplicatePolicy::Reject),
315            policy("c", DuplicatePolicy::Reject)
316        ])
317        .is_err());
318
319        // Distinct names never collide.
320        let distinct = resolve_duplicates(vec![
321            policy("x", DuplicatePolicy::Reject),
322            policy("y", DuplicatePolicy::Reject),
323        ])
324        .unwrap();
325        assert_eq!(distinct.len(), 2);
326    }
327}