Skip to main content

saorsa_core/
validation.rs

1// Copyright (c) 2025 Saorsa Labs Limited
2
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14// This program is distributed in the hope that it will be useful,
15// but WITHOUT ANY WARRANTY; without even the implied warranty of
16// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17// GNU Affero General Public License for more details.
18
19// You should have received a copy of the GNU Affero General Public License
20// along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22//! Comprehensive input validation framework for P2P Foundation
23//!
24//! This module provides a robust validation system for all external inputs,
25//! including network messages, API parameters, file paths, and cryptographic parameters.
26//!
27//! # Features
28//!
29//! - **Type-safe validation traits**: Extensible validation system
30//! - **Rate limiting**: Per-IP and global rate limiting with adaptive throttling
31//! - **Performance optimized**: < 5% overhead for validation operations
32//! - **Security hardened**: Protection against common attack vectors
33//! - **Comprehensive logging**: All validation failures are logged
34//!
35//! # Usage
36//!
37//! ```rust,ignore
38//! use saorsa_core::validation::{Validate, ValidationContext, ValidationError};
39//! use saorsa_core::validation::{validate_peer_id, validate_message_size};
40//!
41//! #[derive(Debug)]
42//! struct NetworkMessage {
43//!     peer_id: String,
44//!     payload: Vec<u8>,
45//! }
46//!
47//! impl Validate for NetworkMessage {
48//!     fn validate(&self, ctx: &ValidationContext) -> Result<(), ValidationError> {
49//!         // Validate peer ID format
50//!         validate_peer_id(&self.peer_id)?;
51//!
52//!         // Validate payload size
53//!         validate_message_size(self.payload.len(), ctx.max_message_size)?;
54//!
55//!         Ok(())
56//!     }
57//! }
58//! ```
59
60use crate::error::{P2PError, P2pResult};
61
62use std::collections::HashMap;
63use std::net::{IpAddr, SocketAddr};
64use std::path::Path;
65use std::sync::Arc;
66use std::time::Duration;
67use thiserror::Error;
68
69// Constants for validation rules
70const MAX_PEER_ID_LENGTH: usize = 64;
71const MIN_PEER_ID_LENGTH: usize = 16;
72const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; // 16MB
73const MAX_PATH_LENGTH: usize = 4096;
74const MAX_KEY_SIZE: usize = 1024 * 1024; // 1MB for DHT keys
75const MAX_VALUE_SIZE: usize = 10 * 1024 * 1024; // 10MB for DHT values
76const MAX_FILE_NAME_LENGTH: usize = 255;
77
78// Rate limiting constants
79const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(60);
80const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 1000;
81const DEFAULT_BURST_SIZE: u32 = 100;
82
83// Validation functions below operate without panicking and avoid global regexes
84
85/// Validation errors specific to input validation
86#[derive(Debug, Error)]
87pub enum ValidationError {
88    #[error("Invalid peer ID format: {0}")]
89    InvalidPeerId(String),
90
91    #[error("Invalid network address: {0}")]
92    InvalidAddress(String),
93
94    #[error("Message size exceeds limit: {size} > {limit}")]
95    MessageTooLarge { size: usize, limit: usize },
96
97    #[error("Invalid file path: {0}")]
98    InvalidPath(String),
99
100    #[error("Path traversal attempt detected: {0}")]
101    PathTraversal(String),
102
103    #[error("Invalid key size: {size} bytes (max: {max})")]
104    InvalidKeySize { size: usize, max: usize },
105
106    #[error("Invalid value size: {size} bytes (max: {max})")]
107    InvalidValueSize { size: usize, max: usize },
108
109    #[error("Invalid cryptographic parameter: {0}")]
110    InvalidCryptoParam(String),
111
112    #[error("Rate limit exceeded for {identifier}")]
113    RateLimitExceeded { identifier: String },
114
115    #[error("Invalid format: {0}")]
116    InvalidFormat(String),
117
118    #[error("Value out of range: {value} (min: {min}, max: {max})")]
119    OutOfRange { value: i64, min: i64, max: i64 },
120}
121
122impl From<ValidationError> for P2PError {
123    fn from(err: ValidationError) -> Self {
124        P2PError::Validation(err.to_string().into())
125    }
126}
127
128/// Context for validation operations
129#[derive(Debug, Clone)]
130pub struct ValidationContext {
131    pub max_message_size: usize,
132    pub max_key_size: usize,
133    pub max_value_size: usize,
134    pub max_path_length: usize,
135    pub allow_localhost: bool,
136    pub allow_private_ips: bool,
137    pub rate_limiter: Option<Arc<RateLimiter>>,
138}
139
140impl Default for ValidationContext {
141    fn default() -> Self {
142        Self {
143            max_message_size: MAX_MESSAGE_SIZE,
144            max_key_size: MAX_KEY_SIZE,
145            max_value_size: MAX_VALUE_SIZE,
146            max_path_length: MAX_PATH_LENGTH,
147            allow_localhost: false,
148            allow_private_ips: false,
149            rate_limiter: None,
150        }
151    }
152}
153
154impl ValidationContext {
155    /// Create a new validation context with custom settings
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    /// Enable rate limiting
161    pub fn with_rate_limiting(mut self, limiter: Arc<RateLimiter>) -> Self {
162        self.rate_limiter = Some(limiter);
163        self
164    }
165
166    /// Allow localhost connections
167    pub fn allow_localhost(mut self) -> Self {
168        self.allow_localhost = true;
169        self
170    }
171
172    /// Allow private IP addresses
173    pub fn allow_private_ips(mut self) -> Self {
174        self.allow_private_ips = true;
175        self
176    }
177}
178
179/// Core validation trait
180pub trait Validate {
181    /// Validate the object with the given context
182    fn validate(&self, ctx: &ValidationContext) -> P2pResult<()>;
183}
184
185/// Trait for sanitizing input
186pub trait Sanitize {
187    /// Sanitize the input, returning a cleaned version
188    fn sanitize(&self) -> Self;
189}
190
191// ===== Network Address Validation =====
192
193/// Validate a network address
194pub fn validate_network_address(addr: &SocketAddr, ctx: &ValidationContext) -> P2pResult<()> {
195    let ip = addr.ip();
196
197    // Check for localhost
198    if ip.is_loopback() && !ctx.allow_localhost {
199        return Err(
200            ValidationError::InvalidAddress("Localhost addresses not allowed".to_string()).into(),
201        );
202    }
203
204    // Check for private IPs
205    if is_private_ip(&ip) && !ctx.allow_private_ips {
206        return Err(ValidationError::InvalidAddress(
207            "Private IP addresses not allowed".to_string(),
208        )
209        .into());
210    }
211
212    // Validate port
213    if addr.port() == 0 {
214        return Err(ValidationError::InvalidAddress("Port 0 is not allowed".to_string()).into());
215    }
216
217    Ok(())
218}
219
220/// Check if an IP is private
221fn is_private_ip(ip: &IpAddr) -> bool {
222    match ip {
223        IpAddr::V4(ipv4) => ipv4.is_private(),
224        IpAddr::V6(ipv6) => ipv6.is_unique_local() || ipv6.is_unicast_link_local(),
225    }
226}
227
228// ===== Peer ID Validation =====
229
230/// Validate a peer ID
231pub fn validate_peer_id(peer_id: &str) -> P2pResult<()> {
232    // Simple length and character set validation; constant-time not required here
233    if peer_id.len() < MIN_PEER_ID_LENGTH || peer_id.len() > MAX_PEER_ID_LENGTH {
234        return Err(ValidationError::InvalidPeerId(format!(
235            "Length must be between {} and {} characters",
236            MIN_PEER_ID_LENGTH, MAX_PEER_ID_LENGTH
237        ))
238        .into());
239    }
240
241    if !peer_id
242        .chars()
243        .all(|ch| ch.is_alphanumeric() || ch == '_' || ch == '-')
244    {
245        return Err(ValidationError::InvalidPeerId(
246            "Must contain only alphanumeric characters, hyphens, and underscores".to_string(),
247        )
248        .into());
249    }
250
251    Ok(())
252}
253
254// ===== Message Size Validation =====
255
256/// Validate message size
257pub fn validate_message_size(size: usize, max_size: usize) -> P2pResult<()> {
258    if size > max_size {
259        return Err(ValidationError::MessageTooLarge {
260            size,
261            limit: max_size,
262        }
263        .into());
264    }
265    Ok(())
266}
267
268// ===== File Path Validation =====
269
270/// Validate a file path for security
271pub fn validate_file_path(path: &Path) -> P2pResult<()> {
272    let path_str = path.to_string_lossy();
273
274    // Check path length
275    if path_str.len() > MAX_PATH_LENGTH {
276        return Err(ValidationError::InvalidPath(format!(
277            "Path too long: {} > {}",
278            path_str.len(),
279            MAX_PATH_LENGTH
280        ))
281        .into());
282    }
283
284    // URL decode to catch encoded traversal attempts
285    let decoded = path_str
286        .replace("%2e", ".")
287        .replace("%2f", "/")
288        .replace("%5c", "\\");
289
290    // Check for path traversal attempts (including encoded versions)
291    let traversal_patterns = ["../", "..\\", "..", "..;", "....//", "%2e%2e", "%252e%252e"];
292    for pattern in &traversal_patterns {
293        if path_str.contains(pattern) || decoded.contains(pattern) {
294            return Err(ValidationError::PathTraversal(path_str.to_string()).into());
295        }
296    }
297
298    // Check for null bytes
299    if path_str.contains('\0') {
300        return Err(ValidationError::InvalidPath("Path contains null bytes".to_string()).into());
301    }
302
303    // Check for command injection characters
304    let dangerous_chars = ['|', '&', ';', '$', '`', '\n'];
305    if path_str.chars().any(|c| dangerous_chars.contains(&c)) {
306        return Err(
307            ValidationError::InvalidPath("Path contains dangerous characters".to_string()).into(),
308        );
309    }
310
311    // Validate each component
312    for component in path.components() {
313        if let Some(name) = component.as_os_str().to_str() {
314            if name.len() > MAX_FILE_NAME_LENGTH {
315                return Err(ValidationError::InvalidPath(format!(
316                    "Component '{}' exceeds maximum length",
317                    name
318                ))
319                .into());
320            }
321
322            // Check for invalid characters
323            if name.contains('\0') {
324                return Err(ValidationError::InvalidPath(format!(
325                    "Component '{}' contains invalid characters",
326                    name
327                ))
328                .into());
329            }
330        }
331    }
332
333    Ok(())
334}
335
336// ===== Cryptographic Parameter Validation =====
337
338/// Validate key size for cryptographic operations
339pub fn validate_key_size(size: usize, expected: usize) -> P2pResult<()> {
340    if size != expected {
341        return Err(ValidationError::InvalidCryptoParam(format!(
342            "Invalid key size: expected {} bytes, got {}",
343            expected, size
344        ))
345        .into());
346    }
347    Ok(())
348}
349
350/// Validate nonce size
351pub fn validate_nonce_size(size: usize, expected: usize) -> P2pResult<()> {
352    if size != expected {
353        return Err(ValidationError::InvalidCryptoParam(format!(
354            "Invalid nonce size: expected {} bytes, got {}",
355            expected, size
356        ))
357        .into());
358    }
359    Ok(())
360}
361
362// ===== DHT Key/Value Validation =====
363
364/// Validate DHT key
365pub fn validate_dht_key(key: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
366    if key.is_empty() {
367        return Err(ValidationError::InvalidFormat("DHT key cannot be empty".to_string()).into());
368    }
369
370    if key.len() > ctx.max_key_size {
371        return Err(ValidationError::InvalidKeySize {
372            size: key.len(),
373            max: ctx.max_key_size,
374        }
375        .into());
376    }
377
378    Ok(())
379}
380
381/// Validate DHT value
382pub fn validate_dht_value(value: &[u8], ctx: &ValidationContext) -> P2pResult<()> {
383    if value.len() > ctx.max_value_size {
384        return Err(ValidationError::InvalidValueSize {
385            size: value.len(),
386            max: ctx.max_value_size,
387        }
388        .into());
389    }
390
391    Ok(())
392}
393
394// ===== Rate Limiting =====
395
396/// Rate limiter for preventing abuse (unified engine)
397#[derive(Debug)]
398pub struct RateLimiter {
399    /// Shared token bucket engine for global and per-IP limiting
400    engine: crate::rate_limit::SharedEngine<IpAddr>,
401    /// Configuration
402    #[allow(dead_code)]
403    config: RateLimitConfig,
404}
405
406/// Rate limit configuration
407#[derive(Debug, Clone)]
408pub struct RateLimitConfig {
409    /// Time window for rate limiting
410    pub window: Duration,
411    /// Maximum requests per window
412    pub max_requests: u32,
413    /// Burst size allowed
414    pub burst_size: u32,
415    /// Enable adaptive throttling
416    pub adaptive: bool,
417    /// Cleanup interval for expired entries
418    pub cleanup_interval: Duration,
419}
420
421impl Default for RateLimitConfig {
422    fn default() -> Self {
423        Self {
424            window: DEFAULT_RATE_LIMIT_WINDOW,
425            max_requests: DEFAULT_MAX_REQUESTS_PER_WINDOW,
426            burst_size: DEFAULT_BURST_SIZE,
427            adaptive: true,
428            cleanup_interval: Duration::from_secs(300), // 5 minutes
429        }
430    }
431}
432
433// Deprecated per-module bucket removed; using crate::rate_limit::Engine instead.
434
435impl RateLimiter {
436    /// Create a new rate limiter
437    pub fn new(config: RateLimitConfig) -> Self {
438        let engine_cfg = crate::rate_limit::EngineConfig {
439            window: config.window,
440            max_requests: config.max_requests,
441            burst_size: config.burst_size,
442        };
443        Self {
444            engine: std::sync::Arc::new(crate::rate_limit::Engine::new(engine_cfg)),
445            config,
446        }
447    }
448
449    /// Check if a request from an IP is allowed
450    pub fn check_ip(&self, ip: &IpAddr) -> P2pResult<()> {
451        // Global limit
452        if !self.engine.try_consume_global() {
453            return Err(ValidationError::RateLimitExceeded {
454                identifier: "global".to_string(),
455            }
456            .into());
457        }
458
459        // Per-IP limit
460        if !self.engine.try_consume_key(ip) {
461            return Err(ValidationError::RateLimitExceeded {
462                identifier: ip.to_string(),
463            }
464            .into());
465        }
466
467        Ok(())
468    }
469
470    /// Clean up expired entries
471    pub fn cleanup(&self) {
472        // Not required with the unified engine (buckets age out via window). No-op.
473    }
474}
475
476// ===== Validation Implementations for Common Types =====
477
478/// Network message validation
479#[derive(Debug)]
480pub struct NetworkMessage {
481    pub peer_id: String,
482    pub payload: Vec<u8>,
483    pub timestamp: u64,
484}
485
486impl Validate for NetworkMessage {
487    fn validate(&self, ctx: &ValidationContext) -> P2pResult<()> {
488        // Validate peer ID
489        validate_peer_id(&self.peer_id)?;
490
491        // Validate payload size
492        validate_message_size(self.payload.len(), ctx.max_message_size)?;
493
494        // Validate timestamp (not too far in future)
495        let now = std::time::SystemTime::now()
496            .duration_since(std::time::UNIX_EPOCH)
497            .map_err(|e| P2PError::Internal(format!("System time error: {}", e).into()))?
498            .as_secs();
499
500        if self.timestamp > now + 300 {
501            // 5 minutes tolerance
502            return Err(
503                ValidationError::InvalidFormat("Timestamp too far in future".to_string()).into(),
504            );
505        }
506
507        Ok(())
508    }
509}
510
511/// API request validation
512#[derive(Debug)]
513pub struct ApiRequest {
514    pub method: String,
515    pub path: String,
516    pub params: HashMap<String, String>,
517}
518
519impl Validate for ApiRequest {
520    fn validate(&self, _ctx: &ValidationContext) -> P2pResult<()> {
521        // Validate method
522        match self.method.as_str() {
523            "GET" | "POST" | "PUT" | "DELETE" => {}
524            _ => {
525                return Err(ValidationError::InvalidFormat(format!(
526                    "Invalid HTTP method: {}",
527                    self.method
528                ))
529                .into());
530            }
531        }
532
533        // Validate path
534        if !self.path.starts_with('/') {
535            return Err(
536                ValidationError::InvalidFormat("Path must start with /".to_string()).into(),
537            );
538        }
539
540        if self.path.contains("..") {
541            return Err(ValidationError::PathTraversal(self.path.clone()).into());
542        }
543
544        // Validate parameters
545        for (key, value) in &self.params {
546            if key.is_empty() {
547                return Err(
548                    ValidationError::InvalidFormat("Empty parameter key".to_string()).into(),
549                );
550            }
551
552            // Check for SQL injection patterns
553            let lower_value = value.to_lowercase();
554            let sql_patterns = [
555                "select ", "insert ", "update ", "delete ", "drop ", "union ", "exec ", "--", "/*",
556                "*/", "'", "\"", " or ", " and ", "1=1", "1='1",
557            ];
558
559            for pattern in &sql_patterns {
560                if lower_value.contains(pattern) {
561                    return Err(ValidationError::InvalidFormat(
562                        "Suspicious parameter value: potential SQL injection".to_string(),
563                    )
564                    .into());
565                }
566            }
567
568            // Check for command injection patterns
569            let dangerous_chars = ['|', '&', ';', '$', '`', '\n', '\0'];
570            if value.chars().any(|c| dangerous_chars.contains(&c)) {
571                return Err(ValidationError::InvalidFormat(
572                    "Dangerous characters in parameter value".to_string(),
573                )
574                .into());
575            }
576        }
577
578        Ok(())
579    }
580}
581
582/// Configuration value validation
583pub fn validate_config_value<T>(value: &str, min: Option<T>, max: Option<T>) -> P2pResult<T>
584where
585    T: std::str::FromStr + PartialOrd + std::fmt::Display,
586{
587    let parsed = value
588        .parse::<T>()
589        .map_err(|_| ValidationError::InvalidFormat(format!("Failed to parse value: {}", value)))?;
590
591    if let Some(min_val) = min
592        && parsed < min_val
593    {
594        return Err(ValidationError::InvalidFormat(format!(
595            "Value {} is less than minimum {}",
596            parsed, min_val
597        ))
598        .into());
599    }
600
601    if let Some(max_val) = max
602        && parsed > max_val
603    {
604        return Err(ValidationError::InvalidFormat(format!(
605            "Value {} is greater than maximum {}",
606            parsed, max_val
607        ))
608        .into());
609    }
610
611    Ok(parsed)
612}
613
614/// Sanitize a string for safe usage
615pub fn sanitize_string(input: &str, max_length: usize) -> String {
616    // First remove any HTML tags and dangerous patterns
617    let mut cleaned = input
618        .replace(['<', '>'], "")
619        .replace("script", "")
620        .replace("javascript:", "")
621        .replace("onerror", "")
622        .replace("onload", "")
623        .replace("onclick", "")
624        .replace("alert", "")
625        .replace("iframe", "");
626
627    // Also handle unicode normalization attacks
628    cleaned = cleaned.replace('\u{2060}', ""); // Word joiner
629    cleaned = cleaned.replace('\u{ffa0}', ""); // Halfwidth hangul filler
630    cleaned = cleaned.replace('\u{200b}', ""); // Zero width space
631    cleaned = cleaned.replace('\u{200c}', ""); // Zero width non-joiner
632    cleaned = cleaned.replace('\u{200d}', ""); // Zero width joiner
633
634    // Finally filter to safe characters (no spaces allowed)
635    cleaned
636        .chars()
637        .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-' || *c == '.')
638        .take(max_length)
639        .collect()
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    #[test]
647    fn test_peer_id_validation() {
648        // Valid peer IDs
649        assert!(validate_peer_id("valid_peer_id_123").is_ok());
650        assert!(validate_peer_id("PEER-ID-WITH-CAPS").is_ok());
651
652        // Invalid peer IDs
653        assert!(validate_peer_id("short").is_err()); // Too short
654        assert!(validate_peer_id(&"x".repeat(100)).is_err()); // Too long
655        assert!(validate_peer_id("invalid peer id").is_err()); // Contains space
656        assert!(validate_peer_id("peer@id").is_err()); // Invalid character
657    }
658
659    #[test]
660    fn test_network_address_validation() {
661        let ctx = ValidationContext::default();
662
663        // Valid addresses
664        let addr: SocketAddr = "8.8.8.8:53".parse().unwrap();
665        assert!(validate_network_address(&addr, &ctx).is_ok());
666
667        // Invalid addresses
668        let localhost: SocketAddr = "127.0.0.1:80".parse().unwrap();
669        assert!(validate_network_address(&localhost, &ctx).is_err());
670
671        // Allow localhost when configured
672        let ctx_localhost = ValidationContext::default().allow_localhost();
673        assert!(validate_network_address(&localhost, &ctx_localhost).is_ok());
674    }
675
676    #[test]
677    fn test_file_path_validation() {
678        // Valid paths
679        assert!(validate_file_path(Path::new("data/file.txt")).is_ok());
680        assert!(validate_file_path(Path::new("/usr/local/bin")).is_ok());
681
682        // Invalid paths
683        assert!(validate_file_path(Path::new("../etc/passwd")).is_err());
684        assert!(validate_file_path(Path::new("file\0name")).is_err());
685    }
686
687    #[test]
688    fn test_rate_limiter() {
689        let config = RateLimitConfig {
690            window: Duration::from_millis(500), // Shorter window for testing
691            max_requests: 10,
692            burst_size: 5,
693            ..Default::default()
694        };
695
696        let limiter = RateLimiter::new(config);
697        let ip: IpAddr = "192.168.1.1".parse().unwrap();
698
699        // Should allow burst
700        for _ in 0..5 {
701            assert!(limiter.check_ip(&ip).is_ok());
702        }
703
704        // Should start rate limiting after burst
705        assert!(limiter.check_ip(&ip).is_err()); // Should be rate limited now
706
707        // After waiting longer than the window, should allow again
708        std::thread::sleep(Duration::from_millis(600));
709        assert!(limiter.check_ip(&ip).is_ok());
710    }
711
712    #[test]
713    fn test_message_validation() {
714        let ctx = ValidationContext::default();
715
716        let valid_msg = NetworkMessage {
717            peer_id: "valid_peer_id_123".to_string(),
718            payload: vec![0u8; 1024],
719            timestamp: std::time::SystemTime::now()
720                .duration_since(std::time::UNIX_EPOCH)
721                .unwrap()
722                .as_secs(),
723        };
724
725        assert!(valid_msg.validate(&ctx).is_ok());
726
727        // Test invalid message
728        let invalid_msg = NetworkMessage {
729            peer_id: "short".to_string(),
730            payload: vec![0u8; 1024],
731            timestamp: 0,
732        };
733
734        assert!(invalid_msg.validate(&ctx).is_err());
735    }
736
737    #[test]
738    fn test_sanitization() {
739        assert_eq!(sanitize_string("hello world!", 20), "helloworld");
740
741        assert_eq!(sanitize_string("test@#$%123", 20), "test123");
742
743        assert_eq!(
744            sanitize_string("very_long_string_that_exceeds_limit", 10),
745            "very_long_"
746        );
747    }
748}