1use crate::core::{Middleware, Next};
4use crate::error::Result;
5use crate::types::{Request, Response, StatusCode};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant, SystemTime};
10
11pub struct LoggingMiddleware {
13 pub enabled: bool,
14 pub log_bodies: bool,
15}
16
17impl Default for LoggingMiddleware {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl LoggingMiddleware {
24 pub fn new() -> Self {
25 Self {
26 enabled: true,
27 log_bodies: false,
28 }
29 }
30
31 pub fn enabled(mut self, enabled: bool) -> Self {
32 self.enabled = enabled;
33 self
34 }
35
36 pub fn log_bodies(mut self, log_bodies: bool) -> Self {
37 self.log_bodies = log_bodies;
38 self
39 }
40}
41
42#[async_trait]
43impl Middleware for LoggingMiddleware {
44 async fn call(&self, req: Request, next: Next) -> Result<Response> {
45 if !self.enabled {
46 return next.run(req).await;
47 }
48
49 let start = Instant::now();
50 let method = req.method;
51 let path = req.path().to_string();
52
53 if self.log_bodies {
54 let body_preview = req.body.len().min(100);
55 println!("-> {:?} {} (body: {} bytes)", method, path, body_preview);
56 } else {
57 println!("-> {:?} {}", method, path);
58 }
59
60 let response = next.run(req).await;
61
62 let duration = start.elapsed();
63 match &response {
64 Ok(_resp) => {
65 println!("<- {:?} {} - 200 OK ({:?})", method, path, duration);
66 }
67 Err(err) => {
68 println!("<- {:?} {} - ERROR: {} ({:?})", method, path, err, duration);
69 }
70 }
71
72 response
73 }
74}
75
76pub struct CorsMiddleware {
78 pub allow_origin: String,
79 pub allow_methods: Vec<String>,
80 pub allow_headers: Vec<String>,
81 pub allow_credentials: bool,
82 pub expose_headers: Vec<String>,
83 pub max_age: Option<Duration>,
84}
85
86impl Default for CorsMiddleware {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl CorsMiddleware {
93 pub fn new() -> Self {
94 Self {
95 allow_origin: "*".to_string(),
96 allow_methods: vec![
97 "GET".to_string(),
98 "POST".to_string(),
99 "PUT".to_string(),
100 "DELETE".to_string(),
101 "OPTIONS".to_string(),
102 ],
103 allow_headers: vec![
104 "Content-Type".to_string(),
105 "Authorization".to_string(),
106 "Accept".to_string(),
107 "Origin".to_string(),
108 "X-Requested-With".to_string(),
109 ],
110 allow_credentials: false,
111 expose_headers: vec![],
112 max_age: Some(Duration::from_secs(86400)), }
114 }
115
116 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
117 self.allow_origin = origin.into();
118 self
119 }
120
121 pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
122 self.allow_methods = methods;
123 self
124 }
125
126 pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
127 self.allow_headers = headers;
128 self
129 }
130
131 pub fn allow_credentials(mut self, allow: bool) -> Self {
132 self.allow_credentials = allow;
133 self
134 }
135
136 pub fn expose_headers(mut self, headers: Vec<String>) -> Self {
137 self.expose_headers = headers;
138 self
139 }
140
141 pub fn max_age(mut self, max_age: Duration) -> Self {
142 self.max_age = Some(max_age);
143 self
144 }
145}
146
147#[async_trait]
148impl Middleware for CorsMiddleware {
149 async fn call(&self, req: Request, next: Next) -> Result<Response> {
150 if req.method == crate::types::HttpMethod::OPTIONS {
152 let mut response = Response::new(StatusCode::OK)
153 .header("Access-Control-Allow-Origin", &self.allow_origin)
154 .header(
155 "Access-Control-Allow-Methods",
156 self.allow_methods.join(", "),
157 )
158 .header(
159 "Access-Control-Allow-Headers",
160 self.allow_headers.join(", "),
161 );
162
163 if self.allow_credentials {
164 response = response.header("Access-Control-Allow-Credentials", "true");
165 }
166
167 if let Some(max_age) = self.max_age {
168 response = response.header("Access-Control-Max-Age", max_age.as_secs().to_string());
169 }
170
171 return Ok(response);
172 }
173
174 let response = next.run(req).await?;
175
176 let mut cors_response = response
177 .header("Access-Control-Allow-Origin", &self.allow_origin)
178 .header(
179 "Access-Control-Allow-Methods",
180 self.allow_methods.join(", "),
181 )
182 .header(
183 "Access-Control-Allow-Headers",
184 self.allow_headers.join(", "),
185 );
186
187 if self.allow_credentials {
188 cors_response = cors_response.header("Access-Control-Allow-Credentials", "true");
189 }
190
191 if !self.expose_headers.is_empty() {
192 cors_response = cors_response.header(
193 "Access-Control-Expose-Headers",
194 self.expose_headers.join(", "),
195 );
196 }
197
198 Ok(cors_response)
199 }
200}
201
202pub struct TimeoutMiddleware {
204 pub timeout: Duration,
205}
206
207impl TimeoutMiddleware {
208 pub fn new(timeout: Duration) -> Self {
209 Self { timeout }
210 }
211}
212
213#[async_trait]
214impl Middleware for TimeoutMiddleware {
215 async fn call(&self, req: Request, next: Next) -> Result<Response> {
216 if self.timeout.as_millis() < 100 {
219 println!("Warning: Very short timeout configured: {:?}", self.timeout);
220 }
221 next.run(req).await
222 }
223}
224
225pub struct RateLimitMiddleware {
227 pub max_requests: u32,
228 pub window: Duration,
229 pub store: Arc<Mutex<HashMap<String, (u32, SystemTime)>>>,
230}
231
232impl RateLimitMiddleware {
233 pub fn new(max_requests: u32, window: Duration) -> Self {
234 Self {
235 max_requests,
236 window,
237 store: Arc::new(Mutex::new(HashMap::new())),
238 }
239 }
240
241 fn get_client_key(&self, req: &Request) -> String {
242 format!("default:{}", req.path())
245 }
246
247 fn is_rate_limited(&self, key: &str) -> bool {
248 let mut store = self.store.lock().unwrap();
249 let now = SystemTime::now();
250
251 match store.get_mut(key) {
252 Some((count, last_reset)) => {
253 if now.duration_since(*last_reset).unwrap_or(Duration::ZERO) >= self.window {
255 *count = 1;
256 *last_reset = now;
257 false
258 } else if *count >= self.max_requests {
259 true
260 } else {
261 *count += 1;
262 false
263 }
264 }
265 None => {
266 store.insert(key.to_string(), (1, now));
267 false
268 }
269 }
270 }
271}
272
273#[async_trait]
274impl Middleware for RateLimitMiddleware {
275 async fn call(&self, req: Request, next: Next) -> Result<Response> {
276 let key = self.get_client_key(&req);
277
278 if self.is_rate_limited(&key) {
279 return Ok(Response::new(StatusCode(429))
280 .header("Content-Type", "application/json")
281 .body(r#"{"error": "Rate limit exceeded"}"#));
282 }
283
284 next.run(req).await
285 }
286}
287
288pub struct AuthMiddleware {
290 pub require_auth: bool,
291 pub bearer_tokens: Arc<Vec<String>>,
292}
293
294impl Default for AuthMiddleware {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl AuthMiddleware {
301 pub fn new() -> Self {
302 Self {
303 require_auth: true,
304 bearer_tokens: Arc::new(vec![]),
305 }
306 }
307
308 pub fn optional(mut self) -> Self {
309 self.require_auth = false;
310 self
311 }
312
313 pub fn with_bearer_tokens(mut self, tokens: Vec<String>) -> Self {
314 self.bearer_tokens = Arc::new(tokens);
315 self
316 }
317
318 fn validate_token(&self, authorization: &str) -> bool {
319 if let Some(token) = authorization.strip_prefix("Bearer ") {
320 self.bearer_tokens.contains(&token.to_string())
321 } else {
322 false
323 }
324 }
325}
326
327#[async_trait]
328impl Middleware for AuthMiddleware {
329 async fn call(&self, req: Request, next: Next) -> Result<Response> {
330 if self.require_auth {
331 if let Some(auth_header) = req.headers.get("authorization") {
332 if !self.bearer_tokens.is_empty() && !self.validate_token(auth_header) {
333 return Ok(Response::new(StatusCode::UNAUTHORIZED)
334 .header("Content-Type", "application/json")
335 .body(r#"{"error": "Invalid token"}"#));
336 }
337 } else {
338 return Ok(Response::new(StatusCode::UNAUTHORIZED)
339 .header("Content-Type", "application/json")
340 .body(r#"{"error": "Authentication required"}"#));
341 }
342 }
343
344 next.run(req).await
345 }
346}
347
348pub struct CompressionMiddleware {
350 pub enabled: bool,
351 pub min_size: usize,
352}
353
354impl Default for CompressionMiddleware {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360impl CompressionMiddleware {
361 pub fn new() -> Self {
362 Self {
363 enabled: true,
364 min_size: 1024, }
366 }
367
368 pub fn min_size(mut self, size: usize) -> Self {
369 self.min_size = size;
370 self
371 }
372}
373
374#[async_trait]
375impl Middleware for CompressionMiddleware {
376 async fn call(&self, req: Request, next: Next) -> Result<Response> {
377 let response = next.run(req).await?;
378
379 if !self.enabled {
380 return Ok(response);
381 }
382
383 if response.body.len() >= self.min_size {
385 let compressed_response = response
387 .header("Content-Encoding", "gzip")
388 .header("Vary", "Accept-Encoding");
389 Ok(compressed_response)
390 } else {
391 Ok(response)
392 }
393 }
394}
395
396pub struct SecurityHeadersMiddleware {
398 pub add_hsts: bool,
399 pub add_frame_options: bool,
400 pub add_content_type_options: bool,
401 pub add_xss_protection: bool,
402}
403
404impl Default for SecurityHeadersMiddleware {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl SecurityHeadersMiddleware {
411 pub fn new() -> Self {
412 Self {
413 add_hsts: true,
414 add_frame_options: true,
415 add_content_type_options: true,
416 add_xss_protection: true,
417 }
418 }
419
420 pub fn with_hsts(mut self, enabled: bool) -> Self {
421 self.add_hsts = enabled;
422 self
423 }
424
425 pub fn with_frame_options(mut self, enabled: bool) -> Self {
426 self.add_frame_options = enabled;
427 self
428 }
429}
430
431#[async_trait]
432impl Middleware for SecurityHeadersMiddleware {
433 async fn call(&self, req: Request, next: Next) -> Result<Response> {
434 let mut response = next.run(req).await?;
435
436 if self.add_hsts {
437 response = response.header(
438 "Strict-Transport-Security",
439 "max-age=31536000; includeSubDomains",
440 );
441 }
442
443 if self.add_frame_options {
444 response = response.header("X-Frame-Options", "DENY");
445 }
446
447 if self.add_content_type_options {
448 response = response.header("X-Content-Type-Options", "nosniff");
449 }
450
451 if self.add_xss_protection {
452 response = response.header("X-XSS-Protection", "1; mode=block");
453 }
454
455 Ok(response)
456 }
457}
458
459pub struct MetricsMiddleware {
461 pub enabled: bool,
462 pub collect_timing: bool,
463 pub collect_errors: bool,
464 pub request_count: Arc<Mutex<u64>>,
465 pub error_count: Arc<Mutex<u64>>,
466 pub total_duration: Arc<Mutex<Duration>>,
467}
468
469impl Default for MetricsMiddleware {
470 fn default() -> Self {
471 Self::new()
472 }
473}
474
475impl MetricsMiddleware {
476 pub fn new() -> Self {
477 Self {
478 enabled: true,
479 collect_timing: true,
480 collect_errors: true,
481 request_count: Arc::new(Mutex::new(0)),
482 error_count: Arc::new(Mutex::new(0)),
483 total_duration: Arc::new(Mutex::new(Duration::ZERO)),
484 }
485 }
486
487 pub fn get_stats(&self) -> (u64, u64, Duration) {
488 let req_count = *self.request_count.lock().unwrap();
489 let err_count = *self.error_count.lock().unwrap();
490 let total_dur = *self.total_duration.lock().unwrap();
491 (req_count, err_count, total_dur)
492 }
493}
494
495#[async_trait]
496impl Middleware for MetricsMiddleware {
497 async fn call(&self, req: Request, next: Next) -> Result<Response> {
498 if !self.enabled {
499 return next.run(req).await;
500 }
501
502 let start = if self.collect_timing {
503 Some(Instant::now())
504 } else {
505 None
506 };
507
508 *self.request_count.lock().unwrap() += 1;
510
511 let result = next.run(req).await;
512
513 if let Some(start_time) = start {
515 let duration = start_time.elapsed();
516 *self.total_duration.lock().unwrap() += duration;
517 }
518
519 if self.collect_errors && result.is_err() {
521 *self.error_count.lock().unwrap() += 1;
522 }
523
524 result
525 }
526}
527
528pub struct CacheMiddleware {
530 pub enabled: bool,
531 pub cache_duration: Duration,
532 pub cache: Arc<Mutex<HashMap<String, (Response, SystemTime)>>>,
533}
534
535impl CacheMiddleware {
536 pub fn new(cache_duration: Duration) -> Self {
537 Self {
538 enabled: true,
539 cache_duration,
540 cache: Arc::new(Mutex::new(HashMap::new())),
541 }
542 }
543
544 fn cache_key(&self, req: &Request) -> String {
545 format!("{}:{}", req.method.as_str(), req.path())
546 }
547
548 fn get_cached(&self, key: &str) -> Option<Response> {
549 let mut cache = self.cache.lock().unwrap();
550
551 if let Some((response, timestamp)) = cache.get(key) {
552 let now = SystemTime::now();
553 if now.duration_since(*timestamp).unwrap_or(Duration::MAX) < self.cache_duration {
554 return Some(response.clone());
555 } else {
556 cache.remove(key);
557 }
558 }
559
560 None
561 }
562
563 fn cache_response(&self, key: String, response: &Response) {
564 if response.status.0 == 200 {
565 let mut cache = self.cache.lock().unwrap();
566 cache.insert(key, (response.clone(), SystemTime::now()));
567 }
568 }
569}
570
571#[async_trait]
572impl Middleware for CacheMiddleware {
573 async fn call(&self, req: Request, next: Next) -> Result<Response> {
574 if !self.enabled || req.method != crate::types::HttpMethod::GET {
575 return next.run(req).await;
576 }
577
578 let cache_key = self.cache_key(&req);
579
580 if let Some(cached_response) = self.get_cached(&cache_key) {
582 return Ok(cached_response.header("X-Cache", "HIT"));
583 }
584
585 let response = next.run(req).await?;
587
588 self.cache_response(cache_key, &response);
590
591 Ok(response.header("X-Cache", "MISS"))
592 }
593}
594
595pub struct PathParameterMiddleware {
598 route_patterns: Vec<(String, crate::types::HttpMethod)>,
599}
600
601impl PathParameterMiddleware {
602 pub fn new(route_patterns: Vec<(String, crate::types::HttpMethod)>) -> Self {
603 Self { route_patterns }
604 }
605
606 fn match_dynamic_path(
608 &self,
609 pattern: &str,
610 path: &str,
611 ) -> Option<std::collections::HashMap<String, String>> {
612 let route_parts: Vec<&str> = pattern.split('/').collect();
613 let path_parts: Vec<&str> = path.split('/').collect();
614
615 if route_parts.len() != path_parts.len() {
616 if let Some(last_part) = route_parts.last()
618 && last_part.starts_with('*')
619 && route_parts.len() <= path_parts.len()
620 {
621 let mut params = std::collections::HashMap::new();
623 let param_name = last_part.trim_start_matches('*');
624 if !param_name.is_empty() {
625 let remaining_path = path_parts[route_parts.len() - 1..].join("/");
626 params.insert(param_name.to_string(), remaining_path);
627 }
628 return Some(params);
629 }
630 return None;
631 }
632
633 let mut params = std::collections::HashMap::new();
634
635 for (route_part, path_part) in route_parts.iter().zip(path_parts.iter()) {
636 if route_part.starts_with(':') {
637 let param_name = route_part.trim_start_matches(':');
639 params.insert(param_name.to_string(), path_part.to_string());
640 } else if route_part.starts_with('*') {
641 let param_name = route_part.trim_start_matches('*');
643 if !param_name.is_empty() {
644 params.insert(param_name.to_string(), path_part.to_string());
645 }
646 } else if route_part != path_part {
647 return None;
649 }
650 }
651
652 Some(params)
653 }
654}
655
656#[async_trait]
657impl Middleware for PathParameterMiddleware {
658 async fn call(&self, mut req: Request, next: Next) -> Result<Response> {
659 for (pattern, method) in &self.route_patterns {
661 if *method == req.method
662 && let Some(params) = self.match_dynamic_path(pattern, req.path())
663 {
664 req.set_params(params);
665 break;
666 }
667 }
668
669 next.run(req).await
670 }
671}
672
673pub struct TransformMiddleware<F, G>
676where
677 F: Fn(Request) -> Request + Send + Sync + 'static,
678 G: Fn(Response) -> Response + Send + Sync + 'static,
679{
680 request_transform: F,
681 response_transform: G,
682}
683
684impl<F, G> TransformMiddleware<F, G>
685where
686 F: Fn(Request) -> Request + Send + Sync + 'static,
687 G: Fn(Response) -> Response + Send + Sync + 'static,
688{
689 pub fn new(request_transform: F, response_transform: G) -> Self {
690 Self {
691 request_transform,
692 response_transform,
693 }
694 }
695}
696
697#[async_trait]
698impl<F, G> Middleware for TransformMiddleware<F, G>
699where
700 F: Fn(Request) -> Request + Send + Sync + 'static,
701 G: Fn(Response) -> Response + Send + Sync + 'static,
702{
703 async fn call(&self, req: Request, next: Next) -> Result<Response> {
704 let transformed_req = (self.request_transform)(req);
705 let response = next.run(transformed_req).await?;
706 Ok((self.response_transform)(response))
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use crate::types::{Body, Headers, Request, Response, StatusCode};
713
714 #[tokio::test]
715 async fn test_middleware_chain() {
716 let _handler = move |_req: Request| async move {
718 Ok::<Response, crate::error::WebServerError>(Response {
719 status: StatusCode::OK,
720 headers: Headers::new(),
721 body: Body::from("test response"),
722 })
723 };
724
725 }
729}