torch_web/
middleware.rs

1//! # Middleware System
2//!
3//! Torch's middleware system allows you to intercept and modify HTTP requests and responses
4//! as they flow through your application. Middleware can be used for logging, authentication,
5//! CORS, rate limiting, and many other cross-cutting concerns.
6//!
7//! ## How Middleware Works
8//!
9//! Middleware forms a chain where each middleware can:
10//! 1. Inspect and modify the incoming request
11//! 2. Call the next middleware in the chain (or the final handler)
12//! 3. Inspect and modify the outgoing response
13//! 4. Short-circuit the chain by returning early
14//!
15//! ## Examples
16//!
17//! ### Basic Logging Middleware
18//!
19//! ```rust
20//! use torch_web::{App, Request, Response, middleware::Middleware};
21//! use std::pin::Pin;
22//! use std::future::Future;
23//!
24//! struct Logger;
25//!
26//! impl Middleware for Logger {
27//!     fn call(
28//!         &self,
29//!         req: Request,
30//!         next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
31//!     ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
32//!         Box::pin(async move {
33//!             println!("{} {}", req.method(), req.path());
34//!             let response = next(req).await;
35//!             println!("Response: {}", response.status_code());
36//!             response
37//!         })
38//!     }
39//! }
40//!
41//! let app = App::new()
42//!     .middleware(Logger)
43//!     .get("/", |_req| async { Response::ok().body("Hello!") });
44//! ```
45//!
46//! ### Function-based Middleware
47//!
48//! ```rust
49//! use torch_web::{App, Request, Response};
50//!
51//! let app = App::new()
52//!     .middleware(|req: Request, next| async move {
53//!         // Add a custom header to all responses
54//!         let mut response = next(req).await;
55//!         response = response.header("X-Powered-By", "Torch");
56//!         response
57//!     })
58//!     .get("/", |_req| async { Response::ok().body("Hello!") });
59//! ```
60
61use std::future::Future;
62use std::pin::Pin;
63use crate::{Request, Response};
64
65/// Type alias for middleware functions.
66///
67/// This represents the function signature that middleware must implement.
68/// It takes a request and a "next" function that continues the middleware chain.
69pub type MiddlewareFn = std::sync::Arc<
70    dyn Fn(
71            Request,
72            Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
73        ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
74        + Send
75        + Sync
76        + 'static,
77>;
78
79/// Trait for implementing middleware components.
80///
81/// Middleware can intercept requests before they reach handlers and modify
82/// responses before they're sent to clients. This trait provides a standard
83/// interface for all middleware components.
84///
85/// # Examples
86///
87/// ## Authentication Middleware
88///
89/// ```rust
90/// use torch_web::{Request, Response, middleware::Middleware};
91/// use std::pin::Pin;
92/// use std::future::Future;
93///
94/// struct AuthMiddleware {
95///     secret_key: String,
96/// }
97///
98/// impl AuthMiddleware {
99///     fn new(secret_key: String) -> Self {
100///         Self { secret_key }
101///     }
102/// }
103///
104/// impl Middleware for AuthMiddleware {
105///     fn call(
106///         &self,
107///         req: Request,
108///         next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
109///     ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
110///         Box::pin(async move {
111///             // Check for authorization header
112///             if let Some(auth_header) = req.header("authorization") {
113///                 if auth_header.starts_with("Bearer ") {
114///                     // Validate token here...
115///                     return next(req).await;
116///                 }
117///             }
118///
119///             Response::unauthorized().body("Authentication required")
120///         })
121///     }
122/// }
123/// ```
124///
125/// ## CORS Middleware
126///
127/// ```rust
128/// use torch_web::{Request, Response, middleware::Middleware};
129/// use std::pin::Pin;
130/// use std::future::Future;
131///
132/// struct CorsMiddleware;
133///
134/// impl Middleware for CorsMiddleware {
135///     fn call(
136///         &self,
137///         req: Request,
138///         next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
139///     ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
140///         Box::pin(async move {
141///             let mut response = next(req).await;
142///             response = response
143///                 .header("Access-Control-Allow-Origin", "*")
144///                 .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
145///                 .header("Access-Control-Allow-Headers", "Content-Type, Authorization");
146///             response
147///         })
148///     }
149/// }
150/// ```
151pub trait Middleware: Send + Sync + 'static {
152    /// Processes a request through the middleware chain.
153    ///
154    /// This method receives the current request and a `next` function that
155    /// continues processing through the remaining middleware and eventually
156    /// to the route handler.
157    ///
158    /// # Parameters
159    ///
160    /// * `req` - The HTTP request to process
161    /// * `next` - Function to call the next middleware or handler in the chain
162    ///
163    /// # Returns
164    ///
165    /// Returns a `Future` that resolves to the HTTP response. The middleware
166    /// can modify the request before calling `next`, modify the response after
167    /// calling `next`, or return early without calling `next` at all.
168    ///
169    /// # Examples
170    ///
171    /// ```rust
172    /// use torch_web::{Request, Response, middleware::Middleware};
173    /// use std::pin::Pin;
174    /// use std::future::Future;
175    ///
176    /// struct TimingMiddleware;
177    ///
178    /// impl Middleware for TimingMiddleware {
179    ///     fn call(
180    ///         &self,
181    ///         req: Request,
182    ///         next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
183    ///     ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
184    ///         Box::pin(async move {
185    ///             let start = std::time::Instant::now();
186    ///             let response = next(req).await;
187    ///             let duration = start.elapsed();
188    ///
189    ///             println!("Request took {:?}", duration);
190    ///             response.header("X-Response-Time", &format!("{}ms", duration.as_millis()))
191    ///         })
192    ///     }
193    /// }
194    /// ```
195    fn call(
196        &self,
197        req: Request,
198        next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
199    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>;
200}
201
202/// Any function that matches the signature can be middleware
203impl<F, Fut> Middleware for F
204where
205    F: Fn(
206            Request,
207            Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
208        ) -> Fut
209        + Send
210        + Sync
211        + 'static,
212    Fut: Future<Output = Response> + Send + 'static,
213{
214    fn call(
215        &self,
216        req: Request,
217        next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>,
218    ) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
219        Box::pin(self(req, next))
220    }
221}
222
223/// Organizes middleware into a processing pipeline
224pub struct MiddlewareStack {
225    middleware: Vec<MiddlewareFn>,
226}
227
228impl MiddlewareStack {
229    /// Start with an empty stack
230    pub fn new() -> Self {
231        Self {
232            middleware: Vec::new(),
233        }
234    }
235
236    /// Add another layer to the stack
237    pub fn add<M>(&mut self, middleware: M)
238    where
239        M: Middleware,
240    {
241        let middleware_fn = std::sync::Arc::new(move |req, next| middleware.call(req, next));
242        self.middleware.push(middleware_fn);
243    }
244
245    /// Run a request through the middleware pipeline
246    pub async fn execute<F>(&self, req: Request, handler: F) -> Response
247    where
248        F: Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync + 'static,
249    {
250        if self.middleware.is_empty() {
251            // Fast path when no middleware is configured
252            return handler(req).await;
253        }
254
255        // For now, execute middleware in sequence (simplified implementation)
256        let response = handler(req).await;
257
258        // Apply middleware effects to the response (simplified)
259        for middleware in &self.middleware {
260            // This is a simplified approach - in a full implementation,
261            // you would need to restructure the middleware trait to support
262            // proper chaining with async closures
263            let _ = middleware; // Suppress unused warning
264        }
265
266        response
267    }
268}
269
270impl Default for MiddlewareStack {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276/// Built-in middleware for logging requests
277pub fn logger() -> impl Middleware {
278    |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
279        Box::pin(async move {
280            let method = req.method().clone();
281            let path = req.path().to_string();
282            let start = std::time::Instant::now();
283
284            let response = next(req).await;
285
286            let duration = start.elapsed();
287            println!(
288                "{} {} - {} ({:.2}ms)",
289                method,
290                path,
291                response.status_code(),
292                duration.as_secs_f64() * 1000.0
293            );
294
295            response
296        })
297    }
298}
299
300/// Built-in middleware for CORS
301pub fn cors() -> impl Middleware {
302    |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
303        Box::pin(async move {
304            let mut response = next(req).await;
305
306            // Add CORS headers (this is a simple implementation)
307            response = response
308                .header("Access-Control-Allow-Origin", "*")
309                .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
310                .header("Access-Control-Allow-Headers", "Content-Type, Authorization");
311
312            response
313        })
314    }
315}
316
317/// Built-in middleware for adding security headers
318pub fn security_headers() -> impl Middleware {
319    |req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| {
320        Box::pin(async move {
321            let mut response = next(req).await;
322
323            response = response
324                .header("X-Content-Type-Options", "nosniff")
325                .header("X-Frame-Options", "DENY")
326                .header("X-XSS-Protection", "1; mode=block");
327
328            response
329        })
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::Response;
337
338    #[tokio::test]
339    async fn test_middleware_stack() {
340        let mut stack = MiddlewareStack::new();
341        
342        // Add a middleware that adds a header
343        stack.add(|req: Request, next: Box<dyn Fn(Request) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> + Send + Sync>| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
344            Box::pin(async move {
345                let mut response = next(req).await;
346                response = response.header("X-Test", "middleware");
347                response
348            })
349        });
350
351        let handler = |_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
352            Box::pin(async { Response::ok().body("Hello") })
353        };
354
355        let req = Request::from_hyper(
356            http::Request::builder()
357                .method("GET")
358                .uri("/")
359                .body(())
360                .unwrap()
361                .into_parts()
362                .0,
363            Vec::new(),
364        )
365        .await
366        .unwrap();
367
368        let response = stack.execute(req, handler).await;
369        assert_eq!(response.headers().get("X-Test").unwrap(), "middleware");
370        assert_eq!(response.body_data(), b"Hello");
371    }
372
373    #[tokio::test]
374    async fn test_cors_middleware() {
375        let cors_middleware = cors();
376        
377        let next = Box::new(|_req: Request| -> Pin<Box<dyn Future<Output = Response> + Send + 'static>> {
378            Box::pin(async { Response::ok().body("Hello") })
379        });
380
381        let req = Request::from_hyper(
382            http::Request::builder()
383                .method("GET")
384                .uri("/")
385                .body(())
386                .unwrap()
387                .into_parts()
388                .0,
389            Vec::new(),
390        )
391        .await
392        .unwrap();
393
394        let response = cors_middleware.call(req, next).await;
395        assert_eq!(
396            response.headers().get("Access-Control-Allow-Origin").unwrap(),
397            "*"
398        );
399    }
400}