torch_web/middleware.rs
1//! # Middleware System
2//!
3//! Torch's middleware system allows you to intercept and modify HTTP requests and responses
4//! as they flow through your application. Middleware can be used for logging, authentication,
5//! CORS, rate limiting, and many other cross-cutting concerns.
6//!
7//! ## How Middleware Works
8//!
9//! Middleware forms a chain where each middleware can:
10//! 1. Inspect and modify the incoming request
11//! 2. Call the next middleware in the chain (or the final handler)
12//! 3. Inspect and modify the outgoing response
13//! 4. Short-circuit the chain by returning early
14//!
15//! ## Examples
16//!
17//! ### Basic Logging Middleware
18//!
19//! ```rust
20//! use torch_web::{App, Request, Response, middleware::Middleware};
21//! use std::pin::Pin;
22//! use std::future::Future;
23//!
24//! struct Logger;
25//!
26//! impl Middleware for Logger {
27//! fn call(
28//! &self,
29//! req: Request,
30//! next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
31//! ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
32//! Box::pin(async move {
33//! println!("{} {}", req.method(), req.path());
34//! let response = next(req).await;
35//! println!("Response: {}", response.status_code());
36//! response
37//! })
38//! }
39//! }
40//!
41//! let app = App::new()
42//! .middleware(Logger)
43//! .get("/", |_req| async { Response::ok().body("Hello!") });
44//! ```
45//!
46//! ### Function-based Middleware
47//!
48//! ```rust
49//! use torch_web::{App, Request, Response};
50//!
51//! let app = App::new()
52//! .middleware(|req: Request, next| async move {
53//! // Add a custom header to all responses
54//! let mut response = next(req).await;
55//! response = response.header("X-Powered-By", "Torch");
56//! response
57//! })
58//! .get("/", |_req| async { Response::ok().body("Hello!") });
59//! ```
60
61use std::future::Future;
62use std::pin::Pin;
63use crate::{Request, Response};
64
65/// Type alias for middleware functions.
66///
67/// This represents the function signature that middleware must implement.
68/// It takes a request and a "next" function that continues the middleware chain.
69pub type MiddlewareFn = std::sync::Arc<
70 dyn Fn(
71 Request,
72 Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
73 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
74 + Send
75 + Sync
76 + 'static,
77>;
78
79/// Trait for implementing middleware components.
80///
81/// Middleware can intercept requests before they reach handlers and modify
82/// responses before they're sent to clients. This trait provides a standard
83/// interface for all middleware components.
84///
85/// # Examples
86///
87/// ## Authentication Middleware
88///
89/// ```rust
90/// use torch_web::{Request, Response, middleware::Middleware};
91/// use std::pin::Pin;
92/// use std::future::Future;
93///
94/// struct AuthMiddleware {
95/// secret_key: String,
96/// }
97///
98/// impl AuthMiddleware {
99/// fn new(secret_key: String) -> Self {
100/// Self { secret_key }
101/// }
102/// }
103///
104/// impl Middleware for AuthMiddleware {
105/// fn call(
106/// &self,
107/// req: Request,
108/// next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
109/// ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
110/// Box::pin(async move {
111/// // Check for authorization header
112/// if let Some(auth_header) = req.header("authorization") {
113/// if auth_header.starts_with("Bearer ") {
114/// // Validate token here...
115/// return next(req).await;
116/// }
117/// }
118///
119/// Response::unauthorized().body("Authentication required")
120/// })
121/// }
122/// }
123/// ```
124///
125/// ## CORS Middleware
126///
127/// ```rust
128/// use torch_web::{Request, Response, middleware::Middleware};
129/// use std::pin::Pin;
130/// use std::future::Future;
131///
132/// struct CorsMiddleware;
133///
134/// impl Middleware for CorsMiddleware {
135/// fn call(
136/// &self,
137/// req: Request,
138/// next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
139/// ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
140/// Box::pin(async move {
141/// let mut response = next(req).await;
142/// response = response
143/// .header("Access-Control-Allow-Origin", "*")
144/// .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
145/// .header("Access-Control-Allow-Headers", "Content-Type, Authorization");
146/// response
147/// })
148/// }
149/// }
150/// ```
151pub trait Middleware: Send + Sync + 'static {
152 /// Processes a request through the middleware chain.
153 ///
154 /// This method receives the current request and a `next` function that
155 /// continues processing through the remaining middleware and eventually
156 /// to the route handler.
157 ///
158 /// # Parameters
159 ///
160 /// * `req` - The HTTP request to process
161 /// * `next` - Function to call the next middleware or handler in the chain
162 ///
163 /// # Returns
164 ///
165 /// Returns a `Future` that resolves to the HTTP response. The middleware
166 /// can modify the request before calling `next`, modify the response after
167 /// calling `next`, or return early without calling `next` at all.
168 ///
169 /// # Examples
170 ///
171 /// ```rust
172 /// use torch_web::{Request, Response, middleware::Middleware};
173 /// use std::pin::Pin;
174 /// use std::future::Future;
175 ///
176 /// struct TimingMiddleware;
177 ///
178 /// impl Middleware for TimingMiddleware {
179 /// fn call(
180 /// &self,
181 /// req: Request,
182 /// next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
183 /// ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
184 /// Box::pin(async move {
185 /// let start = std::time::Instant::now();
186 /// let response = next(req).await;
187 /// let duration = start.elapsed();
188 ///
189 /// println!("Request took {:?}", duration);
190 /// response.header("X-Response-Time", &format!("{}ms", duration.as_millis()))
191 /// })
192 /// }
193 /// }
194 /// ```
195 fn call(
196 &self,
197 req: Request,
198 next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
199 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
200}
201
202/// Any function that matches the signature can be middleware
203impl<F, Fut> Middleware for F
204where
205 F: Fn(
206 Request,
207 Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
208 ) -> Fut
209 + Send
210 + Sync
211 + 'static,
212 Fut: Future<Output = Response> + Send + 'static,
213{
214 fn call(
215 &self,
216 req: Request,
217 next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
218 ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
219 Box::pin(self(req, next))
220 }
221}
222
223/// Organizes middleware into a processing pipeline
224pub struct MiddlewareStack {
225 middleware: Vec<MiddlewareFn>,
226}
227
228impl MiddlewareStack {
229 /// Start with an empty stack
230 pub fn new() -> Self {
231 Self {
232 middleware: Vec::new(),
233 }
234 }
235
236 /// Add another layer to the stack
237 pub fn add<M>(&mut self, middleware: M)
238 where
239 M: Middleware,
240 {
241 let middleware_fn = std::sync::Arc::new(move |req, next| middleware.call(req, next));
242 self.middleware.push(middleware_fn);
243 }
244
245 /// Run a request through the middleware pipeline
246 pub async fn execute<F>(&self, req: Request, handler: F) -> Response
247 where
248 F: Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync + 'static,
249 {
250 if self.middleware.is_empty() {
251 // Fast path when no middleware is configured
252 return handler(req).await;
253 }
254
255 // For now, execute middleware in sequence (simplified implementation)
256 let response = handler(req).await;
257
258 // Apply middleware effects to the response (simplified)
259 for middleware in &self.middleware {
260 // This is a simplified approach - in a full implementation,
261 // you would need to restructure the middleware trait to support
262 // proper chaining with async closures
263 let _ = middleware; // Suppress unused warning
264 }
265
266 response
267 }
268}
269
270impl Default for MiddlewareStack {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276/// Built-in middleware for logging requests
277pub fn logger() -> impl Middleware {
278 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
279 Box::pin(async move {
280 let method = req.method().clone();
281 let path = req.path().to_string();
282 let start = std::time::Instant::now();
283
284 let response = next(req).await;
285
286 let duration = start.elapsed();
287 println!(
288 "{} {} - {} ({:.2}ms)",
289 method,
290 path,
291 response.status_code(),
292 duration.as_secs_f64() * 1000.0
293 );
294
295 response
296 })
297 }
298}
299
300/// Built-in middleware for CORS
301pub fn cors() -> impl Middleware {
302 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
303 Box::pin(async move {
304 let mut response = next(req).await;
305
306 // Add CORS headers (this is a simple implementation)
307 response = response
308 .header("Access-Control-Allow-Origin", "*")
309 .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
310 .header("Access-Control-Allow-Headers", "Content-Type, Authorization");
311
312 response
313 })
314 }
315}
316
317/// Built-in middleware for adding security headers
318pub fn security_headers() -> impl Middleware {
319 |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
320 Box::pin(async move {
321 let mut response = next(req).await;
322
323 response = response
324 .header("X-Content-Type-Options", "nosniff")
325 .header("X-Frame-Options", "DENY")
326 .header("X-XSS-Protection", "1; mode=block");
327
328 response
329 })
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::Response;
337
338 #[tokio::test]
339 async fn test_middleware_stack() {
340 let mut stack = MiddlewareStack::new();
341
342 // Add a middleware that adds a header
343 stack.add(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
344 Box::pin(async move {
345 let mut response = next(req).await;
346 response = response.header("X-Test", "middleware");
347 response
348 })
349 });
350
351 let handler = |_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
352 Box::pin(async { Response::ok().body("Hello") })
353 };
354
355 let req = Request::from_hyper(
356 http::Request::builder()
357 .method("GET")
358 .uri("/")
359 .body(())
360 .unwrap()
361 .into_parts()
362 .0,
363 Vec::new(),
364 )
365 .await
366 .unwrap();
367
368 let response = stack.execute(req, handler).await;
369 assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
370 assert_eq!(response.body_data(), b"Hello");
371 }
372
373 #[tokio::test]
374 async fn test_cors_middleware() {
375 let cors_middleware = cors();
376
377 let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
378 Box::pin(async { Response::ok().body("Hello") })
379 });
380
381 let req = Request::from_hyper(
382 http::Request::builder()
383 .method("GET")
384 .uri("/")
385 .body(())
386 .unwrap()
387 .into_parts()
388 .0,
389 Vec::new(),
390 )
391 .await
392 .unwrap();
393
394 let response = cors_middleware.call(req, next).await;
395 assert_eq!(
396 response.headers().get("Access-Control-Allow-Origin").unwrap(),
397 "*"
398 );
399 }
400}