1use std::sync::Arc;
10
11use crate::app::AppInner;
12use crate::body::ReqBody;
13use crate::error::{Error, Result};
14use crate::response::Response;
15use crate::router::BoxFuture;
16
17pub mod body_limit;
18pub mod compression;
19pub mod cors;
20pub mod https_redirect;
21pub mod proxy_headers;
22pub mod request_id;
23pub mod security_headers;
24pub mod timeout;
25pub mod trace;
26pub mod trusted_host;
27
28pub use body_limit::BodyLimit;
29pub use compression::Compression;
30pub use cors::Cors;
31pub use https_redirect::HttpsRedirect;
32pub use proxy_headers::ProxyHeaders;
33pub use request_id::RequestId;
34pub use security_headers::SecurityHeaders;
35pub use timeout::Timeout;
36pub use trace::Trace;
37pub use trusted_host::TrustedHost;
38
39pub type Request = http::Request<ReqBody>;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum DuplicatePolicy {
46 Allow,
48 Warn,
50 Reject,
52 Replace,
54}
55
56pub trait Middleware: Send + Sync + 'static {
62 fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>>;
64
65 fn name(&self) -> &'static str {
70 std::any::type_name::<Self>()
71 }
72
73 fn duplicate_policy(&self) -> DuplicatePolicy {
75 DuplicatePolicy::Allow
76 }
77}
78
79pub struct Next {
84 state: Arc<NextState>,
85 index: usize,
86}
87
88struct NextState {
89 app: Arc<AppInner>,
90 stack: Arc<[Arc<dyn Middleware>]>,
91}
92
93impl Next {
94 pub(crate) fn new(app: Arc<AppInner>, stack: Arc<[Arc<dyn Middleware>]>) -> Self {
96 Self {
97 state: Arc::new(NextState { app, stack }),
98 index: 0,
99 }
100 }
101
102 pub fn run(self, request: Request) -> BoxFuture<'static, Result<Response>> {
107 match self.state.stack.get(self.index).cloned() {
108 Some(middleware) => {
109 let next = Next {
110 state: self.state,
111 index: self.index + 1,
112 };
113 middleware.handle(request, next)
114 }
115 None => {
116 let app = self.state.app.clone();
117 Box::pin(async move { Ok(app.dispatch(request).await) })
118 }
119 }
120 }
121}
122
123pub(crate) fn resolve_duplicates(
130 middleware: Vec<Arc<dyn Middleware>>,
131) -> Result<Vec<Arc<dyn Middleware>>> {
132 let mut resolved: Vec<Arc<dyn Middleware>> = Vec::with_capacity(middleware.len());
133
134 for entry in middleware {
135 let name = entry.name();
136 let existing = resolved.iter().position(|m| m.name() == name);
137
138 match (existing, entry.duplicate_policy()) {
139 (None, _) | (Some(_), DuplicatePolicy::Allow) => resolved.push(entry),
140 (Some(_), DuplicatePolicy::Warn) => {
141 eprintln!(
142 "tork: middleware `{}` registered more than once",
143 short_name(name)
144 );
145 resolved.push(entry);
146 }
147 (Some(index), DuplicatePolicy::Replace) => resolved[index] = entry,
148 (Some(_), DuplicatePolicy::Reject) => {
149 let short = short_name(name);
150 return Err(Error::internal(format!(
151 "Duplicate middleware detected: {short}\n\
152 {short} middleware can only be registered once per scope.\n\
153 Already registered at app level."
154 ))
155 .with_code("DUPLICATE_MIDDLEWARE"));
156 }
157 }
158 }
159
160 Ok(resolved)
161}
162
163fn short_name(name: &str) -> &str {
165 name.rsplit("::").next().unwrap_or(name)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::app::{App, AppInner};
172 use crate::body::box_body;
173 use crate::constants::TEXT_PLAIN_UTF8;
174 use crate::extract::RequestContext;
175 use crate::response::bytes_response;
176 use crate::router::{HandlerFn, Route, Router};
177 use crate::{Method, StatusCode};
178
179 use bytes::Bytes;
180 use http::HeaderValue;
181 use http_body_util::{BodyExt, Full};
182
183 struct Mark(&'static str);
185 impl Middleware for Mark {
186 fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
187 let header = self.0;
188 Box::pin(async move {
189 let mut response = next.run(request).await?;
190 response
191 .headers_mut()
192 .append("x-mark", HeaderValue::from_static(header));
193 Ok(response)
194 })
195 }
196 }
197
198 struct ShortCircuit;
200 impl Middleware for ShortCircuit {
201 fn handle(&self, _request: Request, _next: Next) -> BoxFuture<'static, Result<Response>> {
202 Box::pin(async { Err(crate::Error::forbidden("blocked")) })
203 }
204 }
205
206 fn pong_handler() -> HandlerFn {
207 std::sync::Arc::new(
208 |_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
209 Box::pin(async {
210 Ok(bytes_response(
211 StatusCode::OK,
212 TEXT_PLAIN_UTF8,
213 Bytes::from_static(b"pong"),
214 ))
215 })
216 },
217 )
218 }
219
220 fn app_with(middlewares: Vec<Box<dyn FnOnce(App) -> App>>) -> std::sync::Arc<AppInner> {
221 let mut app = App::new().include_router(Router::new().route(Route::new(
222 Method::GET,
223 "/",
224 pong_handler(),
225 )));
226 for add in middlewares {
227 app = add(app);
228 }
229 std::sync::Arc::new(app.build().unwrap())
230 }
231
232 fn request() -> Request {
233 http::Request::builder()
234 .method(Method::GET)
235 .uri("/")
236 .body(box_body(Full::new(Bytes::new())))
237 .unwrap()
238 }
239
240 async fn body_string(response: Response) -> String {
241 let bytes = response.into_body().collect().await.unwrap().to_bytes();
242 String::from_utf8(bytes.to_vec()).unwrap()
243 }
244
245 #[tokio::test]
246 async fn chain_runs_outermost_first_and_reaches_dispatch() {
247 let app = app_with(vec![
248 Box::new(|a: App| a.middleware(Mark("outer"))),
249 Box::new(|a: App| a.middleware(Mark("inner"))),
250 ]);
251
252 let response = app.handle(request()).await;
253 assert_eq!(response.status(), StatusCode::OK);
254
255 let marks: Vec<_> = response
257 .headers()
258 .get_all("x-mark")
259 .iter()
260 .map(|v| v.to_str().unwrap().to_owned())
261 .collect();
262 assert_eq!(marks, vec!["inner", "outer"]);
263 assert_eq!(body_string(response).await, "pong");
264 }
265
266 #[tokio::test]
267 async fn middleware_can_short_circuit() {
268 let app = app_with(vec![Box::new(|a: App| a.middleware(ShortCircuit))]);
269 let response = app.handle(request()).await;
270 assert_eq!(response.status(), StatusCode::FORBIDDEN);
271 }
272
273 struct Policy {
275 name: &'static str,
276 policy: DuplicatePolicy,
277 }
278 impl Middleware for Policy {
279 fn handle(&self, request: Request, next: Next) -> BoxFuture<'static, Result<Response>> {
280 next.run(request)
281 }
282 fn name(&self) -> &'static str {
283 self.name
284 }
285 fn duplicate_policy(&self) -> DuplicatePolicy {
286 self.policy
287 }
288 }
289
290 fn policy(name: &'static str, policy: DuplicatePolicy) -> std::sync::Arc<dyn Middleware> {
291 std::sync::Arc::new(Policy { name, policy })
292 }
293
294 #[test]
295 fn resolve_duplicates_applies_each_policy() {
296 let allowed = resolve_duplicates(vec![
298 policy("a", DuplicatePolicy::Allow),
299 policy("a", DuplicatePolicy::Allow),
300 ])
301 .unwrap();
302 assert_eq!(allowed.len(), 2);
303
304 let replaced = resolve_duplicates(vec![
306 policy("b", DuplicatePolicy::Replace),
307 policy("b", DuplicatePolicy::Replace),
308 ])
309 .unwrap();
310 assert_eq!(replaced.len(), 1);
311
312 assert!(resolve_duplicates(vec![
314 policy("c", DuplicatePolicy::Reject),
315 policy("c", DuplicatePolicy::Reject)
316 ])
317 .is_err());
318
319 let distinct = resolve_duplicates(vec![
321 policy("x", DuplicatePolicy::Reject),
322 policy("y", DuplicatePolicy::Reject),
323 ])
324 .unwrap();
325 assert_eq!(distinct.len(), 2);
326 }
327}