1use crate::config::SecurityConfig;
14use crate::core::{Middleware, Next};
15use crate::error::{Result, WebServerError};
16use crate::types::{Request, Response};
17use async_trait::async_trait;
18use sha1::{Digest, Sha1};
19use std::collections::HashMap;
20use std::sync::{Arc, Mutex, RwLock};
21use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
22use tracing::{error, info};
23use uuid::Uuid;
24
25#[cfg(feature = "security")]
26use rustls::{
27 ServerConfig,
28 pki_types::{CertificateDer, PrivatePkcs8KeyDer},
29};
30#[cfg(feature = "security")]
31use rustls_pemfile as pemfile;
32
33pub struct SecurityContext {
35 config: SecurityConfig,
36 csrf_tokens: Arc<RwLock<HashMap<String, CsrfToken>>>,
37 request_signatures: Arc<Mutex<HashMap<String, RequestSignature>>>,
38 security_monitor: SecurityMonitor,
39 rate_limiter: RateLimiter,
40}
41
42impl SecurityContext {
43 pub fn new(config: SecurityConfig) -> Self {
44 Self {
45 config,
46 csrf_tokens: Arc::new(RwLock::new(HashMap::new())),
47 request_signatures: Arc::new(Mutex::new(HashMap::new())),
48 security_monitor: SecurityMonitor::new(),
49 rate_limiter: RateLimiter::new(100, Duration::from_secs(60)), }
51 }
52
53 pub async fn validate_request(
55 &self,
56 request: &mut Request,
57 ) -> Result<SecurityValidationResult> {
58 let mut issues = Vec::new();
59 let client_ip = self.extract_client_ip(request);
60
61 if !self.rate_limiter.allow_request(&client_ip) {
63 let event = SecurityEvent {
64 timestamp: Instant::now(),
65 event_type: SecurityEventType::RateLimitExceeded,
66 severity: SecuritySeverity::Medium,
67 source_ip: client_ip.clone(),
68 details: "Rate limit exceeded".to_string(),
69 };
70 self.security_monitor.log_event(event);
71 issues.push(SecurityIssue::RateLimitExceeded);
72 }
73
74 if let Some(malicious_header) = self.check_malicious_headers(&request.headers) {
76 let event = SecurityEvent {
77 timestamp: Instant::now(),
78 event_type: SecurityEventType::SuspiciousRequest,
79 severity: SecuritySeverity::Medium,
80 source_ip: client_ip.clone(),
81 details: format!("Malicious header detected: {}", malicious_header),
82 };
83 self.security_monitor.log_event(event);
84 issues.push(SecurityIssue::MaliciousHeader(malicious_header));
85 }
86
87 if let Some(content_length) = request.headers.get("content-length") {
89 if let Ok(size) = content_length.parse::<usize>() {
90 if size > MAX_REQUEST_SIZE {
91 issues.push(SecurityIssue::RequestTooLarge(size));
92 }
93 }
94 }
95
96 let uri_path = request.uri.path();
98 if self.contains_sql_injection_patterns(uri_path) {
99 let event = SecurityEvent {
100 timestamp: Instant::now(),
101 event_type: SecurityEventType::SqlInjectionAttempt,
102 severity: SecuritySeverity::Critical,
103 source_ip: client_ip.clone(),
104 details: format!("SQL injection attempt in path: {}", uri_path),
105 };
106 self.security_monitor.log_event(event);
107 issues.push(SecurityIssue::SqlInjectionAttempt(uri_path.to_string()));
108 }
109
110 if self.contains_xss_patterns(uri_path) {
112 let event = SecurityEvent {
113 timestamp: Instant::now(),
114 event_type: SecurityEventType::XssAttempt,
115 severity: SecuritySeverity::High,
116 source_ip: client_ip.clone(),
117 details: format!("XSS attempt in path: {}", uri_path),
118 };
119 self.security_monitor.log_event(event);
120 issues.push(SecurityIssue::XssAttempt(uri_path.to_string()));
121 }
122
123 if self.config.enable_csrf_protection {
125 if let Err(csrf_issue) = self.validate_csrf_token(request).await {
126 let event = SecurityEvent {
127 timestamp: Instant::now(),
128 event_type: SecurityEventType::CsrfTokenValidation,
129 severity: SecuritySeverity::Medium,
130 source_ip: client_ip.clone(),
131 details: "CSRF token validation failed".to_string(),
132 };
133 self.security_monitor.log_event(event);
134 issues.push(csrf_issue);
135 }
136 }
137
138 if let Err(replay_issue) = self.check_replay_attack(request).await {
140 issues.push(replay_issue);
141 }
142
143 Ok(SecurityValidationResult {
144 is_valid: issues.is_empty(),
145 issues,
146 })
147 }
148
149 fn extract_client_ip(&self, request: &Request) -> String {
151 if let Some(forwarded) = request.headers.get("x-forwarded-for") {
153 if let Some(ip) = forwarded.split(',').next() {
154 return ip.trim().to_string();
155 }
156 }
157
158 if let Some(real_ip) = request.headers.get("x-real-ip") {
160 return real_ip.clone();
161 }
162
163 request
165 .headers
166 .get("remote-addr")
167 .cloned()
168 .unwrap_or_else(|| "unknown".to_string())
169 }
170
171 pub async fn generate_csrf_token(&self, session_id: &str) -> String {
173 let token = CsrfToken::new(Duration::from_hours(1));
174 let token_value = token.token.clone();
175
176 {
177 let mut tokens = self.csrf_tokens.write().unwrap();
178 tokens.insert(session_id.to_string(), token);
179
180 tokens.retain(|_, token| !token.is_expired());
182 }
183
184 token_value
185 }
186
187 async fn validate_csrf_token(
189 &self,
190 request: &Request,
191 ) -> std::result::Result<(), SecurityIssue> {
192 if matches!(
194 request.method,
195 crate::types::HttpMethod::GET
196 | crate::types::HttpMethod::HEAD
197 | crate::types::HttpMethod::OPTIONS
198 ) {
199 return Ok(());
200 }
201
202 let token = request
204 .headers
205 .get("x-csrf-token")
206 .or_else(|| request.headers.get("csrf-token"))
207 .cloned();
208
209 let session_id = request
210 .headers
211 .get("session-id")
212 .or_else(|| request.headers.get("authorization"))
213 .cloned();
214
215 match (token, session_id) {
216 (Some(token), Some(session_id)) => {
217 let tokens = self.csrf_tokens.read().unwrap();
218 if let Some(stored_token) = tokens.get(&session_id) {
219 if !stored_token.is_expired() && stored_token.token == token {
220 Ok(())
221 } else {
222 Err(SecurityIssue::InvalidCsrfToken)
223 }
224 } else {
225 Err(SecurityIssue::MissingCsrfToken)
226 }
227 }
228 _ => Err(SecurityIssue::MissingCsrfToken),
229 }
230 }
231
232 async fn check_replay_attack(
234 &self,
235 request: &Request,
236 ) -> std::result::Result<(), SecurityIssue> {
237 let signature = self.generate_request_signature(request);
238 let mut signatures = self.request_signatures.lock().unwrap();
239
240 if let Some(existing) = signatures.get(&signature) {
242 if existing.timestamp.elapsed() < Duration::from_secs(5 * 60) {
243 return Err(SecurityIssue::ReplayAttack);
244 }
245 }
246
247 signatures.insert(
249 signature.clone(),
250 RequestSignature {
251 signature,
252 timestamp: Instant::now(),
253 },
254 );
255
256 let one_hour_ago = Instant::now() - Duration::from_hours(1);
258 signatures.retain(|_, sig| sig.timestamp > one_hour_ago);
259
260 Ok(())
261 }
262
263 fn generate_request_signature(&self, request: &Request) -> String {
265 let mut hasher = Sha1::new();
266 hasher.update(request.method.to_string().as_bytes());
267 hasher.update(request.uri.path().as_bytes());
268
269 let timestamp = SystemTime::now()
271 .duration_since(UNIX_EPOCH)
272 .unwrap()
273 .as_secs()
274 / 60; hasher.update(timestamp.to_string().as_bytes());
276
277 if let Some(auth) = request.headers.get("authorization") {
279 hasher.update(auth.as_bytes());
280 }
281
282 let result = hasher.finalize();
283 hex::encode(result)
284 }
285
286 fn check_malicious_headers(&self, headers: &crate::types::Headers) -> Option<String> {
288 let malicious_patterns = [
289 "eval(",
290 "javascript:",
291 "<script",
292 "data:text/html",
293 "../",
294 "..\\",
295 "union select",
296 "drop table",
297 ];
298
299 for (name, value) in headers.iter() {
300 let combined = format!("{}: {}", name, value).to_lowercase();
301 for pattern in &malicious_patterns {
302 if combined.contains(pattern) {
303 return Some(format!("{}:{}", name, value));
304 }
305 }
306 }
307 None
308 }
309
310 fn contains_sql_injection_patterns(&self, input: &str) -> bool {
312 let sql_patterns = [
313 "union select",
314 "drop table",
315 "delete from",
316 "insert into",
317 "update set",
318 "or 1=1",
319 "and 1=1",
320 "' or '",
321 "\" or \"",
322 "; --",
323 "/*",
324 "*/",
325 "xp_",
326 "sp_",
327 "exec(",
328 "execute(",
329 ];
330
331 let input_lower = input.to_lowercase();
332 sql_patterns
333 .iter()
334 .any(|pattern| input_lower.contains(pattern))
335 }
336
337 fn contains_xss_patterns(&self, input: &str) -> bool {
339 let xss_patterns = [
340 "<script",
341 "</script>",
342 "javascript:",
343 "onload=",
344 "onerror=",
345 "onclick=",
346 "onmouseover=",
347 "data:text/html",
348 "eval(",
349 "expression(",
350 "url(javascript:",
351 "vbscript:",
352 ];
353
354 let input_lower = input.to_lowercase();
355 xss_patterns
356 .iter()
357 .any(|pattern| input_lower.contains(pattern))
358 }
359
360 pub fn add_security_headers(&self, response: &mut Response) {
362 response
364 .headers
365 .insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
366 response
367 .headers
368 .insert("X-Frame-Options".to_string(), "DENY".to_string());
369 response
370 .headers
371 .insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
372 response.headers.insert(
373 "Referrer-Policy".to_string(),
374 "strict-origin-when-cross-origin".to_string(),
375 );
376
377 let csp = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self'; connect-src 'self'; frame-ancestors 'none'";
379 response
380 .headers
381 .insert("Content-Security-Policy".to_string(), csp.to_string());
382
383 if self.config.tls.enabled {
385 response.headers.insert(
386 "Strict-Transport-Security".to_string(),
387 "max-age=31536000; includeSubDomains; preload".to_string(),
388 );
389 }
390
391 response.headers.insert(
393 "Permissions-Policy".to_string(),
394 "geolocation=(), microphone=(), camera=()".to_string(),
395 );
396 response.headers.insert(
397 "Cross-Origin-Embedder-Policy".to_string(),
398 "require-corp".to_string(),
399 );
400 response.headers.insert(
401 "Cross-Origin-Opener-Policy".to_string(),
402 "same-origin".to_string(),
403 );
404
405 response.headers.insert(
407 "Cache-Control".to_string(),
408 "no-store, no-cache, must-revalidate, private".to_string(),
409 );
410 response
411 .headers
412 .insert("Pragma".to_string(), "no-cache".to_string());
413 }
414
415 pub fn get_security_stats(&self) -> SecurityStats {
417 self.security_monitor.get_security_stats()
418 }
419
420 pub fn get_recent_events(&self, since: Instant) -> Vec<SecurityEvent> {
422 self.security_monitor.get_events_since(since)
423 }
424}
425
426#[derive(Debug, Clone)]
428struct CsrfToken {
429 token: String,
430 #[allow(dead_code)]
431 created_at: Instant,
432 expires_at: Instant,
433}
434
435impl CsrfToken {
436 fn new(duration: Duration) -> Self {
437 let now = Instant::now();
438 let token = generate_secure_token();
439 Self {
440 token,
441 created_at: now,
442 expires_at: now + duration,
443 }
444 }
445
446 fn is_expired(&self) -> bool {
447 Instant::now() > self.expires_at
448 }
449}
450
451#[derive(Debug, Clone)]
453struct RequestSignature {
454 #[allow(dead_code)]
455 signature: String,
456 timestamp: Instant,
457}
458
459struct RateLimiter {
461 requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
462 max_requests: usize,
463 window: Duration,
464}
465
466impl RateLimiter {
467 fn new(max_requests: usize, window: Duration) -> Self {
468 Self {
469 requests: Arc::new(Mutex::new(HashMap::new())),
470 max_requests,
471 window,
472 }
473 }
474
475 fn allow_request(&self, client_ip: &str) -> bool {
476 let mut requests = self.requests.lock().unwrap();
477 let now = Instant::now();
478
479 let client_requests = requests.entry(client_ip.to_string()).or_default();
481
482 client_requests.retain(|×tamp| now.duration_since(timestamp) <= self.window);
484
485 if client_requests.len() < self.max_requests {
487 client_requests.push(now);
488 true
489 } else {
490 false
491 }
492 }
493}
494
495pub struct SecurityMonitor {
497 events: Arc<Mutex<Vec<SecurityEvent>>>,
498}
499
500impl SecurityMonitor {
501 pub fn new() -> Self {
502 Self {
503 events: Arc::new(Mutex::new(Vec::new())),
504 }
505 }
506
507 pub fn log_event(&self, event: SecurityEvent) {
509 info!("Security event: {:?}", event);
510
511 let mut events = self.events.lock().unwrap();
512 events.push(event);
513
514 if events.len() > 1000 {
516 let excess = events.len() - 1000;
517 events.drain(0..excess);
518 }
519 }
520
521 pub fn get_events_since(&self, since: Instant) -> Vec<SecurityEvent> {
523 let events = self.events.lock().unwrap();
524 events
525 .iter()
526 .filter(|event| event.timestamp > since)
527 .cloned()
528 .collect()
529 }
530
531 pub fn get_security_stats(&self) -> SecurityStats {
533 let events = self.events.lock().unwrap();
534 let total_events = events.len();
535
536 let mut stats_by_severity = HashMap::new();
537 for event in events.iter() {
538 let counter = stats_by_severity.entry(event.severity.clone()).or_insert(0);
539 *counter += 1;
540 }
541
542 SecurityStats {
543 total_events,
544 events_by_severity: stats_by_severity,
545 last_event: events.last().map(|e| e.timestamp),
546 }
547 }
548}
549
550impl Default for SecurityMonitor {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556#[derive(Debug, Clone)]
558pub struct SecurityEvent {
559 pub timestamp: Instant,
560 pub event_type: SecurityEventType,
561 pub severity: SecuritySeverity,
562 pub source_ip: String,
563 pub details: String,
564}
565
566#[derive(Debug, Clone)]
567pub enum SecurityEventType {
568 SqlInjectionAttempt,
569 XssAttempt,
570 CsrfTokenValidation,
571 RateLimitExceeded,
572 SuspiciousRequest,
573 AuthenticationFailure,
574 AuthorizationFailure,
575}
576
577#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
579pub enum SecuritySeverity {
580 Low,
581 Medium,
582 High,
583 Critical,
584}
585
586#[derive(Debug)]
588pub struct SecurityValidationResult {
589 pub is_valid: bool,
590 pub issues: Vec<SecurityIssue>,
591}
592
593#[derive(Debug, Clone)]
595pub enum SecurityIssue {
596 MaliciousHeader(String),
597 RequestTooLarge(usize),
598 SqlInjectionAttempt(String),
599 XssAttempt(String),
600 InvalidCsrfToken,
601 MissingCsrfToken,
602 ReplayAttack,
603 RateLimitExceeded,
604 InvalidSignature,
605}
606
607impl SecurityIssue {
608 pub fn severity(&self) -> SecuritySeverity {
609 match self {
610 SecurityIssue::SqlInjectionAttempt(_) => SecuritySeverity::Critical,
611 SecurityIssue::XssAttempt(_) => SecuritySeverity::High,
612 SecurityIssue::ReplayAttack => SecuritySeverity::High,
613 SecurityIssue::InvalidCsrfToken => SecuritySeverity::Medium,
614 SecurityIssue::MaliciousHeader(_) => SecuritySeverity::Medium,
615 SecurityIssue::RateLimitExceeded => SecuritySeverity::Low,
616 SecurityIssue::RequestTooLarge(_) => SecuritySeverity::Low,
617 SecurityIssue::MissingCsrfToken => SecuritySeverity::Low,
618 SecurityIssue::InvalidSignature => SecuritySeverity::Medium,
619 }
620 }
621
622 pub fn to_http_status(&self) -> crate::types::StatusCode {
623 match self {
624 SecurityIssue::SqlInjectionAttempt(_) | SecurityIssue::XssAttempt(_) => {
625 crate::types::StatusCode::BAD_REQUEST
626 }
627 SecurityIssue::InvalidCsrfToken | SecurityIssue::MissingCsrfToken => {
628 crate::types::StatusCode::FORBIDDEN
629 }
630 SecurityIssue::ReplayAttack => crate::types::StatusCode::CONFLICT,
631 SecurityIssue::RateLimitExceeded => crate::types::StatusCode::TOO_MANY_REQUESTS,
632 SecurityIssue::RequestTooLarge(_) => crate::types::StatusCode::PAYLOAD_TOO_LARGE,
633 SecurityIssue::MaliciousHeader(_) => crate::types::StatusCode::BAD_REQUEST,
634 SecurityIssue::InvalidSignature => crate::types::StatusCode::UNAUTHORIZED,
635 }
636 }
637}
638
639#[derive(Debug)]
641pub struct SecurityStats {
642 pub total_events: usize,
643 pub events_by_severity: HashMap<SecuritySeverity, usize>,
644 pub last_event: Option<Instant>,
645}
646
647#[cfg(feature = "security")]
649pub struct TlsManager {
650 config: SecurityConfig,
651}
652
653#[cfg(feature = "security")]
654impl TlsManager {
655 pub fn new(config: SecurityConfig) -> Self {
656 Self { config }
657 }
658
659 pub fn load_tls_config(&self) -> Result<ServerConfig> {
661 if !self.config.tls.enabled {
662 return Err(WebServerError::custom("TLS not enabled"));
663 }
664
665 let cert_path = &self.config.tls.cert_file;
666 let key_path = &self.config.tls.key_file;
667
668 let cert_file = std::fs::File::open(cert_path).map_err(|e| {
670 WebServerError::custom(format!("Failed to open certificate file: {}", e))
671 })?;
672 let certs: Vec<CertificateDer> =
673 rustls_pemfile::certs(&mut std::io::BufReader::new(cert_file))
674 .map(|cert| {
675 cert.map_err(|e| {
676 WebServerError::custom(format!("Failed to parse certificate: {}", e))
677 })
678 })
679 .collect::<crate::error::Result<Vec<_>>>()?;
680
681 let key_file = std::fs::File::open(key_path).map_err(|e| {
683 WebServerError::custom(format!("Failed to open private key file: {}", e))
684 })?;
685 let keys: Vec<PrivatePkcs8KeyDer> =
686 rustls_pemfile::pkcs8_private_keys(&mut std::io::BufReader::new(key_file))
687 .map(|key| {
688 key.map_err(|e| {
689 WebServerError::custom(format!("Failed to parse private key: {}", e))
690 })
691 })
692 .collect::<crate::error::Result<Vec<_>>>()?;
693
694 let private_key = keys
695 .into_iter()
696 .next()
697 .ok_or_else(|| WebServerError::custom("No private key found"))?;
698
699 let config = ServerConfig::builder()
701 .with_no_client_auth()
702 .with_single_cert(certs, private_key.into())
703 .map_err(|e| WebServerError::custom(format!("TLS configuration error: {}", e)))?;
704
705 Ok(config)
706 }
707}
708
709const MAX_REQUEST_SIZE: usize = 100 * 1024 * 1024; fn generate_secure_token() -> String {
714 use std::collections::hash_map::DefaultHasher;
715 use std::hash::{Hash, Hasher};
716
717 let mut hasher = DefaultHasher::new();
718 SystemTime::now().hash(&mut hasher);
719 std::thread::current().id().hash(&mut hasher);
720
721 let random_value = hasher.finish();
722 format!("{:x}{}", random_value, Uuid::new_v4().simple())
723}
724
725pub struct SecurityMiddleware {
727 context: Arc<SecurityContext>,
728}
729
730impl SecurityMiddleware {
731 pub fn new(config: SecurityConfig) -> Self {
732 Self {
733 context: Arc::new(SecurityContext::new(config)),
734 }
735 }
736
737 pub fn get_context(&self) -> Arc<SecurityContext> {
738 self.context.clone()
739 }
740}
741
742#[async_trait]
743impl Middleware for SecurityMiddleware {
744 async fn call(&self, mut request: Request, next: Next) -> crate::Result<Response> {
745 let validation_result = self.context.validate_request(&mut request).await?;
747
748 if !validation_result.is_valid {
749 let most_severe = validation_result
751 .issues
752 .iter()
753 .max_by_key(|issue| issue.severity())
754 .unwrap();
755
756 let status = most_severe.to_http_status();
757 let message = format!("Security validation failed: {:?}", most_severe);
758
759 error!("Security validation failed for request: {}", message);
760
761 let mut response = Response::new(status).body(message);
762 self.context.add_security_headers(&mut response);
763 return Ok(response);
764 }
765
766 let mut response = next.run(request).await?;
768
769 self.context.add_security_headers(&mut response);
771
772 Ok(response)
773 }
774}
775
776#[derive(Debug)]
778pub struct CsrfMiddleware {
779 secret_key: String,
780 token_name: String,
781 cookie_name: String,
782 header_name: String,
783 exclude_paths: Vec<String>,
784 token_store: Arc<RwLock<HashMap<String, (String, SystemTime)>>>,
785 token_lifetime: Duration,
786}
787impl CsrfMiddleware {
788 pub fn new(secret_key: String) -> Self {
789 Self {
790 secret_key,
791 token_name: "csrf_token".to_string(),
792 cookie_name: "csrf_token".to_string(),
793 header_name: "X-CSRF-Token".to_string(),
794 exclude_paths: vec![],
795 token_store: Arc::new(RwLock::new(HashMap::new())),
796 token_lifetime: Duration::from_secs(3600), }
798 }
799
800 pub fn token_name(mut self, name: String) -> Self {
802 self.token_name = name;
803 self
804 }
805
806 pub fn cookie_name(mut self, name: String) -> Self {
808 self.cookie_name = name;
809 self
810 }
811
812 pub fn header_name(mut self, name: String) -> Self {
814 self.header_name = name;
815 self
816 }
817
818 pub fn exclude_path(mut self, path: String) -> Self {
820 self.exclude_paths.push(path);
821 self
822 }
823
824 pub fn token_lifetime(mut self, lifetime: Duration) -> Self {
826 self.token_lifetime = lifetime;
827 self
828 }
829
830 fn generate_token(&self, session_id: &str) -> String {
832 let timestamp = SystemTime::now()
833 .duration_since(SystemTime::UNIX_EPOCH)
834 .unwrap()
835 .as_secs();
836
837 let raw_token = format!("{}:{}:{}", session_id, timestamp, self.secret_key);
838 let mut hasher = Sha1::new();
839 hasher.update(raw_token.as_bytes());
840 let hash = hasher.finalize();
841
842 format!("{}:{}", timestamp, hex::encode(hash))
843 }
844
845 fn validate_token(&self, token: &str, session_id: &str) -> bool {
847 let parts: Vec<&str> = token.split(':').collect();
848 if parts.len() != 2 {
849 return false;
850 }
851
852 let timestamp_str = parts[0];
853 let hash_str = parts[1];
854
855 if let Ok(timestamp) = timestamp_str.parse::<u64>() {
856 let token_time = SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp);
857 let now = SystemTime::now();
858
859 if now.duration_since(token_time).unwrap_or(Duration::MAX) > self.token_lifetime {
861 return false;
862 }
863
864 let raw_token = format!("{}:{}:{}", session_id, timestamp, self.secret_key);
866 let mut hasher = Sha1::new();
867 hasher.update(raw_token.as_bytes());
868 let expected_hash = hex::encode(hasher.finalize());
869
870 return hash_str == expected_hash;
871 }
872
873 false
874 }
875
876 fn cleanup_expired_tokens(&self) {
878 let mut store = self.token_store.write().unwrap();
879 let now = SystemTime::now();
880 store.retain(|_, (_, created_at)| {
881 now.duration_since(*created_at).unwrap_or(Duration::MAX) <= self.token_lifetime
882 });
883 }
884}
885
886#[async_trait]
887impl Middleware for CsrfMiddleware {
888 async fn call(&self, mut request: Request, next: Next) -> crate::Result<Response> {
889 let path = request.uri.path();
890
891 if self.exclude_paths.iter().any(|p| path.starts_with(p)) {
893 return next.run(request).await;
894 }
895
896 self.cleanup_expired_tokens();
898
899 if matches!(
901 request.method,
902 crate::types::HttpMethod::GET
903 | crate::types::HttpMethod::HEAD
904 | crate::types::HttpMethod::OPTIONS
905 ) {
906 let session_id = request
908 .cookie("session_id")
909 .map(|c| c.value.clone())
910 .unwrap_or_else(|| Uuid::new_v4().to_string());
911
912 let token = self.generate_token(&session_id);
913
914 {
916 let mut store = self.token_store.write().unwrap();
917 store.insert(session_id.clone(), (token.clone(), SystemTime::now()));
918 }
919
920 request
922 .extensions
923 .insert("csrf_token".to_string(), token.clone());
924
925 let mut response = next.run(request).await?;
926
927 response.headers.insert("X-CSRF-Token".to_string(), token);
929
930 return Ok(response);
931 }
932
933 let session_id = request
935 .cookie("session_id")
936 .map(|c| c.value.clone())
937 .unwrap_or_default();
938
939 if session_id.is_empty() {
940 return Ok(
941 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Missing session")
942 );
943 }
944
945 let token = request.headers.get(&self.header_name).cloned().or_else(|| {
947 request.form(&self.token_name).map(|s| s.to_string())
949 });
950
951 let token = match token {
952 Some(t) => t,
953 None => {
954 return Ok(
955 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Missing token")
956 );
957 }
958 };
959
960 if !self.validate_token(&token, &session_id) {
962 return Ok(
963 Response::new(crate::types::StatusCode::FORBIDDEN).body("CSRF: Invalid token")
964 );
965 }
966
967 next.run(request).await
968 }
969}
970
971#[derive(Debug)]
973pub struct XssProtectionMiddleware {
974 enable_filtering: bool,
975 block_mode: bool,
976}
977
978impl XssProtectionMiddleware {
979 pub fn new() -> Self {
981 Self {
982 enable_filtering: true,
983 block_mode: true,
984 }
985 }
986
987 pub fn filtering(mut self, enable: bool) -> Self {
989 self.enable_filtering = enable;
990 self
991 }
992
993 pub fn block_mode(mut self, block: bool) -> Self {
995 self.block_mode = block;
996 self
997 }
998}
999
1000impl Default for XssProtectionMiddleware {
1001 fn default() -> Self {
1002 Self::new()
1003 }
1004}
1005
1006#[async_trait]
1007impl Middleware for XssProtectionMiddleware {
1008 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
1009 let mut response = next.run(request).await?;
1010
1011 if self.enable_filtering {
1013 let header_value = if self.block_mode {
1014 "1; mode=block"
1015 } else {
1016 "1"
1017 };
1018 response
1019 .headers
1020 .insert("X-XSS-Protection".to_string(), header_value.to_string());
1021 } else {
1022 response
1023 .headers
1024 .insert("X-XSS-Protection".to_string(), "0".to_string());
1025 }
1026
1027 response
1029 .headers
1030 .insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
1031
1032 response
1034 .headers
1035 .insert("X-Frame-Options".to_string(), "DENY".to_string());
1036
1037 Ok(response)
1038 }
1039}
1040
1041#[derive(Debug)]
1043pub struct CspMiddleware {
1044 directives: HashMap<String, Vec<String>>,
1045 report_only: bool,
1046}
1047
1048impl CspMiddleware {
1049 pub fn new() -> Self {
1051 Self {
1052 directives: HashMap::new(),
1053 report_only: false,
1054 }
1055 }
1056
1057 pub fn default_policy() -> Self {
1059 let mut csp = Self::new();
1060 csp.directive("default-src", vec!["'self'".to_string()]);
1061 csp.directive(
1062 "script-src",
1063 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
1064 );
1065 csp.directive(
1066 "style-src",
1067 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
1068 );
1069 csp.directive("img-src", vec!["'self'".to_string(), "data:".to_string()]);
1070 csp.directive("font-src", vec!["'self'".to_string()]);
1071 csp.directive("connect-src", vec!["'self'".to_string()]);
1072 csp.directive("frame-ancestors", vec!["'none'".to_string()]);
1073 csp
1074 }
1075
1076 pub fn directive(&mut self, name: &str, values: Vec<String>) -> &mut Self {
1078 self.directives.insert(name.to_string(), values);
1079 self
1080 }
1081
1082 pub fn report_only(mut self, report_only: bool) -> Self {
1084 self.report_only = report_only;
1085 self
1086 }
1087
1088 fn build_header_value(&self) -> String {
1090 self.directives
1091 .iter()
1092 .map(|(directive, values)| format!("{} {}", directive, values.join(" ")))
1093 .collect::<Vec<_>>()
1094 .join("; ")
1095 }
1096}
1097
1098impl Default for CspMiddleware {
1099 fn default() -> Self {
1100 Self::default_policy()
1101 }
1102}
1103
1104#[async_trait]
1105impl Middleware for CspMiddleware {
1106 async fn call(&self, request: Request, next: Next) -> crate::Result<Response> {
1107 let mut response = next.run(request).await?;
1108
1109 let header_name = if self.report_only {
1110 "Content-Security-Policy-Report-Only"
1111 } else {
1112 "Content-Security-Policy"
1113 };
1114
1115 let header_value = self.build_header_value();
1116 response
1117 .headers
1118 .insert(header_name.to_string(), header_value);
1119
1120 Ok(response)
1121 }
1122}
1123
1124pub mod sanitize {
1126 pub fn html(input: &str) -> String {
1128 input
1129 .replace('&', "&")
1130 .replace('<', "<")
1131 .replace('>', ">")
1132 .replace('"', """)
1133 .replace('\'', "'")
1134 .replace('/', "/")
1135 }
1136
1137 pub fn sql(input: &str) -> String {
1139 input
1140 .replace('\'', "''")
1141 .replace('"', "\"\"")
1142 .replace('\\', "\\\\")
1143 .replace('\0', "")
1144 }
1145
1146 pub fn filename(input: &str) -> String {
1148 input
1149 .chars()
1150 .filter(|c: &char| c.is_alphanumeric() || *c == '.' || *c == '_' || *c == '-')
1151 .collect()
1152 }
1153
1154 pub fn is_valid_email(email: &str) -> bool {
1156 email.contains('@') && email.len() > 3 && email.len() < 255
1157 }
1158
1159 pub fn is_valid_url(url: &str) -> bool {
1161 url.starts_with("http://") || url.starts_with("https://")
1162 }
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use super::*;
1168
1169 #[test]
1170 fn test_html_sanitization() {
1171 let input = "<script>alert('xss')</script>";
1172 let expected = "<script>alert('xss')</script>";
1173 assert_eq!(sanitize::html(input), expected);
1174 }
1175
1176 #[test]
1177 fn test_filename_sanitization() {
1178 let input = "../../etc/passwd";
1179 let expected = "....etcpasswd";
1180 assert_eq!(sanitize::filename(input), expected);
1181 }
1182
1183 #[test]
1184 fn test_email_validation() {
1185 assert!(sanitize::is_valid_email("test@example.com"));
1186 assert!(!sanitize::is_valid_email("invalid"));
1187 assert!(sanitize::is_valid_email("@example.com")); }
1189
1190 #[test]
1191 fn test_url_validation() {
1192 assert!(sanitize::is_valid_url("https://example.com"));
1193 assert!(sanitize::is_valid_url("http://example.com"));
1194 assert!(!sanitize::is_valid_url("ftp://example.com"));
1195 assert!(!sanitize::is_valid_url("example.com"));
1196 }
1197
1198 #[tokio::test]
1199 async fn test_csrf_token_generation() {
1200 let middleware = CsrfMiddleware::new("secret_key".to_string());
1201 let token = middleware.generate_token("session_123");
1202 assert!(!token.is_empty());
1203 assert!(token.contains(':'));
1204 }
1205
1206 #[test]
1207 fn test_csp_header_building() {
1208 let mut csp = CspMiddleware::new();
1209 csp.directive("default-src", vec!["'self'".to_string()]);
1210 csp.directive(
1211 "script-src",
1212 vec!["'self'".to_string(), "'unsafe-inline'".to_string()],
1213 );
1214
1215 let header = csp.build_header_value();
1216 assert!(header.contains("default-src 'self'"));
1217 assert!(header.contains("script-src 'self' 'unsafe-inline'"));
1218 }
1219}