1use crate::{
7 config::{CompressionConfig, CorsConfig, SecurityConfig},
8 core::{Middleware, Next},
9 error::{Result, WebServerError},
10 types::{Headers, Request, Response},
11};
12use async_trait::async_trait;
13use bytes::Bytes;
14use flate2::{Compression, write::GzEncoder};
15use std::io::Write;
16use std::{
17 collections::HashMap,
18 sync::{Arc, Mutex},
19 time::{Duration, Instant},
20};
21
22#[async_trait]
24pub trait EnhancedMiddleware: Send + Sync {
25 async fn before_request(&self, request: &mut Request) -> Result<Option<Response>>;
27
28 async fn after_response(&self, response: &mut Response) -> Result<()>;
30
31 fn name(&self) -> &'static str;
33
34 fn is_enabled(&self) -> bool {
36 true
37 }
38}
39
40pub struct CorsMiddleware {
42 config: CorsConfig,
43}
44
45impl CorsMiddleware {
46 pub fn new(config: CorsConfig) -> Self {
47 Self { config }
48 }
49
50 fn is_origin_allowed(&self, origin: &str) -> bool {
52 self.config.allowed_origins.contains(&"*".to_string())
53 || self.config.allowed_origins.contains(&origin.to_string())
54 }
55
56 fn get_allowed_methods(&self) -> String {
58 self.config.allowed_methods.join(", ")
59 }
60
61 fn get_allowed_headers(&self) -> String {
63 self.config.allowed_headers.join(", ")
64 }
65}
66
67#[async_trait]
68impl EnhancedMiddleware for CorsMiddleware {
69 async fn before_request(&self, request: &mut Request) -> Result<Option<Response>> {
70 if !self.config.enabled {
71 return Ok(None);
72 }
73
74 if request.method == crate::types::HttpMethod::OPTIONS {
76 let origin = request.headers.get("origin").cloned().unwrap_or_default();
77
78 if self.is_origin_allowed(&origin) {
79 let mut response = Response::new(crate::types::StatusCode::OK);
80
81 response
83 .headers
84 .insert("Access-Control-Allow-Origin".to_string(), origin);
85 response.headers.insert(
86 "Access-Control-Allow-Methods".to_string(),
87 self.get_allowed_methods(),
88 );
89 response.headers.insert(
90 "Access-Control-Allow-Headers".to_string(),
91 self.get_allowed_headers(),
92 );
93
94 if self.config.max_age > 0 {
95 response.headers.insert(
96 "Access-Control-Max-Age".to_string(),
97 self.config.max_age.to_string(),
98 );
99 }
100
101 return Ok(Some(response));
102 }
103 }
104
105 Ok(None)
106 }
107
108 async fn after_response(&self, response: &mut Response) -> Result<()> {
109 if !self.config.enabled {
110 return Ok(());
111 }
112
113 if let Some(origin) = response.headers.get("origin") {
115 if self.is_origin_allowed(origin) {
116 response
117 .headers
118 .insert("Access-Control-Allow-Origin".to_string(), origin.clone());
119 }
120 } else if self.config.allowed_origins.contains(&"*".to_string()) {
121 response
122 .headers
123 .insert("Access-Control-Allow-Origin".to_string(), "*".to_string());
124 }
125
126 Ok(())
127 }
128
129 fn name(&self) -> &'static str {
130 "CORS"
131 }
132
133 fn is_enabled(&self) -> bool {
134 self.config.enabled
135 }
136}
137
138pub struct CompressionMiddleware {
140 config: CompressionConfig,
141}
142
143impl CompressionMiddleware {
144 pub fn new(config: CompressionConfig) -> Self {
145 Self { config }
146 }
147
148 #[allow(dead_code)]
150 fn accepts_gzip(&self, request: &Request) -> bool {
151 if let Some(accept_encoding) = request.headers.get("accept-encoding") {
152 accept_encoding.contains("gzip")
153 } else {
154 false
155 }
156 }
157
158 fn compress_gzip(&self, data: &[u8]) -> Result<Vec<u8>> {
160 let mut encoder = GzEncoder::new(Vec::new(), Compression::new(6)); encoder
162 .write_all(data)
163 .map_err(|e| WebServerError::custom(format!("Compression failed: {}", e)))?;
164 encoder
165 .finish()
166 .map_err(|e| WebServerError::custom(format!("Compression finish failed: {}", e)))
167 }
168}
169
170#[async_trait]
171impl EnhancedMiddleware for CompressionMiddleware {
172 async fn before_request(&self, _request: &mut Request) -> Result<Option<Response>> {
173 Ok(None)
175 }
176
177 async fn after_response(&self, response: &mut Response) -> Result<()> {
178 if !self.config.enabled {
179 return Ok(());
180 }
181
182 let body_bytes = response.body.bytes().await?;
184
185 if body_bytes.len() < self.config.min_size {
187 return Ok(());
188 }
189
190 if response.headers.get("content-encoding").is_some() {
192 return Ok(());
193 }
194
195 let compressed = self.compress_gzip(&body_bytes)?;
197
198 response.body = crate::types::Body::from_bytes(Bytes::from(compressed));
200 response
201 .headers
202 .insert("Content-Encoding".to_string(), "gzip".to_string());
203 response.headers.insert(
204 "Content-Length".to_string(),
205 response.body.bytes().await?.len().to_string(),
206 );
207
208 Ok(())
209 }
210
211 fn name(&self) -> &'static str {
212 "Compression"
213 }
214
215 fn is_enabled(&self) -> bool {
216 self.config.enabled
217 }
218}
219
220#[async_trait]
221impl Middleware for CompressionMiddleware {
222 async fn call(&self, req: Request, next: Next) -> Result<Response> {
223 let mut response = next.run(req).await?;
225
226 if self.is_enabled() {
228 self.after_response(&mut response).await?;
229 }
230
231 Ok(response)
232 }
233}
234
235pub struct RateLimitMiddleware {
237 config: SecurityConfig,
238 requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
239}
240
241impl RateLimitMiddleware {
242 pub fn new(config: SecurityConfig) -> Self {
243 Self {
244 config,
245 requests: Arc::new(Mutex::new(HashMap::new())),
246 }
247 }
248
249 fn is_rate_limited(&self, client_ip: &str) -> bool {
251 let rate_limit = match self.config.rate_limit_per_minute {
252 Some(limit) => limit,
253 None => return false, };
255
256 let mut requests = self.requests.lock().unwrap();
257 let now = Instant::now();
258 let one_minute_ago = now - Duration::from_secs(60);
259
260 let client_requests = requests.entry(client_ip.to_string()).or_default();
262
263 client_requests.retain(|&request_time| request_time > one_minute_ago);
265
266 if client_requests.len() >= rate_limit as usize {
268 return true;
269 }
270
271 client_requests.push(now);
273 false
274 }
275
276 fn get_client_ip(&self, request: &Request) -> String {
278 if let Some(forwarded) = request.headers.get("x-forwarded-for") {
280 if let Some(ip) = forwarded.split(',').next() {
281 return ip.trim().to_string();
282 }
283 }
284
285 if let Some(real_ip) = request.headers.get("x-real-ip") {
286 return real_ip.clone();
287 }
288
289 "unknown".to_string()
291 }
292}
293
294#[async_trait]
295impl EnhancedMiddleware for RateLimitMiddleware {
296 async fn before_request(&self, request: &mut Request) -> Result<Option<Response>> {
297 if self.config.rate_limit_per_minute.is_none() {
298 return Ok(None);
299 }
300
301 let client_ip = self.get_client_ip(request);
302
303 if self.is_rate_limited(&client_ip) {
304 let mut response = Response::new(crate::types::StatusCode::TOO_MANY_REQUESTS);
305 response
306 .headers
307 .insert("Retry-After".to_string(), "60".to_string());
308 response.body = crate::types::Body::from_string("Rate limit exceeded");
309 return Ok(Some(response));
310 }
311
312 Ok(None)
313 }
314
315 async fn after_response(&self, _response: &mut Response) -> Result<()> {
316 Ok(())
317 }
318
319 fn name(&self) -> &'static str {
320 "RateLimit"
321 }
322
323 fn is_enabled(&self) -> bool {
324 self.config.rate_limit_per_minute.is_some()
325 }
326}
327
328pub struct SecurityHeadersMiddleware {
330 config: SecurityConfig,
331}
332
333impl SecurityHeadersMiddleware {
334 pub fn new(config: SecurityConfig) -> Self {
335 Self { config }
336 }
337}
338
339#[async_trait]
340impl EnhancedMiddleware for SecurityHeadersMiddleware {
341 async fn before_request(&self, _request: &mut Request) -> Result<Option<Response>> {
342 Ok(None)
343 }
344
345 async fn after_response(&self, response: &mut Response) -> Result<()> {
346 response
348 .headers
349 .insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
350 response
351 .headers
352 .insert("X-Frame-Options".to_string(), "DENY".to_string());
353 response
354 .headers
355 .insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
356 response.headers.insert(
357 "Referrer-Policy".to_string(),
358 "strict-origin-when-cross-origin".to_string(),
359 );
360
361 if self.config.tls.enabled {
363 response.headers.insert(
364 "Strict-Transport-Security".to_string(),
365 "max-age=31536000; includeSubDomains".to_string(),
366 );
367 }
368
369 if self.config.enable_csrf_protection {
371 response.headers.insert(
372 "Content-Security-Policy".to_string(),
373 "default-src 'self'".to_string(),
374 );
375 }
376
377 Ok(())
378 }
379
380 fn name(&self) -> &'static str {
381 "SecurityHeaders"
382 }
383}
384
385pub struct MiddlewareStack {
387 middlewares: Vec<Box<dyn EnhancedMiddleware>>,
388}
389
390impl MiddlewareStack {
391 pub fn new() -> Self {
392 Self {
393 middlewares: Vec::new(),
394 }
395 }
396
397 pub fn add_middleware(&mut self, middleware: Box<dyn EnhancedMiddleware>) {
399 self.middlewares.push(middleware);
400 }
401
402 pub fn from_config(
404 cors_config: CorsConfig,
405 compression_config: CompressionConfig,
406 security_config: SecurityConfig,
407 ) -> Self {
408 let mut stack = Self::new();
409
410 stack.add_middleware(Box::new(SecurityHeadersMiddleware::new(
412 security_config.clone(),
413 )));
414 stack.add_middleware(Box::new(RateLimitMiddleware::new(security_config)));
415 stack.add_middleware(Box::new(CorsMiddleware::new(cors_config)));
416 stack.add_middleware(Box::new(CompressionMiddleware::new(compression_config)));
417
418 stack
419 }
420
421 pub async fn process_request(&self, request: &mut Request) -> Result<Option<Response>> {
423 for middleware in &self.middlewares {
424 if !middleware.is_enabled() {
425 continue;
426 }
427
428 if let Some(response) = middleware.before_request(request).await? {
429 return Ok(Some(response));
430 }
431 }
432 Ok(None)
433 }
434
435 pub async fn process_response(&self, response: &mut Response) -> Result<()> {
437 for middleware in self.middlewares.iter().rev() {
438 if !middleware.is_enabled() {
439 continue;
440 }
441
442 middleware.after_response(response).await?;
443 }
444 Ok(())
445 }
446
447 pub fn get_enabled_middlewares(&self) -> Vec<&'static str> {
449 self.middlewares
450 .iter()
451 .filter(|m| m.is_enabled())
452 .map(|m| m.name())
453 .collect()
454 }
455}
456
457impl Default for MiddlewareStack {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::types::{HttpMethod, StatusCode};
467
468 #[tokio::test]
469 async fn test_cors_middleware() {
470 let config = CorsConfig {
471 enabled: true,
472 allowed_origins: vec!["https://example.com".to_string()],
473 allowed_methods: vec!["GET".to_string(), "POST".to_string()],
474 allowed_headers: vec!["content-type".to_string()],
475 credentials: false,
476 max_age: 3600,
477 };
478
479 let middleware = CorsMiddleware::new(config);
480
481 let mut request = Request {
483 method: HttpMethod::OPTIONS,
484 uri: http::Uri::from_static("https://example.com/test"),
485 version: http::Version::HTTP_11,
486 headers: {
487 let mut headers = Headers::new();
488 headers.insert("origin".to_string(), "https://example.com".to_string());
489 headers
490 },
491 body: crate::types::Body::empty(),
492 extensions: std::collections::HashMap::new(),
493 path_params: std::collections::HashMap::new(),
494 cookies: std::collections::HashMap::new(),
495 form_data: None,
496 multipart: None,
497 };
498
499 let response = middleware.before_request(&mut request).await.unwrap();
500 assert!(response.is_some());
501
502 let response = response.unwrap();
503 assert_eq!(response.status, StatusCode::OK);
504 assert_eq!(
505 response.headers.get("Access-Control-Allow-Origin"),
506 Some(&"https://example.com".to_string())
507 );
508 }
509
510 #[tokio::test]
511 async fn test_rate_limit_middleware() {
512 let config = SecurityConfig {
513 rate_limit_per_minute: Some(2),
514 ..Default::default()
515 };
516
517 let middleware = RateLimitMiddleware::new(config);
518
519 let mut request = Request {
520 method: HttpMethod::GET,
521 uri: http::Uri::from_static("https://example.com/test"),
522 version: http::Version::HTTP_11,
523 headers: {
524 let mut headers = Headers::new();
525 headers.insert("x-forwarded-for".to_string(), "192.168.1.1".to_string());
526 headers
527 },
528 body: crate::types::Body::empty(),
529 extensions: std::collections::HashMap::new(),
530 path_params: std::collections::HashMap::new(),
531 cookies: std::collections::HashMap::new(),
532 form_data: None,
533 multipart: None,
534 };
535
536 let response1 = middleware.before_request(&mut request).await.unwrap();
538 assert!(response1.is_none());
539
540 let response2 = middleware.before_request(&mut request).await.unwrap();
542 assert!(response2.is_none());
543
544 let response3 = middleware.before_request(&mut request).await.unwrap();
546 assert!(response3.is_some());
547
548 let response = response3.unwrap();
549 assert_eq!(response.status, StatusCode::TOO_MANY_REQUESTS);
550 }
551
552 #[tokio::test]
553 async fn test_middleware_stack() {
554 let cors_config = CorsConfig::default();
555 let compression_config = CompressionConfig::default();
556 let security_config = SecurityConfig::default();
557
558 let stack = MiddlewareStack::from_config(cors_config, compression_config, security_config);
559
560 let enabled = stack.get_enabled_middlewares();
561 assert!(enabled.contains(&"SecurityHeaders"));
562 assert!(enabled.contains(&"CORS"));
563 assert!(enabled.contains(&"Compression"));
564 }
565}