1use hyper::HeaderMap;
36use hyper::header::{
37 CONTENT_SECURITY_POLICY, STRICT_TRANSPORT_SECURITY, X_FRAME_OPTIONS,
38 X_CONTENT_TYPE_OPTIONS, X_XSS_PROTECTION, REFERRER_POLICY,
39 CONTENT_LENGTH, HeaderValue
40};
41use std::collections::HashMap;
42use std::net::{Ipv4Addr, Ipv6Addr};
43use std::sync::Arc;
44use std::time::{Duration, Instant};
45use tokio::sync::RwLock;
46
47#[derive(Debug, Clone)]
49pub struct RateLimitConfig {
50 pub max_requests: usize,
52
53 pub window_secs: u64,
55
56 pub enabled: bool,
58}
59
60impl Default for RateLimitConfig {
61 fn default() -> Self {
62 Self {
63 max_requests: 100, window_secs: 60, enabled: true,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct IpAccessConfig {
73 pub allowlist: Vec<String>,
75
76 pub blocklist: Vec<String>,
78
79 pub enabled: bool,
81}
82
83impl Default for IpAccessConfig {
84 fn default() -> Self {
85 Self {
86 allowlist: vec![],
87 blocklist: vec![],
88 enabled: false,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct SecurityHeadersConfig {
96 pub content_security_policy: Option<String>,
98
99 pub hsts_max_age: Option<u64>,
101
102 pub hsts_include_subdomains: bool,
104
105 pub x_frame_options: Option<String>,
107
108 pub x_content_type_options: Option<String>,
110
111 pub x_xss_protection: Option<String>,
113
114 pub referrer_policy: Option<String>,
116
117 pub enabled: bool,
119}
120
121impl Default for SecurityHeadersConfig {
122 fn default() -> Self {
123 Self {
124 content_security_policy: Some("default-src 'self'".to_string()),
125 hsts_max_age: Some(31536000), hsts_include_subdomains: true,
127 x_frame_options: Some("SAMEORIGIN".to_string()),
128 x_content_type_options: Some("nosniff".to_string()),
129 x_xss_protection: Some("1; mode=block".to_string()),
130 referrer_policy: Some("strict-origin-when-cross-origin".to_string()),
131 enabled: true,
132 }
133 }
134}
135
136#[derive(Debug, Clone)]
138pub struct RequestSizeConfig {
139 pub max_body_size: usize,
141
142 pub max_headers: usize,
144
145 pub max_header_line_size: usize,
147
148 pub enabled: bool,
150}
151
152impl Default for RequestSizeConfig {
153 fn default() -> Self {
154 Self {
155 max_body_size: 10 * 1024 * 1024, max_headers: 100,
157 max_header_line_size: 8192, enabled: true,
159 }
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct SecurityConfig {
166 pub rate_limit: RateLimitConfig,
168
169 pub ip_access: IpAccessConfig,
171
172 pub security_headers: SecurityHeadersConfig,
174
175 pub request_size: RequestSizeConfig,
177}
178
179impl Default for SecurityConfig {
180 fn default() -> Self {
181 Self {
182 rate_limit: RateLimitConfig::default(),
183 ip_access: IpAccessConfig::default(),
184 security_headers: SecurityHeadersConfig::default(),
185 request_size: RequestSizeConfig::default(),
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192struct RateLimitState {
193 request_count: usize,
194 window_start: Instant,
195}
196
197pub struct SecurityLayer {
199 config: Arc<SecurityConfig>,
200 rate_limit_states: Arc<RwLock<HashMap<String, RateLimitState>>>,
201}
202
203impl SecurityLayer {
204 pub fn new(config: SecurityConfig) -> Self {
206 Self {
207 config: Arc::new(config),
208 rate_limit_states: Arc::new(RwLock::new(HashMap::new())),
209 }
210 }
211
212 pub fn is_ip_allowed(&self, ip: &str) -> bool {
214 if !self.config.ip_access.enabled {
215 return true;
216 }
217
218 if !self.config.ip_access.blocklist.is_empty() {
220 if self.config.ip_access.blocklist.iter().any(|blocked| self.match_ip(ip, blocked)) {
221 return false;
222 }
223 }
224
225 if !self.config.ip_access.allowlist.is_empty() {
227 return self.config.ip_access.allowlist.iter().any(|allowed| self.match_ip(ip, allowed));
228 }
229
230 true
232 }
233
234 fn match_ip(&self, ip: &str, pattern: &str) -> bool {
236 if ip == pattern {
238 return true;
239 }
240
241 if pattern.contains('/') {
243 if let Ok(cidr) = self.parse_cidr_v4(pattern) {
244 if let Ok(ip_addr) = ip.parse::<Ipv4Addr>() {
245 return self.is_ip_in_cidr_v4(&ip_addr, &cidr);
246 }
247 }
248
249 if let Ok(cidr) = self.parse_cidr_v6(pattern) {
251 if let Ok(ip_addr) = ip.parse::<Ipv6Addr>() {
252 return self.is_ip_in_cidr_v6(&ip_addr, &cidr);
253 }
254 }
255 }
256
257 false
258 }
259
260 fn parse_cidr_v4(&self, pattern: &str) -> Result<(Ipv4Addr, u8), ()> {
262 let parts: Vec<&str> = pattern.split('/').collect();
263 if parts.len() != 2 {
264 return Err(());
265 }
266
267 let ip = parts[0].parse::<Ipv4Addr>().map_err(|_| ())?;
268 let prefix_len = parts[1].parse::<u8>().map_err(|_| ())?;
269
270 if prefix_len > 32 {
271 return Err(());
272 }
273
274 Ok((ip, prefix_len))
275 }
276
277 fn parse_cidr_v6(&self, pattern: &str) -> Result<(Ipv6Addr, u8), ()> {
279 let parts: Vec<&str> = pattern.split('/').collect();
280 if parts.len() != 2 {
281 return Err(());
282 }
283
284 let ip = parts[0].parse::<Ipv6Addr>().map_err(|_| ())?;
285 let prefix_len = parts[1].parse::<u8>().map_err(|_| ())?;
286
287 if prefix_len > 128 {
288 return Err(());
289 }
290
291 Ok((ip, prefix_len))
292 }
293
294 fn is_ip_in_cidr_v4(&self, ip: &Ipv4Addr, cidr: &(Ipv4Addr, u8)) -> bool {
296 let mask = if cidr.1 == 0 {
297 0u32
298 } else {
299 u32::MAX << (32 - cidr.1)
300 };
301
302 let ip_u32 = u32::from(*ip);
303 let network_u32 = u32::from(cidr.0);
304
305 (ip_u32 & mask) == (network_u32 & mask)
306 }
307
308 fn is_ip_in_cidr_v6(&self, ip: &Ipv6Addr, cidr: &(Ipv6Addr, u8)) -> bool {
310 let mask_bits = cidr.1 as usize;
311 if mask_bits == 0 {
312 return true;
313 }
314
315 let ip_bytes = ip.octets();
316 let network_bytes = cidr.0.octets();
317
318 for i in 0..((mask_bits + 7) / 8) {
319 let bits_to_check = if i == (mask_bits / 8) {
320 mask_bits % 8
321 } else {
322 8
323 };
324
325 let mask = 0xFFu8 << (8 - bits_to_check);
326 if (ip_bytes[i] & mask) != (network_bytes[i] & mask) {
327 return false;
328 }
329 }
330
331 true
332 }
333
334 pub async fn check_rate_limit(&self, ip: &str) -> bool {
336 if !self.config.rate_limit.enabled {
337 return true;
338 }
339
340 let mut states = self.rate_limit_states.write().await;
341 let now = Instant::now();
342 let window = Duration::from_secs(self.config.rate_limit.window_secs);
343
344 let state = states.entry(ip.to_string()).or_insert_with(|| {
346 RateLimitState {
347 request_count: 0,
348 window_start: now,
349 }
350 });
351
352 if now.duration_since(state.window_start) >= window {
354 state.window_start = now;
355 state.request_count = 1;
356 return true;
357 }
358
359 if state.request_count >= self.config.rate_limit.max_requests {
361 return false;
362 }
363
364 state.request_count += 1;
365 true
366 }
367
368 pub fn check_request_size(&self, headers: &HeaderMap) -> bool {
370 if !self.config.request_size.enabled {
371 return true;
372 }
373
374 if headers.len() > self.config.request_size.max_headers {
376 return false;
377 }
378
379 if let Some(content_length) = headers.get(CONTENT_LENGTH) {
381 if let Ok(length) = content_length.to_str() {
382 if let Ok(size) = length.parse::<usize>() {
383 if size > self.config.request_size.max_body_size {
384 return false;
385 }
386 }
387 }
388 }
389
390 true
391 }
392
393 pub fn add_security_headers(&self, response_headers: &mut HeaderMap) {
395 if !self.config.security_headers.enabled {
396 return;
397 }
398
399 if let Some(csp) = &self.config.security_headers.content_security_policy {
401 if let Ok(value) = HeaderValue::from_str(csp) {
402 response_headers.insert(CONTENT_SECURITY_POLICY, value);
403 }
404 }
405
406 if let Some(max_age) = self.config.security_headers.hsts_max_age {
408 let hsts_value = if self.config.security_headers.hsts_include_subdomains {
409 format!("max-age={}; includeSubDomains", max_age)
410 } else {
411 format!("max-age={}", max_age)
412 };
413
414 if let Ok(value) = HeaderValue::from_str(&hsts_value) {
415 response_headers.insert(STRICT_TRANSPORT_SECURITY, value);
416 }
417 }
418
419 if let Some(frame_options) = &self.config.security_headers.x_frame_options {
421 if let Ok(value) = HeaderValue::from_str(frame_options) {
422 response_headers.insert(X_FRAME_OPTIONS, value);
423 }
424 }
425
426 if let Some(content_type_options) = &self.config.security_headers.x_content_type_options {
428 if let Ok(value) = HeaderValue::from_str(content_type_options) {
429 response_headers.insert(X_CONTENT_TYPE_OPTIONS, value);
430 }
431 }
432
433 if let Some(xss_protection) = &self.config.security_headers.x_xss_protection {
435 if let Ok(value) = HeaderValue::from_str(xss_protection) {
436 response_headers.insert(X_XSS_PROTECTION, value);
437 }
438 }
439
440 if let Some(referrer_policy) = &self.config.security_headers.referrer_policy {
442 if let Ok(value) = HeaderValue::from_str(referrer_policy) {
443 response_headers.insert(REFERRER_POLICY, value);
444 }
445 }
446 }
447
448 pub async fn cleanup_expired_states(&self) {
450 let mut states = self.rate_limit_states.write().await;
451 let now = Instant::now();
452 let window = Duration::from_secs(self.config.rate_limit.window_secs);
453
454 states.retain(|_, state| {
455 now.duration_since(state.window_start) < window * 2
456 });
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_rate_limit_config_default() {
466 let config = RateLimitConfig::default();
467 assert_eq!(config.max_requests, 100);
468 assert_eq!(config.window_secs, 60);
469 assert!(config.enabled);
470 }
471
472 #[test]
473 fn test_ip_access_config_default() {
474 let config = IpAccessConfig::default();
475 assert!(config.allowlist.is_empty());
476 assert!(config.blocklist.is_empty());
477 assert!(!config.enabled);
478 }
479
480 #[test]
481 fn test_security_headers_config_default() {
482 let config = SecurityHeadersConfig::default();
483 assert_eq!(config.content_security_policy, Some("default-src 'self'".to_string()));
484 assert_eq!(config.hsts_max_age, Some(31536000));
485 assert!(config.hsts_include_subdomains);
486 assert_eq!(config.x_frame_options, Some("SAMEORIGIN".to_string()));
487 assert_eq!(config.x_content_type_options, Some("nosniff".to_string()));
488 assert_eq!(config.x_xss_protection, Some("1; mode=block".to_string()));
489 assert_eq!(config.referrer_policy, Some("strict-origin-when-cross-origin".to_string()));
490 assert!(config.enabled);
491 }
492
493 #[test]
494 fn test_request_size_config_default() {
495 let config = RequestSizeConfig::default();
496 assert_eq!(config.max_body_size, 10 * 1024 * 1024);
497 assert_eq!(config.max_headers, 100);
498 assert_eq!(config.max_header_line_size, 8192);
499 assert!(config.enabled);
500 }
501
502 #[test]
503 fn test_ip_exact_match() {
504 let config = SecurityConfig::default();
505 let layer = SecurityLayer::new(config);
506
507 assert!(layer.is_ip_allowed("192.168.1.1"));
508 }
509
510 #[test]
511 fn test_ip_blocklist() {
512 let config = SecurityConfig {
513 ip_access: IpAccessConfig {
514 blocklist: vec!["192.168.1.100".to_string()],
515 enabled: true,
516 ..Default::default()
517 },
518 ..Default::default()
519 };
520 let layer = SecurityLayer::new(config);
521
522 assert!(!layer.is_ip_allowed("192.168.1.100"));
523 assert!(layer.is_ip_allowed("192.168.1.1"));
524 }
525
526 #[test]
527 fn test_ip_allowlist() {
528 let config = SecurityConfig {
529 ip_access: IpAccessConfig {
530 allowlist: vec!["192.168.1.10".to_string(), "192.168.1.20".to_string()],
531 enabled: true,
532 ..Default::default()
533 },
534 ..Default::default()
535 };
536 let layer = SecurityLayer::new(config);
537
538 assert!(layer.is_ip_allowed("192.168.1.10"));
539 assert!(layer.is_ip_allowed("192.168.1.20"));
540 assert!(!layer.is_ip_allowed("192.168.1.1"));
541 }
542
543 #[test]
544 fn test_cidr_v4_range() {
545 let config = SecurityConfig::default();
546 let layer = SecurityLayer::new(config);
547
548 assert!(layer.match_ip("192.168.1.10", "192.168.1.0/24"));
550 assert!(layer.match_ip("192.168.1.255", "192.168.1.0/24"));
551 assert!(!layer.match_ip("192.168.2.1", "192.168.1.0/24"));
552
553 assert!(layer.match_ip("192.168.1.1", "192.168.0.0/16"));
555 assert!(layer.match_ip("192.168.255.255", "192.168.0.0/16"));
556 assert!(!layer.match_ip("192.167.255.255", "192.168.0.0/16"));
557
558 assert!(layer.match_ip("192.168.1.1", "192.168.1.1/32"));
560 assert!(!layer.match_ip("192.168.1.2", "192.168.1.1/32"));
561 }
562
563 #[test]
564 fn test_cidr_v6_range() {
565 let config = SecurityConfig::default();
566 let layer = SecurityLayer::new(config);
567
568 assert!(layer.match_ip("2001:db8::1", "2001:db8::/64"));
570 assert!(layer.match_ip("2001:db8::ffff", "2001:db8::/64"));
571 assert!(!layer.match_ip("2001:db8:1::1", "2001:db8::/64"));
572
573 assert!(layer.match_ip("2001:db8::1", "2001:db8::1/128"));
575 assert!(!layer.match_ip("2001:db8::2", "2001:db8::1/128"));
576 }
577
578 #[test]
579 fn test_request_size_check() {
580 let config = SecurityConfig::default();
581 let layer = SecurityLayer::new(config);
582
583 let mut headers = HeaderMap::new();
585 headers.insert(CONTENT_LENGTH, "1024".parse().unwrap());
586 assert!(layer.check_request_size(&headers));
587
588 let mut too_many_headers = HeaderMap::new();
591 assert!(layer.check_request_size(&too_many_headers));
594
595 let mut large_body = HeaderMap::new();
597 large_body.insert(CONTENT_LENGTH, "10485761".parse().unwrap()); assert!(!layer.check_request_size(&large_body));
599 }
600
601 #[test]
602 fn test_disabled_ip_access_control() {
603 let config = SecurityConfig {
604 ip_access: IpAccessConfig {
605 blocklist: vec!["192.168.1.100".to_string()],
606 enabled: false,
607 ..Default::default()
608 },
609 ..Default::default()
610 };
611 let layer = SecurityLayer::new(config);
612
613 assert!(layer.is_ip_allowed("192.168.1.100"));
615 assert!(layer.is_ip_allowed("192.168.1.1"));
616 }
617
618 #[test]
619 fn test_disabled_size_limits() {
620 let config = SecurityConfig {
621 request_size: RequestSizeConfig {
622 max_headers: 10,
623 max_body_size: 100,
624 enabled: false,
625 ..Default::default()
626 },
627 ..Default::default()
628 };
629 let layer = SecurityLayer::new(config);
630
631 let mut many_headers = HeaderMap::new();
635 for _ in 0..100 {
637 many_headers.insert(CONTENT_LENGTH, "1000".parse().unwrap());
638 }
639 assert!(layer.check_request_size(&many_headers));
640 }
641
642 #[test]
643 fn test_security_headers_addition() {
644 let config = SecurityConfig::default();
645 let layer = SecurityLayer::new(config);
646
647 let mut headers = HeaderMap::new();
648 layer.add_security_headers(&mut headers);
649
650 assert!(headers.contains_key(CONTENT_SECURITY_POLICY));
651 assert!(headers.contains_key(STRICT_TRANSPORT_SECURITY));
652 assert!(headers.contains_key(X_FRAME_OPTIONS));
653 assert!(headers.contains_key(X_CONTENT_TYPE_OPTIONS));
654 assert!(headers.contains_key(X_XSS_PROTECTION));
655 assert!(headers.contains_key(REFERRER_POLICY));
656 }
657
658 #[test]
659 fn test_disabled_security_headers() {
660 let config = SecurityConfig {
661 security_headers: SecurityHeadersConfig {
662 enabled: false,
663 ..Default::default()
664 },
665 ..Default::default()
666 };
667 let layer = SecurityLayer::new(config);
668
669 let mut headers = HeaderMap::new();
670 layer.add_security_headers(&mut headers);
671
672 assert!(!headers.contains_key(CONTENT_SECURITY_POLICY));
673 assert!(!headers.contains_key(STRICT_TRANSPORT_SECURITY));
674 }
675
676 #[tokio::test]
677 async fn test_rate_limit_window_reset() {
678 let config = SecurityConfig::default();
679 let layer = SecurityLayer::new(config);
680
681 let ip = "192.168.1.1";
682 assert!(layer.check_rate_limit(ip).await);
684 }
685
686 #[tokio::test]
687 async fn test_rate_limit_disabled() {
688 let config = SecurityConfig {
689 rate_limit: RateLimitConfig {
690 enabled: false,
691 ..Default::default()
692 },
693 ..Default::default()
694 };
695 let layer = SecurityLayer::new(config);
696
697 let ip = "192.168.1.1";
699 for _ in 0..200 {
700 assert!(layer.check_rate_limit(ip).await);
701 }
702 }
703
704 #[test]
705 fn test_parse_cidr_v4_invalid_format() {
706 let config = SecurityConfig::default();
707 let layer = SecurityLayer::new(config);
708
709 assert!(layer.parse_cidr_v4("192.168.1.0").is_err());
711 assert!(layer.parse_cidr_v4("192.168.1.0/24/extra").is_err());
713 assert!(layer.parse_cidr_v4("192.168.1.0/33").is_err());
715 }
716
717 #[test]
718 fn test_parse_cidr_v6_invalid_format() {
719 let config = SecurityConfig::default();
720 let layer = SecurityLayer::new(config);
721
722 assert!(layer.parse_cidr_v6("::1").is_err());
724 assert!(layer.parse_cidr_v6("::1/129").is_err());
726 }
727
728 #[test]
729 fn test_is_ip_in_cidr_v4_edge_cases() {
730 let config = SecurityConfig::default();
731 let layer = SecurityLayer::new(config);
732
733 let ip = "192.168.1.1".parse().unwrap();
735 let cidr = (ip, 0);
736 assert!(layer.is_ip_in_cidr_v4(&ip, &cidr));
737
738 let cidr = (ip, 32);
740 assert!(layer.is_ip_in_cidr_v4(&ip, &cidr));
741 }
742
743 #[test]
744 fn test_is_ip_in_cidr_v6_edge_cases() {
745 let config = SecurityConfig::default();
746 let layer = SecurityLayer::new(config);
747
748 let ip = "2001:db8::1".parse().unwrap();
750 let cidr = (ip, 0);
751 assert!(layer.is_ip_in_cidr_v6(&ip, &cidr));
752
753 let cidr = (ip, 128);
755 assert!(layer.is_ip_in_cidr_v6(&ip, &cidr));
756 }
757
758 #[test]
759 fn test_check_request_size_no_content_length() {
760 let config = SecurityConfig::default();
761 let layer = SecurityLayer::new(config);
762
763 let headers = HeaderMap::new();
765 assert!(layer.check_request_size(&headers));
766 }
767
768 #[test]
769 fn test_check_request_size_invalid_content_length() {
770 let config = SecurityConfig::default();
771 let layer = SecurityLayer::new(config);
772
773 let mut headers = HeaderMap::new();
775 headers.insert(CONTENT_LENGTH, "invalid".parse().unwrap());
776 assert!(layer.check_request_size(&headers));
777 }
778
779 #[test]
780 fn test_ip_in_cidr_v6_partial_byte() {
781 let config = SecurityConfig::default();
782 let layer = SecurityLayer::new(config);
783
784 let network: Ipv6Addr = "2001:db8::".parse().unwrap();
786 let ip_in: Ipv6Addr = "2001:db8::1".parse().unwrap();
787 let ip_out: Ipv6Addr = "2001:db9::1".parse().unwrap();
788
789 assert!(layer.is_ip_in_cidr_v6(&ip_in, &(network, 33)));
791 assert!(!layer.is_ip_in_cidr_v6(&ip_out, &(network, 33)));
792 }
793
794}