1use std::future::Future;
2use std::pin::Pin;
3use crate::{Request, Response};
4
5pub type MiddlewareFn = std::sync::Arc<
7 dyn Fn(
8 Request,
9 Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
10 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
11 + Send
12 + Sync
13 + 'static,
14>;
15
16pub trait Middleware: Send + Sync + 'static {
18 fn call(
20 &self,
21 req: Request,
22 next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
23 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
24}
25
26impl<F, Fut> Middleware for F
28where
29 F: Fn(
30 Request,
31 Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
32 ) -> Fut
33 + Send
34 + Sync
35 + 'static,
36 Fut: Future<Output = Response> + Send + 'static,
37{
38 fn call(
39 &self,
40 req: Request,
41 next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
42 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
43 Box::pin(self(req, next))
44 }
45}
46
47pub struct MiddlewareStack {
49 middleware: Vec<MiddlewareFn>,
50}
51
52impl MiddlewareStack {
53 pub fn new() -> Self {
55 Self {
56 middleware: Vec::new(),
57 }
58 }
59
60 pub fn add<M>(&mut self, middleware: M)
62 where
63 M: Middleware,
64 {
65 let middleware_fn = std::sync::Arc::new(move |req, next| middleware.call(req, next));
66 self.middleware.push(middleware_fn);
67 }
68
69 pub async fn execute<F>(&self, req: Request, handler: F) -> Response
71 where
72 F: Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync + 'static,
73 {
74 if self.middleware.is_empty() {
75 return handler(req).await;
77 }
78
79 let response = handler(req).await;
81
82 for middleware in &self.middleware {
84 let _ = middleware; }
89
90 response
91 }
92}
93
94impl Default for MiddlewareStack {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100pub fn logger() -> impl Middleware {
102 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
103 Box::pin(async move {
104 let method = req.method().clone();
105 let path = req.path().to_string();
106 let start = std::time::Instant::now();
107
108 let response = next(req).await;
109
110 let duration = start.elapsed();
111 println!(
112 "{} {} - {} ({:.2}ms)",
113 method,
114 path,
115 response.status_code(),
116 duration.as_secs_f64() * 1000.0
117 );
118
119 response
120 })
121 }
122}
123
124pub fn cors() -> impl Middleware {
126 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
127 Box::pin(async move {
128 let mut response = next(req).await;
129
130 response = response
132 .header("Access-Control-Allow-Origin", "*")
133 .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
134 .header("Access-Control-Allow-Headers", "Content-Type, Authorization");
135
136 response
137 })
138 }
139}
140
141pub fn security_headers() -> impl Middleware {
143 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
144 Box::pin(async move {
145 let mut response = next(req).await;
146
147 response = response
148 .header("X-Content-Type-Options", "nosniff")
149 .header("X-Frame-Options", "DENY")
150 .header("X-XSS-Protection", "1; mode=block");
151
152 response
153 })
154 }
155}
156
157#[cfg(disabled_for_now)]
158mod tests {
159 use super::*;
160 use crate::Response;
161
162 #[tokio::test]
163 async fn test_middleware_stack() {
164 let mut stack = MiddlewareStack::new();
165
166 stack.add(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
168 Box::pin(async move {
169 let mut response = next(req).await;
170 response = response.header("X-Test", "middleware");
171 response
172 })
173 });
174
175 let handler = |_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
176 Box::pin(async { Response::ok().body("Hello") })
177 };
178
179 let req = Request::from_hyper(
180 http::Request::builder()
181 .method("GET")
182 .uri("/")
183 .body(())
184 .unwrap()
185 .into_parts()
186 .0,
187 Vec::new(),
188 )
189 .await
190 .unwrap();
191
192 let response = stack.execute(req, handler).await;
193 assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
194 assert_eq!(response.body_data(), b"Hello");
195 }
196
197 #[tokio::test]
198 async fn test_cors_middleware() {
199 let cors_middleware = cors();
200
201 let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
202 Box::pin(async { Response::ok().body("Hello") })
203 });
204
205 let req = Request::from_hyper(
206 http::Request::builder()
207 .method("GET")
208 .uri("/")
209 .body(())
210 .unwrap()
211 .into_parts()
212 .0,
213 Vec::new(),
214 )
215 .await
216 .unwrap();
217
218 let response = cors_middleware.call(req, next).await;
219 assert_eq!(
220 response.headers().get("Access-Control-Allow-Origin").unwrap(),
221 "*"
222 );
223 }
224}