web_server_abstraction/
middleware.rs

1//! Expanded middleware implementations with advanced features.
2
3use crate::core::{Middleware, Next};
4use crate::error::Result;
5use crate::types::{Request, Response, StatusCode};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant, SystemTime};
10
11/// Logging middleware that logs request details
12pub struct LoggingMiddleware {
13    pub enabled: bool,
14    pub log_bodies: bool,
15}
16
17impl Default for LoggingMiddleware {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl LoggingMiddleware {
24    pub fn new() -> Self {
25        Self {
26            enabled: true,
27            log_bodies: false,
28        }
29    }
30
31    pub fn enabled(mut self, enabled: bool) -> Self {
32        self.enabled = enabled;
33        self
34    }
35
36    pub fn log_bodies(mut self, log_bodies: bool) -> Self {
37        self.log_bodies = log_bodies;
38        self
39    }
40}
41
42#[async_trait]
43impl Middleware for LoggingMiddleware {
44    async fn call(&self, req: Request, next: Next) -> Result<Response> {
45        if !self.enabled {
46            return next.run(req).await;
47        }
48
49        let start = Instant::now();
50        let method = req.method;
51        let path = req.path().to_string();
52
53        if self.log_bodies {
54            let body_preview = req.body.len().min(100);
55            println!("-> {:?} {} (body: {} bytes)", method, path, body_preview);
56        } else {
57            println!("-> {:?} {}", method, path);
58        }
59
60        let response = next.run(req).await;
61
62        let duration = start.elapsed();
63        match &response {
64            Ok(_resp) => {
65                println!("<- {:?} {} - 200 OK ({:?})", method, path, duration);
66            }
67            Err(err) => {
68                println!("<- {:?} {} - ERROR: {} ({:?})", method, path, err, duration);
69            }
70        }
71
72        response
73    }
74}
75
76/// CORS middleware for handling Cross-Origin Resource Sharing
77pub struct CorsMiddleware {
78    pub allow_origin: String,
79    pub allow_methods: Vec<String>,
80    pub allow_headers: Vec<String>,
81    pub allow_credentials: bool,
82    pub expose_headers: Vec<String>,
83    pub max_age: Option<Duration>,
84}
85
86impl Default for CorsMiddleware {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl CorsMiddleware {
93    pub fn new() -> Self {
94        Self {
95            allow_origin: "*".to_string(),
96            allow_methods: vec![
97                "GET".to_string(),
98                "POST".to_string(),
99                "PUT".to_string(),
100                "DELETE".to_string(),
101                "OPTIONS".to_string(),
102            ],
103            allow_headers: vec![
104                "Content-Type".to_string(),
105                "Authorization".to_string(),
106                "Accept".to_string(),
107                "Origin".to_string(),
108                "X-Requested-With".to_string(),
109            ],
110            allow_credentials: false,
111            expose_headers: vec![],
112            max_age: Some(Duration::from_secs(86400)), // 24 hours
113        }
114    }
115
116    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
117        self.allow_origin = origin.into();
118        self
119    }
120
121    pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
122        self.allow_methods = methods;
123        self
124    }
125
126    pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
127        self.allow_headers = headers;
128        self
129    }
130
131    pub fn allow_credentials(mut self, allow: bool) -> Self {
132        self.allow_credentials = allow;
133        self
134    }
135
136    pub fn expose_headers(mut self, headers: Vec<String>) -> Self {
137        self.expose_headers = headers;
138        self
139    }
140
141    pub fn max_age(mut self, max_age: Duration) -> Self {
142        self.max_age = Some(max_age);
143        self
144    }
145}
146
147#[async_trait]
148impl Middleware for CorsMiddleware {
149    async fn call(&self, req: Request, next: Next) -> Result<Response> {
150        // Handle preflight requests
151        if req.method == crate::types::HttpMethod::OPTIONS {
152            let mut response = Response::new(StatusCode::OK)
153                .header("Access-Control-Allow-Origin", &self.allow_origin)
154                .header(
155                    "Access-Control-Allow-Methods",
156                    self.allow_methods.join(", "),
157                )
158                .header(
159                    "Access-Control-Allow-Headers",
160                    self.allow_headers.join(", "),
161                );
162
163            if self.allow_credentials {
164                response = response.header("Access-Control-Allow-Credentials", "true");
165            }
166
167            if let Some(max_age) = self.max_age {
168                response = response.header("Access-Control-Max-Age", max_age.as_secs().to_string());
169            }
170
171            return Ok(response);
172        }
173
174        let response = next.run(req).await?;
175
176        let mut cors_response = response
177            .header("Access-Control-Allow-Origin", &self.allow_origin)
178            .header(
179                "Access-Control-Allow-Methods",
180                self.allow_methods.join(", "),
181            )
182            .header(
183                "Access-Control-Allow-Headers",
184                self.allow_headers.join(", "),
185            );
186
187        if self.allow_credentials {
188            cors_response = cors_response.header("Access-Control-Allow-Credentials", "true");
189        }
190
191        if !self.expose_headers.is_empty() {
192            cors_response = cors_response.header(
193                "Access-Control-Expose-Headers",
194                self.expose_headers.join(", "),
195            );
196        }
197
198        Ok(cors_response)
199    }
200}
201
202/// Timeout middleware that cancels requests after a specified duration
203pub struct TimeoutMiddleware {
204    pub timeout: Duration,
205}
206
207impl TimeoutMiddleware {
208    pub fn new(timeout: Duration) -> Self {
209        Self { timeout }
210    }
211}
212
213#[async_trait]
214impl Middleware for TimeoutMiddleware {
215    async fn call(&self, req: Request, next: Next) -> Result<Response> {
216        // Note: In a real implementation, you'd use tokio::time::timeout
217        // For now, we'll just pass through with a warning if timeout is very short
218        if self.timeout.as_millis() < 100 {
219            println!("Warning: Very short timeout configured: {:?}", self.timeout);
220        }
221        next.run(req).await
222    }
223}
224
225/// Rate limiting middleware with in-memory storage
226pub struct RateLimitMiddleware {
227    pub max_requests: u32,
228    pub window: Duration,
229    pub store: Arc<Mutex<HashMap<String, (u32, SystemTime)>>>,
230}
231
232impl RateLimitMiddleware {
233    pub fn new(max_requests: u32, window: Duration) -> Self {
234        Self {
235            max_requests,
236            window,
237            store: Arc::new(Mutex::new(HashMap::new())),
238        }
239    }
240
241    fn get_client_key(&self, req: &Request) -> String {
242        // In a real implementation, you'd extract the client IP or user ID
243        // For now, we'll use a simple key based on the path
244        format!("default:{}", req.path())
245    }
246
247    fn is_rate_limited(&self, key: &str) -> bool {
248        let mut store = self.store.lock().unwrap();
249        let now = SystemTime::now();
250
251        match store.get_mut(key) {
252            Some((count, last_reset)) => {
253                // Check if window has expired
254                if now.duration_since(*last_reset).unwrap_or(Duration::ZERO) >= self.window {
255                    *count = 1;
256                    *last_reset = now;
257                    false
258                } else if *count >= self.max_requests {
259                    true
260                } else {
261                    *count += 1;
262                    false
263                }
264            }
265            None => {
266                store.insert(key.to_string(), (1, now));
267                false
268            }
269        }
270    }
271}
272
273#[async_trait]
274impl Middleware for RateLimitMiddleware {
275    async fn call(&self, req: Request, next: Next) -> Result<Response> {
276        let key = self.get_client_key(&req);
277
278        if self.is_rate_limited(&key) {
279            return Ok(Response::new(StatusCode(429))
280                .header("Content-Type", "application/json")
281                .body(r#"{"error": "Rate limit exceeded"}"#));
282        }
283
284        next.run(req).await
285    }
286}
287
288/// Authentication middleware with configurable validation
289pub struct AuthMiddleware {
290    pub require_auth: bool,
291    pub bearer_tokens: Arc<Vec<String>>,
292}
293
294impl Default for AuthMiddleware {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl AuthMiddleware {
301    pub fn new() -> Self {
302        Self {
303            require_auth: true,
304            bearer_tokens: Arc::new(vec![]),
305        }
306    }
307
308    pub fn optional(mut self) -> Self {
309        self.require_auth = false;
310        self
311    }
312
313    pub fn with_bearer_tokens(mut self, tokens: Vec<String>) -> Self {
314        self.bearer_tokens = Arc::new(tokens);
315        self
316    }
317
318    fn validate_token(&self, authorization: &str) -> bool {
319        if let Some(token) = authorization.strip_prefix("Bearer ") {
320            self.bearer_tokens.contains(&token.to_string())
321        } else {
322            false
323        }
324    }
325}
326
327#[async_trait]
328impl Middleware for AuthMiddleware {
329    async fn call(&self, req: Request, next: Next) -> Result<Response> {
330        if self.require_auth {
331            if let Some(auth_header) = req.headers.get("authorization") {
332                if !self.bearer_tokens.is_empty() && !self.validate_token(auth_header) {
333                    return Ok(Response::new(StatusCode::UNAUTHORIZED)
334                        .header("Content-Type", "application/json")
335                        .body(r#"{"error": "Invalid token"}"#));
336                }
337            } else {
338                return Ok(Response::new(StatusCode::UNAUTHORIZED)
339                    .header("Content-Type", "application/json")
340                    .body(r#"{"error": "Authentication required"}"#));
341            }
342        }
343
344        next.run(req).await
345    }
346}
347
348/// Content compression middleware
349pub struct CompressionMiddleware {
350    pub enabled: bool,
351    pub min_size: usize,
352}
353
354impl Default for CompressionMiddleware {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360impl CompressionMiddleware {
361    pub fn new() -> Self {
362        Self {
363            enabled: true,
364            min_size: 1024, // Only compress responses larger than 1KB
365        }
366    }
367
368    pub fn min_size(mut self, size: usize) -> Self {
369        self.min_size = size;
370        self
371    }
372}
373
374#[async_trait]
375impl Middleware for CompressionMiddleware {
376    async fn call(&self, req: Request, next: Next) -> Result<Response> {
377        let response = next.run(req).await?;
378
379        if !self.enabled {
380            return Ok(response);
381        }
382
383        // Check if response is large enough to compress
384        if response.body.len() >= self.min_size {
385            // In a real implementation, you'd actually compress the body
386            let compressed_response = response
387                .header("Content-Encoding", "gzip")
388                .header("Vary", "Accept-Encoding");
389            Ok(compressed_response)
390        } else {
391            Ok(response)
392        }
393    }
394}
395
396/// Security headers middleware
397pub struct SecurityHeadersMiddleware {
398    pub add_hsts: bool,
399    pub add_frame_options: bool,
400    pub add_content_type_options: bool,
401    pub add_xss_protection: bool,
402}
403
404impl Default for SecurityHeadersMiddleware {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410impl SecurityHeadersMiddleware {
411    pub fn new() -> Self {
412        Self {
413            add_hsts: true,
414            add_frame_options: true,
415            add_content_type_options: true,
416            add_xss_protection: true,
417        }
418    }
419
420    pub fn with_hsts(mut self, enabled: bool) -> Self {
421        self.add_hsts = enabled;
422        self
423    }
424
425    pub fn with_frame_options(mut self, enabled: bool) -> Self {
426        self.add_frame_options = enabled;
427        self
428    }
429}
430
431#[async_trait]
432impl Middleware for SecurityHeadersMiddleware {
433    async fn call(&self, req: Request, next: Next) -> Result<Response> {
434        let mut response = next.run(req).await?;
435
436        if self.add_hsts {
437            response = response.header(
438                "Strict-Transport-Security",
439                "max-age=31536000; includeSubDomains",
440            );
441        }
442
443        if self.add_frame_options {
444            response = response.header("X-Frame-Options", "DENY");
445        }
446
447        if self.add_content_type_options {
448            response = response.header("X-Content-Type-Options", "nosniff");
449        }
450
451        if self.add_xss_protection {
452            response = response.header("X-XSS-Protection", "1; mode=block");
453        }
454
455        Ok(response)
456    }
457}
458
459/// Request metrics and monitoring middleware
460pub struct MetricsMiddleware {
461    pub enabled: bool,
462    pub collect_timing: bool,
463    pub collect_errors: bool,
464    pub request_count: Arc<Mutex<u64>>,
465    pub error_count: Arc<Mutex<u64>>,
466    pub total_duration: Arc<Mutex<Duration>>,
467}
468
469impl Default for MetricsMiddleware {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475impl MetricsMiddleware {
476    pub fn new() -> Self {
477        Self {
478            enabled: true,
479            collect_timing: true,
480            collect_errors: true,
481            request_count: Arc::new(Mutex::new(0)),
482            error_count: Arc::new(Mutex::new(0)),
483            total_duration: Arc::new(Mutex::new(Duration::ZERO)),
484        }
485    }
486
487    pub fn get_stats(&self) -> (u64, u64, Duration) {
488        let req_count = *self.request_count.lock().unwrap();
489        let err_count = *self.error_count.lock().unwrap();
490        let total_dur = *self.total_duration.lock().unwrap();
491        (req_count, err_count, total_dur)
492    }
493}
494
495#[async_trait]
496impl Middleware for MetricsMiddleware {
497    async fn call(&self, req: Request, next: Next) -> Result<Response> {
498        if !self.enabled {
499            return next.run(req).await;
500        }
501
502        let start = if self.collect_timing {
503            Some(Instant::now())
504        } else {
505            None
506        };
507
508        // Increment request count
509        *self.request_count.lock().unwrap() += 1;
510
511        let result = next.run(req).await;
512
513        // Collect timing
514        if let Some(start_time) = start {
515            let duration = start_time.elapsed();
516            *self.total_duration.lock().unwrap() += duration;
517        }
518
519        // Collect errors
520        if self.collect_errors && result.is_err() {
521            *self.error_count.lock().unwrap() += 1;
522        }
523
524        result
525    }
526}
527
528/// Response caching middleware (simple in-memory cache)
529pub struct CacheMiddleware {
530    pub enabled: bool,
531    pub cache_duration: Duration,
532    pub cache: Arc<Mutex<HashMap<String, (Response, SystemTime)>>>,
533}
534
535impl CacheMiddleware {
536    pub fn new(cache_duration: Duration) -> Self {
537        Self {
538            enabled: true,
539            cache_duration,
540            cache: Arc::new(Mutex::new(HashMap::new())),
541        }
542    }
543
544    fn cache_key(&self, req: &Request) -> String {
545        format!("{}:{}", req.method.as_str(), req.path())
546    }
547
548    fn get_cached(&self, key: &str) -> Option<Response> {
549        let mut cache = self.cache.lock().unwrap();
550
551        if let Some((response, timestamp)) = cache.get(key) {
552            let now = SystemTime::now();
553            if now.duration_since(*timestamp).unwrap_or(Duration::MAX) < self.cache_duration {
554                return Some(response.clone());
555            } else {
556                cache.remove(key);
557            }
558        }
559
560        None
561    }
562
563    fn cache_response(&self, key: String, response: &Response) {
564        if response.status.0 == 200 {
565            let mut cache = self.cache.lock().unwrap();
566            cache.insert(key, (response.clone(), SystemTime::now()));
567        }
568    }
569}
570
571#[async_trait]
572impl Middleware for CacheMiddleware {
573    async fn call(&self, req: Request, next: Next) -> Result<Response> {
574        if !self.enabled || req.method != crate::types::HttpMethod::GET {
575            return next.run(req).await;
576        }
577
578        let cache_key = self.cache_key(&req);
579
580        // Try to get from cache
581        if let Some(cached_response) = self.get_cached(&cache_key) {
582            return Ok(cached_response.header("X-Cache", "HIT"));
583        }
584
585        // Execute request
586        let response = next.run(req).await?;
587
588        // Cache the response
589        self.cache_response(cache_key, &response);
590
591        Ok(response.header("X-Cache", "MISS"))
592    }
593}
594
595/// Path parameter extraction middleware
596/// Automatically extracts path parameters and adds them to the request
597pub struct PathParameterMiddleware {
598    route_patterns: Vec<(String, crate::types::HttpMethod)>,
599}
600
601impl PathParameterMiddleware {
602    pub fn new(route_patterns: Vec<(String, crate::types::HttpMethod)>) -> Self {
603        Self { route_patterns }
604    }
605
606    /// Match dynamic path patterns
607    fn match_dynamic_path(
608        &self,
609        pattern: &str,
610        path: &str,
611    ) -> Option<std::collections::HashMap<String, String>> {
612        let route_parts: Vec<&str> = pattern.split('/').collect();
613        let path_parts: Vec<&str> = path.split('/').collect();
614
615        if route_parts.len() != path_parts.len() {
616            // Handle wildcard at the end
617            if let Some(last_part) = route_parts.last()
618                && last_part.starts_with('*')
619                && route_parts.len() <= path_parts.len()
620            {
621                // Wildcard matches remaining path
622                let mut params = std::collections::HashMap::new();
623                let param_name = last_part.trim_start_matches('*');
624                if !param_name.is_empty() {
625                    let remaining_path = path_parts[route_parts.len() - 1..].join("/");
626                    params.insert(param_name.to_string(), remaining_path);
627                }
628                return Some(params);
629            }
630            return None;
631        }
632
633        let mut params = std::collections::HashMap::new();
634
635        for (route_part, path_part) in route_parts.iter().zip(path_parts.iter()) {
636            if route_part.starts_with(':') {
637                // Path parameter
638                let param_name = route_part.trim_start_matches(':');
639                params.insert(param_name.to_string(), path_part.to_string());
640            } else if route_part.starts_with('*') {
641                // Wildcard
642                let param_name = route_part.trim_start_matches('*');
643                if !param_name.is_empty() {
644                    params.insert(param_name.to_string(), path_part.to_string());
645                }
646            } else if route_part != path_part {
647                // Exact match required
648                return None;
649            }
650        }
651
652        Some(params)
653    }
654}
655
656#[async_trait]
657impl Middleware for PathParameterMiddleware {
658    async fn call(&self, mut req: Request, next: Next) -> Result<Response> {
659        // Find matching route and extract parameters
660        for (pattern, method) in &self.route_patterns {
661            if *method == req.method
662                && let Some(params) = self.match_dynamic_path(pattern, req.path())
663            {
664                req.set_params(params);
665                break;
666            }
667        }
668
669        next.run(req).await
670    }
671}
672
673/// Request/Response transformation middleware
674/// Allows custom transformations of requests and responses
675pub struct TransformMiddleware<F, G>
676where
677    F: Fn(Request) -> Request + Send + Sync + 'static,
678    G: Fn(Response) -> Response + Send + Sync + 'static,
679{
680    request_transform: F,
681    response_transform: G,
682}
683
684impl<F, G> TransformMiddleware<F, G>
685where
686    F: Fn(Request) -> Request + Send + Sync + 'static,
687    G: Fn(Response) -> Response + Send + Sync + 'static,
688{
689    pub fn new(request_transform: F, response_transform: G) -> Self {
690        Self {
691            request_transform,
692            response_transform,
693        }
694    }
695}
696
697#[async_trait]
698impl<F, G> Middleware for TransformMiddleware<F, G>
699where
700    F: Fn(Request) -> Request + Send + Sync + 'static,
701    G: Fn(Response) -> Response + Send + Sync + 'static,
702{
703    async fn call(&self, req: Request, next: Next) -> Result<Response> {
704        let transformed_req = (self.request_transform)(req);
705        let response = next.run(transformed_req).await?;
706        Ok((self.response_transform)(response))
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use crate::types::{Body, Headers, Request, Response, StatusCode};
713
714    #[tokio::test]
715    async fn test_middleware_chain() {
716        // Create a simple handler
717        let _handler = move |_req: Request| async move {
718            Ok::<Response, crate::error::WebServerError>(Response {
719                status: StatusCode::OK,
720                headers: Headers::new(),
721                body: Body::from("test response"),
722            })
723        };
724
725        // Test that basic middleware compilation works
726        // Actual middleware functionality would be tested when those modules are re-enabled
727        // This test verifies that the basic middleware framework compiles and links correctly
728    }
729}