1use crate::error::{P2PError, P2pResult};
54use parking_lot::RwLock;
55use std::collections::HashMap;
56use std::net::{IpAddr, SocketAddr};
57use std::path::Path;
58use std::sync::Arc;
59use std::time::{Duration, Instant};
60use thiserror::Error;
61
62const MAX_PEER_ID_LENGTH: usize = 64;
64const MIN_PEER_ID_LENGTH: usize = 16;
65const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; const MAX_PATH_LENGTH: usize = 4096;
67const MAX_KEY_SIZE: usize = 1024 * 1024; const MAX_VALUE_SIZE: usize = 10 * 1024 * 1024; const MAX_FILE_NAME_LENGTH: usize = 255;
70
71const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
73const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 1000;
74const DEFAULT_BURST_SIZE: u32 = 100;
75
76#[derive(Debug, Error)]
80pub enum ValidationError {
81 #[error("Invalid peer ID format: {0}")]
82 InvalidPeerId(String),
83
84 #[error("Invalid network address: {0}")]
85 InvalidAddress(String),
86
87 #[error("Message size exceeds limit: {size} > {limit}")]
88 MessageTooLarge { size: usize, limit: usize },
89
90 #[error("Invalid file path: {0}")]
91 InvalidPath(String),
92
93 #[error("Path traversal attempt detected: {0}")]
94 PathTraversal(String),
95
96 #[error("Invalid key size: {size} bytes (max: {max})")]
97 InvalidKeySize { size: usize, max: usize },
98
99 #[error("Invalid value size: {size} bytes (max: {max})")]
100 InvalidValueSize { size: usize, max: usize },
101
102 #[error("Invalid cryptographic parameter: {0}")]
103 InvalidCryptoParam(String),
104
105 #[error("Rate limit exceeded for {identifier}")]
106 RateLimitExceeded { identifier: String },
107
108 #[error("Invalid format: {0}")]
109 InvalidFormat(String),
110
111 #[error("Value out of range: {value} (min: {min}, max: {max})")]
112 OutOfRange { value: i64, min: i64, max: i64 },
113}
114
115impl From<ValidationError> for P2PError {
116 fn from(err: ValidationError) -> Self {
117 P2PError::Validation(err.to_string().into())
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct ValidationContext {
124 pub max_message_size: usize,
125 pub max_key_size: usize,
126 pub max_value_size: usize,
127 pub max_path_length: usize,
128 pub allow_localhost: bool,
129 pub allow_private_ips: bool,
130 pub rate_limiter: Option<Arc<RateLimiter>>,
131}
132
133impl Default for ValidationContext {
134 fn default() -> Self {
135 Self {
136 max_message_size: MAX_MESSAGE_SIZE,
137 max_key_size: MAX_KEY_SIZE,
138 max_value_size: MAX_VALUE_SIZE,
139 max_path_length: MAX_PATH_LENGTH,
140 allow_localhost: false,
141 allow_private_ips: false,
142 rate_limiter: None,
143 }
144 }
145}
146
147impl ValidationContext {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn with_rate_limiting(mut self, limiter: Arc<RateLimiter>) -> Self {
155 self.rate_limiter = Some(limiter);
156 self
157 }
158
159 pub fn allow_localhost(mut self) -> Self {
161 self.allow_localhost = true;
162 self
163 }
164
165 pub fn allow_private_ips(mut self) -> Self {
167 self.allow_private_ips = true;
168 self
169 }
170}
171
172pub trait Validate {
174 fn validate(&self, ctx: &ValidationContext) -> P2pResult<()>;
176}
177
178pub trait Sanitize {
180 fn sanitize(&self) -> Self;
182}
183
184pub fn validate_network_address(addr: &SocketAddr, ctx: &ValidationContext) -> P2pResult<()> {
188 let ip = addr.ip();
189
190 if ip.is_loopback() && !ctx.allow_localhost {
192 return Err(
193 ValidationError::InvalidAddress("Localhost addresses not allowed".to_string()).into(),
194 );
195 }
196
197 if is_private_ip(&ip) && !ctx.allow_private_ips {
199 return Err(ValidationError::InvalidAddress(
200 "Private IP addresses not allowed".to_string(),
201 )
202 .into());
203 }
204
205 if addr.port() == 0 {
207 return Err(ValidationError::InvalidAddress("Port 0 is not allowed".to_string()).into());
208 }
209
210 Ok(())
211}
212
213fn is_private_ip(ip: &IpAddr) -> bool {
215 match ip {
216 IpAddr::V4(ipv4) => ipv4.is_private(),
217 IpAddr::V6(ipv6) => ipv6.is_unique_local() || ipv6.is_unicast_link_local(),
218 }
219}
220
221pub fn validate_peer_id(peer_id: &str) -> P2pResult<()> {
225 if peer_id.len() < MIN_PEER_ID_LENGTH || peer_id.len() > MAX_PEER_ID_LENGTH {
227 return Err(ValidationError::InvalidPeerId(format!(
228 "Length must be between {} and {} characters",
229 MIN_PEER_ID_LENGTH, MAX_PEER_ID_LENGTH
230 ))
231 .into());
232 }
233
234 if !peer_id
235 .chars()
236 .all(|ch| ch.is_alphanumeric() || ch == '_' || ch == '-')
237 {
238 return Err(ValidationError::InvalidPeerId(
239 "Must contain only alphanumeric characters, hyphens, and underscores".to_string(),
240 )
241 .into());
242 }
243
244 Ok(())
245}
246
247pub fn validate_message_size(size: usize, max_size: usize) -> P2pResult<()> {
251 if size > max_size {
252 return Err(ValidationError::MessageTooLarge {
253 size,
254 limit: max_size,
255 }
256 .into());
257 }
258 Ok(())
259}
260
261pub fn validate_file_path(path: &Path) -> P2pResult<()> {
265 let path_str = path.to_string_lossy();
266
267 if path_str.len() > MAX_PATH_LENGTH {
269 return Err(ValidationError::InvalidPath(format!(
270 "Path too long: {} > {}",
271 path_str.len(),
272 MAX_PATH_LENGTH
273 ))
274 .into());
275 }
276
277 let decoded = path_str
279 .replace("%2e", ".")
280 .replace("%2f", "/")
281 .replace("%5c", "\\");
282
283 let traversal_patterns = ["../", "..\\", "..", "..;", "....//", "%2e%2e", "%252e%252e"];
285 for pattern in &traversal_patterns {
286 if path_str.contains(pattern) || decoded.contains(pattern) {
287 return Err(ValidationError::PathTraversal(path_str.to_string()).into());
288 }
289 }
290
291 if path_str.contains('\0') {
293 return Err(ValidationError::InvalidPath("Path contains null bytes".to_string()).into());
294 }
295
296 let dangerous_chars = ['|', '&', ';', '$', '`', '\n'];
298 if path_str.chars().any(|c| dangerous_chars.contains(&c)) {
299 return Err(
300 ValidationError::InvalidPath("Path contains dangerous characters".to_string()).into(),
301 );
302 }
303
304 for component in path.components() {
306 if let Some(name) = component.as_os_str().to_str() {
307 if name.len() > MAX_FILE_NAME_LENGTH {
308 return Err(ValidationError::InvalidPath(format!(
309 "Component '{}' exceeds maximum length",
310 name
311 ))
312 .into());
313 }
314
315 if name.contains('\0') {
317 return Err(ValidationError::InvalidPath(format!(
318 "Component '{}' contains invalid characters",
319 name
320 ))
321 .into());
322 }
323 }
324 }
325
326 Ok(())
327}
328
329pub fn validate_key_size(size: usize, expected: usize) -> P2pResult<()> {
333 if size != expected {
334 return Err(ValidationError::InvalidCryptoParam(format!(
335 "Invalid key size: expected {} bytes, got {}",
336 expected, size
337 ))
338 .into());
339 }
340 Ok(())
341}
342
343pub fn validate_nonce_size(size: usize, expected: usize) -> P2pResult<()> {
345 if size != expected {
346 return Err(ValidationError::InvalidCryptoParam(format!(
347 "Invalid nonce size: expected {} bytes, got {}",
348 expected, size
349 ))
350 .into());
351 }
352 Ok(())
353}
354
355pub fn validate_dht_key(key: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
359 if key.is_empty() {
360 return Err(ValidationError::InvalidFormat("DHT key cannot be empty".to_string()).into());
361 }
362
363 if key.len() > ctx.max_key_size {
364 return Err(ValidationError::InvalidKeySize {
365 size: key.len(),
366 max: ctx.max_key_size,
367 }
368 .into());
369 }
370
371 Ok(())
372}
373
374pub fn validate_dht_value(value: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
376 if value.len() > ctx.max_value_size {
377 return Err(ValidationError::InvalidValueSize {
378 size: value.len(),
379 max: ctx.max_value_size,
380 }
381 .into());
382 }
383
384 Ok(())
385}
386
387#[derive(Debug)]
391pub struct RateLimiter {
392 ip_limits: RwLock<HashMap<IpAddr, RateLimitBucket>>,
394 global_limit: RwLock<RateLimitBucket>,
396 config: RateLimitConfig,
398}
399
400#[derive(Debug, Clone)]
402pub struct RateLimitConfig {
403 pub window: Duration,
405 pub max_requests: u32,
407 pub burst_size: u32,
409 pub adaptive: bool,
411 pub cleanup_interval: Duration,
413}
414
415impl Default for RateLimitConfig {
416 fn default() -> Self {
417 Self {
418 window: DEFAULT_RATE_LIMIT_WINDOW,
419 max_requests: DEFAULT_MAX_REQUESTS_PER_WINDOW,
420 burst_size: DEFAULT_BURST_SIZE,
421 adaptive: true,
422 cleanup_interval: Duration::from_secs(300), }
424 }
425}
426
427#[derive(Debug)]
429struct RateLimitBucket {
430 tokens: f64,
431 last_update: Instant,
432 requests_in_window: u32,
433 window_start: Instant,
434}
435
436impl RateLimitBucket {
437 fn new(initial_tokens: f64) -> Self {
438 let now = Instant::now();
439 Self {
440 tokens: initial_tokens,
441 last_update: now,
442 requests_in_window: 0,
443 window_start: now,
444 }
445 }
446
447 fn try_consume(&mut self, config: &RateLimitConfig) -> bool {
449 let now = Instant::now();
450
451 if now.duration_since(self.window_start) > config.window {
453 self.window_start = now;
454 self.requests_in_window = 0;
455 }
456
457 let elapsed = now.duration_since(self.last_update).as_secs_f64();
459 let refill_rate = config.max_requests as f64 / config.window.as_secs_f64();
460 self.tokens += elapsed * refill_rate;
461 self.tokens = self.tokens.min(config.burst_size as f64);
462 self.last_update = now;
463
464 if self.tokens >= 1.0 && self.requests_in_window < config.max_requests {
466 self.tokens -= 1.0;
467 self.requests_in_window += 1;
468 true
469 } else {
470 false
471 }
472 }
473}
474
475impl RateLimiter {
476 pub fn new(config: RateLimitConfig) -> Self {
478 Self {
479 ip_limits: RwLock::new(HashMap::new()),
480 global_limit: RwLock::new(RateLimitBucket::new(config.burst_size as f64)),
481 config,
482 }
483 }
484
485 pub fn check_ip(&self, ip: &IpAddr) -> P2pResult<()> {
487 {
489 let mut global = self.global_limit.write();
490 if !global.try_consume(&self.config) {
491 return Err(ValidationError::RateLimitExceeded {
492 identifier: "global".to_string(),
493 }
494 .into());
495 }
496 }
497
498 {
500 let mut limits = self.ip_limits.write();
501 let bucket = limits
502 .entry(*ip)
503 .or_insert_with(|| RateLimitBucket::new(self.config.burst_size as f64));
504
505 if !bucket.try_consume(&self.config) {
506 return Err(ValidationError::RateLimitExceeded {
507 identifier: ip.to_string(),
508 }
509 .into());
510 }
511 }
512
513 Ok(())
514 }
515
516 pub fn cleanup(&self) {
518 let mut limits = self.ip_limits.write();
519 let now = Instant::now();
520
521 limits.retain(|_, bucket| {
522 now.duration_since(bucket.last_update) < self.config.cleanup_interval
523 });
524 }
525}
526
527#[derive(Debug)]
531pub struct NetworkMessage {
532 pub peer_id: String,
533 pub payload: Vec<u8>,
534 pub timestamp: u64,
535}
536
537impl Validate for NetworkMessage {
538 fn validate(&self, ctx: &ValidationContext) -> P2pResult<()> {
539 validate_peer_id(&self.peer_id)?;
541
542 validate_message_size(self.payload.len(), ctx.max_message_size)?;
544
545 let now = std::time::SystemTime::now()
547 .duration_since(std::time::UNIX_EPOCH)
548 .map_err(|e| P2PError::Internal(format!("System time error: {}", e).into()))?
549 .as_secs();
550
551 if self.timestamp > now + 300 {
552 return Err(
554 ValidationError::InvalidFormat("Timestamp too far in future".to_string()).into(),
555 );
556 }
557
558 Ok(())
559 }
560}
561
562#[derive(Debug)]
564pub struct ApiRequest {
565 pub method: String,
566 pub path: String,
567 pub params: HashMap<String, String>,
568}
569
570impl Validate for ApiRequest {
571 fn validate(&self, _ctx: &ValidationContext) -> P2pResult<()> {
572 match self.method.as_str() {
574 "GET" | "POST" | "PUT" | "DELETE" => {}
575 _ => {
576 return Err(ValidationError::InvalidFormat(format!(
577 "Invalid HTTP method: {}",
578 self.method
579 ))
580 .into());
581 }
582 }
583
584 if !self.path.starts_with('/') {
586 return Err(
587 ValidationError::InvalidFormat("Path must start with /".to_string()).into(),
588 );
589 }
590
591 if self.path.contains("..") {
592 return Err(ValidationError::PathTraversal(self.path.clone()).into());
593 }
594
595 for (key, value) in &self.params {
597 if key.is_empty() {
598 return Err(
599 ValidationError::InvalidFormat("Empty parameter key".to_string()).into(),
600 );
601 }
602
603 let lower_value = value.to_lowercase();
605 let sql_patterns = [
606 "select ", "insert ", "update ", "delete ", "drop ", "union ", "exec ", "--", "/*",
607 "*/", "'", "\"", " or ", " and ", "1=1", "1='1",
608 ];
609
610 for pattern in &sql_patterns {
611 if lower_value.contains(pattern) {
612 return Err(ValidationError::InvalidFormat(
613 "Suspicious parameter value: potential SQL injection".to_string(),
614 )
615 .into());
616 }
617 }
618
619 let dangerous_chars = ['|', '&', ';', '$', '`', '\n', '\0'];
621 if value.chars().any(|c| dangerous_chars.contains(&c)) {
622 return Err(ValidationError::InvalidFormat(
623 "Dangerous characters in parameter value".to_string(),
624 )
625 .into());
626 }
627 }
628
629 Ok(())
630 }
631}
632
633pub fn validate_config_value<T>(value: &str, min: Option<T>, max: Option<T>) -> P2pResult<T>
635where
636 T: std::str::FromStr + PartialOrd + std::fmt::Display,
637{
638 let parsed = value
639 .parse::<T>()
640 .map_err(|_| ValidationError::InvalidFormat(format!("Failed to parse value: {}", value)))?;
641
642 if let Some(min_val) = min
643 && parsed < min_val
644 {
645 return Err(ValidationError::InvalidFormat(format!(
646 "Value {} is less than minimum {}",
647 parsed, min_val
648 ))
649 .into());
650 }
651
652 if let Some(max_val) = max
653 && parsed > max_val
654 {
655 return Err(ValidationError::InvalidFormat(format!(
656 "Value {} is greater than maximum {}",
657 parsed, max_val
658 ))
659 .into());
660 }
661
662 Ok(parsed)
663}
664
665pub fn sanitize_string(input: &str, max_length: usize) -> String {
667 let mut cleaned = input
669 .replace(['<', '>'], "")
670 .replace("script", "")
671 .replace("javascript:", "")
672 .replace("onerror", "")
673 .replace("onload", "")
674 .replace("onclick", "")
675 .replace("alert", "")
676 .replace("iframe", "");
677
678 cleaned = cleaned.replace('\u{2060}', ""); cleaned = cleaned.replace('\u{ffa0}', ""); cleaned = cleaned.replace('\u{200b}', ""); cleaned = cleaned.replace('\u{200c}', ""); cleaned = cleaned.replace('\u{200d}', ""); cleaned
687 .chars()
688 .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-' || *c == '.')
689 .take(max_length)
690 .collect()
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_peer_id_validation() {
699 assert!(validate_peer_id("valid_peer_id_123").is_ok());
701 assert!(validate_peer_id("PEER-ID-WITH-CAPS").is_ok());
702
703 assert!(validate_peer_id("short").is_err()); assert!(validate_peer_id(&"x".repeat(100)).is_err()); assert!(validate_peer_id("invalid peer id").is_err()); assert!(validate_peer_id("peer@id").is_err()); }
709
710 #[test]
711 fn test_network_address_validation() {
712 let ctx = ValidationContext::default();
713
714 let addr: SocketAddr = "8.8.8.8:53".parse().unwrap();
716 assert!(validate_network_address(&addr, &ctx).is_ok());
717
718 let localhost: SocketAddr = "127.0.0.1:80".parse().unwrap();
720 assert!(validate_network_address(&localhost, &ctx).is_err());
721
722 let ctx_localhost = ValidationContext::default().allow_localhost();
724 assert!(validate_network_address(&localhost, &ctx_localhost).is_ok());
725 }
726
727 #[test]
728 fn test_file_path_validation() {
729 assert!(validate_file_path(Path::new("data/file.txt")).is_ok());
731 assert!(validate_file_path(Path::new("/usr/local/bin")).is_ok());
732
733 assert!(validate_file_path(Path::new("../etc/passwd")).is_err());
735 assert!(validate_file_path(Path::new("file\0name")).is_err());
736 }
737
738 #[test]
739 fn test_rate_limiter() {
740 let config = RateLimitConfig {
741 window: Duration::from_secs(1),
742 max_requests: 10,
743 burst_size: 5,
744 ..Default::default()
745 };
746
747 let limiter = RateLimiter::new(config);
748 let ip: IpAddr = "192.168.1.1".parse().unwrap();
749
750 for _ in 0..5 {
752 assert!(limiter.check_ip(&ip).is_ok());
753 }
754
755 std::thread::sleep(Duration::from_millis(100));
757 assert!(limiter.check_ip(&ip).is_ok());
758 }
759
760 #[test]
761 fn test_message_validation() {
762 let ctx = ValidationContext::default();
763
764 let valid_msg = NetworkMessage {
765 peer_id: "valid_peer_id_123".to_string(),
766 payload: vec![0u8; 1024],
767 timestamp: std::time::SystemTime::now()
768 .duration_since(std::time::UNIX_EPOCH)
769 .unwrap()
770 .as_secs(),
771 };
772
773 assert!(valid_msg.validate(&ctx).is_ok());
774
775 let invalid_msg = NetworkMessage {
777 peer_id: "short".to_string(),
778 payload: vec![0u8; 1024],
779 timestamp: 0,
780 };
781
782 assert!(invalid_msg.validate(&ctx).is_err());
783 }
784
785 #[test]
786 fn test_sanitization() {
787 assert_eq!(sanitize_string("hello world!", 20), "helloworld");
788
789 assert_eq!(sanitize_string("test@#$%123", 20), "test123");
790
791 assert_eq!(
792 sanitize_string("very_long_string_that_exceeds_limit", 10),
793 "very_long_"
794 );
795 }
796}