web_server_abstraction/
enhanced_middleware.rs

1//! Enhanced Middleware Integration
2//!
3//! This module provides a comprehensive middleware system with built-in middleware
4//! for common web server needs including CORS, compression, rate limiting, and more.
5
6use crate::{
7    config::{CompressionConfig, CorsConfig, SecurityConfig},
8    core::{Middleware, Next},
9    error::{Result, WebServerError},
10    types::{Headers, Request, Response},
11};
12use async_trait::async_trait;
13use bytes::Bytes;
14use flate2::{Compression, write::GzEncoder};
15use std::io::Write;
16use std::{
17    collections::HashMap,
18    sync::{Arc, Mutex},
19    time::{Duration, Instant},
20};
21
22/// Enhanced middleware trait with configuration support
23#[async_trait]
24pub trait EnhancedMiddleware: Send + Sync {
25    /// Process a request before it reaches the handler
26    async fn before_request(&self, request: &mut Request) -> Result<Option<Response>>;
27
28    /// Process a response after the handler
29    async fn after_response(&self, response: &mut Response) -> Result<()>;
30
31    /// Get middleware name for debugging/logging
32    fn name(&self) -> &'static str;
33
34    /// Check if middleware is enabled
35    fn is_enabled(&self) -> bool {
36        true
37    }
38}
39
40/// CORS middleware with configurable options
41pub struct CorsMiddleware {
42    config: CorsConfig,
43}
44
45impl CorsMiddleware {
46    pub fn new(config: CorsConfig) -> Self {
47        Self { config }
48    }
49
50    /// Check if origin is allowed
51    fn is_origin_allowed(&self, origin: &str) -> bool {
52        self.config.allowed_origins.contains(&"*".to_string())
53            || self.config.allowed_origins.contains(&origin.to_string())
54    }
55
56    /// Get allowed methods as a string
57    fn get_allowed_methods(&self) -> String {
58        self.config.allowed_methods.join(", ")
59    }
60
61    /// Get allowed headers as a string
62    fn get_allowed_headers(&self) -> String {
63        self.config.allowed_headers.join(", ")
64    }
65}
66
67#[async_trait]
68impl EnhancedMiddleware for CorsMiddleware {
69    async fn before_request(&self, request: &mut Request) -> Result<Option<Response>> {
70        if !self.config.enabled {
71            return Ok(None);
72        }
73
74        // Handle preflight requests
75        if request.method == crate::types::HttpMethod::OPTIONS {
76            let origin = request.headers.get("origin").cloned().unwrap_or_default();
77
78            if self.is_origin_allowed(&origin) {
79                let mut response = Response::new(crate::types::StatusCode::OK);
80
81                // Add CORS headers
82                response
83                    .headers
84                    .insert("Access-Control-Allow-Origin".to_string(), origin);
85                response.headers.insert(
86                    "Access-Control-Allow-Methods".to_string(),
87                    self.get_allowed_methods(),
88                );
89                response.headers.insert(
90                    "Access-Control-Allow-Headers".to_string(),
91                    self.get_allowed_headers(),
92                );
93
94                if self.config.max_age > 0 {
95                    response.headers.insert(
96                        "Access-Control-Max-Age".to_string(),
97                        self.config.max_age.to_string(),
98                    );
99                }
100
101                return Ok(Some(response));
102            }
103        }
104
105        Ok(None)
106    }
107
108    async fn after_response(&self, response: &mut Response) -> Result<()> {
109        if !self.config.enabled {
110            return Ok(());
111        }
112
113        // Add CORS headers to all responses
114        if let Some(origin) = response.headers.get("origin") {
115            if self.is_origin_allowed(origin) {
116                response
117                    .headers
118                    .insert("Access-Control-Allow-Origin".to_string(), origin.clone());
119            }
120        } else if self.config.allowed_origins.contains(&"*".to_string()) {
121            response
122                .headers
123                .insert("Access-Control-Allow-Origin".to_string(), "*".to_string());
124        }
125
126        Ok(())
127    }
128
129    fn name(&self) -> &'static str {
130        "CORS"
131    }
132
133    fn is_enabled(&self) -> bool {
134        self.config.enabled
135    }
136}
137
138/// Compression middleware with configurable compression levels
139pub struct CompressionMiddleware {
140    config: CompressionConfig,
141}
142
143impl CompressionMiddleware {
144    pub fn new(config: CompressionConfig) -> Self {
145        Self { config }
146    }
147
148    /// Check if request accepts gzip compression
149    #[allow(dead_code)]
150    fn accepts_gzip(&self, request: &Request) -> bool {
151        if let Some(accept_encoding) = request.headers.get("accept-encoding") {
152            accept_encoding.contains("gzip")
153        } else {
154            false
155        }
156    }
157
158    /// Compress data using gzip
159    fn compress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
160        let mut encoder = GzEncoder::new(Vec::new(), Compression::new(6)); // Default compression level
161        encoder
162            .write_all(data)
163            .map_err(|e| WebServerError::custom(format!("Compression failed: {}", e)))?;
164        encoder
165            .finish()
166            .map_err(|e| WebServerError::custom(format!("Compression finish failed: {}", e)))
167    }
168}
169
170#[async_trait]
171impl EnhancedMiddleware for CompressionMiddleware {
172    async fn before_request(&self, _request: &mut Request) -> Result<Option<Response>> {
173        // Compression is applied in after_response
174        Ok(None)
175    }
176
177    async fn after_response(&self, response: &mut Response) -> Result<()> {
178        if !self.config.enabled {
179            return Ok(());
180        }
181
182        // Get response body
183        let body_bytes = response.body.bytes().await?;
184
185        // Only compress if body is large enough
186        if body_bytes.len() < self.config.min_size {
187            return Ok(());
188        }
189
190        // Check if content is already compressed
191        if response.headers.get("content-encoding").is_some() {
192            return Ok(());
193        }
194
195        // Compress the body
196        let compressed = self.compress_gzip(&body_bytes)?;
197
198        // Update response
199        response.body = crate::types::Body::from_bytes(Bytes::from(compressed));
200        response
201            .headers
202            .insert("Content-Encoding".to_string(), "gzip".to_string());
203        response.headers.insert(
204            "Content-Length".to_string(),
205            response.body.bytes().await?.len().to_string(),
206        );
207
208        Ok(())
209    }
210
211    fn name(&self) -> &'static str {
212        "Compression"
213    }
214
215    fn is_enabled(&self) -> bool {
216        self.config.enabled
217    }
218}
219
220#[async_trait]
221impl Middleware for CompressionMiddleware {
222    async fn call(&self, req: Request, next: Next) -> Result<Response> {
223        // Call the next middleware/handler
224        let mut response = next.run(req).await?;
225
226        // Apply compression to the response
227        if self.is_enabled() {
228            self.after_response(&mut response).await?;
229        }
230
231        Ok(response)
232    }
233}
234
235/// Rate limiting middleware
236pub struct RateLimitMiddleware {
237    config: SecurityConfig,
238    requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
239}
240
241impl RateLimitMiddleware {
242    pub fn new(config: SecurityConfig) -> Self {
243        Self {
244            config,
245            requests: Arc::new(Mutex::new(HashMap::new())),
246        }
247    }
248
249    /// Check if request is rate limited
250    fn is_rate_limited(&self, client_ip: &str) -> bool {
251        let rate_limit = match self.config.rate_limit_per_minute {
252            Some(limit) => limit,
253            None => return false, // No rate limiting
254        };
255
256        let mut requests = self.requests.lock().unwrap();
257        let now = Instant::now();
258        let one_minute_ago = now - Duration::from_secs(60);
259
260        // Get or create request history for this IP
261        let client_requests = requests.entry(client_ip.to_string()).or_default();
262
263        // Remove requests older than 1 minute
264        client_requests.retain(|&request_time| request_time > one_minute_ago);
265
266        // Check if rate limit exceeded
267        if client_requests.len() >= rate_limit as usize {
268            return true;
269        }
270
271        // Add current request
272        client_requests.push(now);
273        false
274    }
275
276    /// Get client IP from request
277    fn get_client_ip(&self, request: &Request) -> String {
278        // Try various headers for client IP
279        if let Some(forwarded) = request.headers.get("x-forwarded-for") {
280            if let Some(ip) = forwarded.split(',').next() {
281                return ip.trim().to_string();
282            }
283        }
284
285        if let Some(real_ip) = request.headers.get("x-real-ip") {
286            return real_ip.clone();
287        }
288
289        // Fallback to "unknown" (in production you'd get this from the connection)
290        "unknown".to_string()
291    }
292}
293
294#[async_trait]
295impl EnhancedMiddleware for RateLimitMiddleware {
296    async fn before_request(&self, request: &mut Request) -> Result<Option<Response>> {
297        if self.config.rate_limit_per_minute.is_none() {
298            return Ok(None);
299        }
300
301        let client_ip = self.get_client_ip(request);
302
303        if self.is_rate_limited(&client_ip) {
304            let mut response = Response::new(crate::types::StatusCode::TOO_MANY_REQUESTS);
305            response
306                .headers
307                .insert("Retry-After".to_string(), "60".to_string());
308            response.body = crate::types::Body::from_string("Rate limit exceeded");
309            return Ok(Some(response));
310        }
311
312        Ok(None)
313    }
314
315    async fn after_response(&self, _response: &mut Response) -> Result<()> {
316        Ok(())
317    }
318
319    fn name(&self) -> &'static str {
320        "RateLimit"
321    }
322
323    fn is_enabled(&self) -> bool {
324        self.config.rate_limit_per_minute.is_some()
325    }
326}
327
328/// Security headers middleware
329pub struct SecurityHeadersMiddleware {
330    config: SecurityConfig,
331}
332
333impl SecurityHeadersMiddleware {
334    pub fn new(config: SecurityConfig) -> Self {
335        Self { config }
336    }
337}
338
339#[async_trait]
340impl EnhancedMiddleware for SecurityHeadersMiddleware {
341    async fn before_request(&self, _request: &mut Request) -> Result<Option<Response>> {
342        Ok(None)
343    }
344
345    async fn after_response(&self, response: &mut Response) -> Result<()> {
346        // Add security headers
347        response
348            .headers
349            .insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
350        response
351            .headers
352            .insert("X-Frame-Options".to_string(), "DENY".to_string());
353        response
354            .headers
355            .insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
356        response.headers.insert(
357            "Referrer-Policy".to_string(),
358            "strict-origin-when-cross-origin".to_string(),
359        );
360
361        // Add HSTS header if TLS is enabled
362        if self.config.tls.enabled {
363            response.headers.insert(
364                "Strict-Transport-Security".to_string(),
365                "max-age=31536000; includeSubDomains".to_string(),
366            );
367        }
368
369        // Add CSP header if CSRF protection is enabled
370        if self.config.enable_csrf_protection {
371            response.headers.insert(
372                "Content-Security-Policy".to_string(),
373                "default-src 'self'".to_string(),
374            );
375        }
376
377        Ok(())
378    }
379
380    fn name(&self) -> &'static str {
381        "SecurityHeaders"
382    }
383}
384
385/// Middleware stack manager
386pub struct MiddlewareStack {
387    middlewares: Vec<Box<dyn EnhancedMiddleware>>,
388}
389
390impl MiddlewareStack {
391    pub fn new() -> Self {
392        Self {
393            middlewares: Vec::new(),
394        }
395    }
396
397    /// Add middleware to the stack
398    pub fn add_middleware(&mut self, middleware: Box<dyn EnhancedMiddleware>) {
399        self.middlewares.push(middleware);
400    }
401
402    /// Create a default middleware stack from configuration
403    pub fn from_config(
404        cors_config: CorsConfig,
405        compression_config: CompressionConfig,
406        security_config: SecurityConfig,
407    ) -> Self {
408        let mut stack = Self::new();
409
410        // Add middlewares in order of execution
411        stack.add_middleware(Box::new(SecurityHeadersMiddleware::new(
412            security_config.clone(),
413        )));
414        stack.add_middleware(Box::new(RateLimitMiddleware::new(security_config)));
415        stack.add_middleware(Box::new(CorsMiddleware::new(cors_config)));
416        stack.add_middleware(Box::new(CompressionMiddleware::new(compression_config)));
417
418        stack
419    }
420
421    /// Process request through all middleware
422    pub async fn process_request(&self, request: &mut Request) -> Result<Option<Response>> {
423        for middleware in &self.middlewares {
424            if !middleware.is_enabled() {
425                continue;
426            }
427
428            if let Some(response) = middleware.before_request(request).await? {
429                return Ok(Some(response));
430            }
431        }
432        Ok(None)
433    }
434
435    /// Process response through all middleware (in reverse order)
436    pub async fn process_response(&self, response: &mut Response) -> Result<()> {
437        for middleware in self.middlewares.iter().rev() {
438            if !middleware.is_enabled() {
439                continue;
440            }
441
442            middleware.after_response(response).await?;
443        }
444        Ok(())
445    }
446
447    /// Get list of enabled middlewares
448    pub fn get_enabled_middlewares(&self) -> Vec<&'static str> {
449        self.middlewares
450            .iter()
451            .filter(|m| m.is_enabled())
452            .map(|m| m.name())
453            .collect()
454    }
455}
456
457impl Default for MiddlewareStack {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use crate::types::{HttpMethod, StatusCode};
467
468    #[tokio::test]
469    async fn test_cors_middleware() {
470        let config = CorsConfig {
471            enabled: true,
472            allowed_origins: vec!["https://example.com".to_string()],
473            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
474            allowed_headers: vec!["content-type".to_string()],
475            credentials: false,
476            max_age: 3600,
477        };
478
479        let middleware = CorsMiddleware::new(config);
480
481        // Test preflight request
482        let mut request = Request {
483            method: HttpMethod::OPTIONS,
484            uri: http::Uri::from_static("https://example.com/test"),
485            version: http::Version::HTTP_11,
486            headers: {
487                let mut headers = Headers::new();
488                headers.insert("origin".to_string(), "https://example.com".to_string());
489                headers
490            },
491            body: crate::types::Body::empty(),
492            extensions: std::collections::HashMap::new(),
493            path_params: std::collections::HashMap::new(),
494            cookies: std::collections::HashMap::new(),
495            form_data: None,
496            multipart: None,
497        };
498
499        let response = middleware.before_request(&mut request).await.unwrap();
500        assert!(response.is_some());
501
502        let response = response.unwrap();
503        assert_eq!(response.status, StatusCode::OK);
504        assert_eq!(
505            response.headers.get("Access-Control-Allow-Origin"),
506            Some(&"https://example.com".to_string())
507        );
508    }
509
510    #[tokio::test]
511    async fn test_rate_limit_middleware() {
512        let config = SecurityConfig {
513            rate_limit_per_minute: Some(2),
514            ..Default::default()
515        };
516
517        let middleware = RateLimitMiddleware::new(config);
518
519        let mut request = Request {
520            method: HttpMethod::GET,
521            uri: http::Uri::from_static("https://example.com/test"),
522            version: http::Version::HTTP_11,
523            headers: {
524                let mut headers = Headers::new();
525                headers.insert("x-forwarded-for".to_string(), "192.168.1.1".to_string());
526                headers
527            },
528            body: crate::types::Body::empty(),
529            extensions: std::collections::HashMap::new(),
530            path_params: std::collections::HashMap::new(),
531            cookies: std::collections::HashMap::new(),
532            form_data: None,
533            multipart: None,
534        };
535
536        // First request should pass
537        let response1 = middleware.before_request(&mut request).await.unwrap();
538        assert!(response1.is_none());
539
540        // Second request should pass
541        let response2 = middleware.before_request(&mut request).await.unwrap();
542        assert!(response2.is_none());
543
544        // Third request should be rate limited
545        let response3 = middleware.before_request(&mut request).await.unwrap();
546        assert!(response3.is_some());
547
548        let response = response3.unwrap();
549        assert_eq!(response.status, StatusCode::TOO_MANY_REQUESTS);
550    }
551
552    #[tokio::test]
553    async fn test_middleware_stack() {
554        let cors_config = CorsConfig::default();
555        let compression_config = CompressionConfig::default();
556        let security_config = SecurityConfig::default();
557
558        let stack = MiddlewareStack::from_config(cors_config, compression_config, security_config);
559
560        let enabled = stack.get_enabled_middlewares();
561        assert!(enabled.contains(&"SecurityHeaders"));
562        assert!(enabled.contains(&"CORS"));
563        assert!(enabled.contains(&"Compression"));
564    }
565}