sklears_core/
input_sanitization.rs

1/// Input sanitization for untrusted data
2///
3/// This module provides comprehensive input sanitization and validation for data
4/// coming from untrusted sources to prevent security vulnerabilities and ensure
5/// data integrity in machine learning workflows.
6use crate::error::{Result, SklearsError};
7use crate::types::{Array1, Array2, FloatBounds};
8use std::collections::HashMap;
9
10/// Trait for sanitizing input data
11pub trait Sanitize {
12    /// Sanitize the input and return a cleaned version
13    fn sanitize(self) -> Result<Self>
14    where
15        Self: Sized;
16
17    /// Check if input is safe without modifying it
18    fn is_safe(&self) -> bool;
19
20    /// Get detailed information about safety issues
21    fn safety_issues(&self) -> Vec<SafetyIssue>;
22}
23
24/// Types of safety issues that can be found in input data
25#[derive(Debug, Clone, PartialEq)]
26pub enum SafetyIssue {
27    /// Contains NaN values
28    ContainsNaN {
29        count: usize,
30        locations: Vec<String>,
31    },
32    /// Contains infinite values
33    ContainsInfinity {
34        count: usize,
35        locations: Vec<String>,
36    },
37    /// Values outside acceptable range
38    OutOfRange {
39        min_allowed: f64,
40        max_allowed: f64,
41        violations: usize,
42    },
43    /// Array shape is invalid
44    InvalidShape {
45        expected: Vec<usize>,
46        actual: Vec<usize>,
47    },
48    /// Empty data where data is required
49    EmptyData,
50    /// Suspicious patterns that might indicate attacks
51    SuspiciousPattern {
52        pattern: String,
53        description: String,
54    },
55    /// String contains potentially dangerous characters
56    UnsafeCharacters { characters: Vec<char> },
57    /// Data size exceeds limits
58    ExceedsLimits { size: usize, limit: usize },
59}
60
61/// Configuration for input sanitization
62#[derive(Debug, Clone)]
63pub struct SanitizationConfig {
64    /// Whether to remove NaN values
65    pub remove_nan: bool,
66    /// Whether to remove infinite values  
67    pub remove_infinity: bool,
68    /// Whether to clamp values to valid ranges
69    pub clamp_values: bool,
70    /// Valid range for numeric values
71    pub valid_range: Option<(f64, f64)>,
72    /// Maximum allowed array size
73    pub max_array_size: Option<usize>,
74    /// Maximum string length
75    pub max_string_length: Option<usize>,
76    /// Characters that are not allowed in strings
77    pub forbidden_chars: Vec<char>,
78    /// Whether to perform deep validation
79    pub deep_validation: bool,
80}
81
82impl Default for SanitizationConfig {
83    fn default() -> Self {
84        Self {
85            remove_nan: true,
86            remove_infinity: true,
87            clamp_values: false,
88            valid_range: None,
89            max_array_size: Some(1_000_000), // 1M elements
90            max_string_length: Some(1000),
91            forbidden_chars: vec!['\0', '\x01', '\x02', '\x03'],
92            deep_validation: true,
93        }
94    }
95}
96
97/// Input sanitizer with configurable policies
98#[allow(dead_code)]
99pub struct InputSanitizer {
100    config: SanitizationConfig,
101    validation_cache: std::sync::Mutex<HashMap<String, bool>>,
102}
103
104impl InputSanitizer {
105    /// Create a new input sanitizer with default configuration
106    pub fn new() -> Self {
107        Self {
108            config: SanitizationConfig::default(),
109            validation_cache: std::sync::Mutex::new(HashMap::new()),
110        }
111    }
112
113    /// Create a new input sanitizer with custom configuration
114    pub fn with_config(config: SanitizationConfig) -> Self {
115        Self {
116            config,
117            validation_cache: std::sync::Mutex::new(HashMap::new()),
118        }
119    }
120
121    /// Sanitize a 2D array
122    pub fn sanitize_array2<T>(&self, array: Array2<T>) -> Result<Array2<T>>
123    where
124        T: FloatBounds + Copy,
125    {
126        // Check size limits
127        if let Some(max_size) = self.config.max_array_size {
128            if array.len() > max_size {
129                return Err(SklearsError::InvalidData {
130                    reason: format!("Array size {} exceeds limit {max_size}", array.len()),
131                });
132            }
133        }
134
135        let mut sanitized = array.clone();
136        let mut removed_count = 0;
137
138        // Check for NaN and infinity values
139        for element in sanitized.iter_mut() {
140            if self.config.remove_nan && element.is_nan() {
141                *element = T::zero();
142                removed_count += 1;
143            } else if self.config.remove_infinity && element.is_infinite() {
144                *element = if element.is_sign_positive() {
145                    T::from(1e10).unwrap_or(T::one())
146                } else {
147                    T::from(-1e10).unwrap_or(-T::one())
148                };
149                removed_count += 1;
150            }
151
152            // Clamp values if configured
153            if let Some((min_val, max_val)) = self.config.valid_range {
154                if self.config.clamp_values {
155                    let val = element.to_f64().unwrap_or(0.0);
156                    if val < min_val {
157                        *element = T::from(min_val).unwrap_or(T::zero());
158                    } else if val > max_val {
159                        *element = T::from(max_val).unwrap_or(T::one());
160                    }
161                }
162            }
163        }
164
165        if removed_count > 0 {
166            log::warn!("Sanitized {removed_count} problematic values in array");
167        }
168
169        Ok(sanitized)
170    }
171
172    /// Sanitize a 1D array
173    pub fn sanitize_array1<T>(&self, array: Array1<T>) -> Result<Array1<T>>
174    where
175        T: FloatBounds + Copy,
176    {
177        // Check size limits
178        if let Some(max_size) = self.config.max_array_size {
179            if array.len() > max_size {
180                return Err(SklearsError::InvalidData {
181                    reason: format!("Array size {} exceeds limit {max_size}", array.len()),
182                });
183            }
184        }
185
186        let mut sanitized = array.clone();
187        let mut removed_count = 0;
188
189        // Check for NaN and infinity values
190        for element in sanitized.iter_mut() {
191            if self.config.remove_nan && element.is_nan() {
192                *element = T::zero();
193                removed_count += 1;
194            } else if self.config.remove_infinity && element.is_infinite() {
195                *element = if element.is_sign_positive() {
196                    T::from(1e10).unwrap_or(T::one())
197                } else {
198                    T::from(-1e10).unwrap_or(-T::one())
199                };
200                removed_count += 1;
201            }
202        }
203
204        if removed_count > 0 {
205            log::warn!("Sanitized {removed_count} problematic values in 1D array");
206        }
207
208        Ok(sanitized)
209    }
210
211    /// Sanitize a string input
212    pub fn sanitize_string(&self, input: String) -> Result<String> {
213        // Check length limits
214        if let Some(max_len) = self.config.max_string_length {
215            if input.len() > max_len {
216                return Err(SklearsError::InvalidData {
217                    reason: format!("String length {} exceeds limit {}", input.len(), max_len),
218                });
219            }
220        }
221
222        // Remove forbidden characters
223        let sanitized = input
224            .chars()
225            .filter(|c| !self.config.forbidden_chars.contains(c))
226            .collect::<String>();
227
228        // Check for suspicious patterns
229        if self.config.deep_validation {
230            self.check_suspicious_patterns(&sanitized)?;
231        }
232
233        Ok(sanitized)
234    }
235
236    /// Check for suspicious patterns in strings
237    fn check_suspicious_patterns(&self, input: &str) -> Result<()> {
238        // Check for potential SQL injection patterns
239        let sql_patterns = [
240            "DROP TABLE",
241            "DELETE FROM",
242            "INSERT INTO",
243            "UPDATE SET",
244            "UNION SELECT",
245        ];
246        for pattern in &sql_patterns {
247            if input.to_uppercase().contains(pattern) {
248                return Err(SklearsError::InvalidData {
249                    reason: format!("Potentially dangerous SQL pattern detected: {pattern}"),
250                });
251            }
252        }
253
254        // Check for script injection patterns
255        let script_patterns = ["<script", "javascript:", "onload=", "onerror="];
256        for pattern in &script_patterns {
257            if input.to_lowercase().contains(pattern) {
258                return Err(SklearsError::InvalidData {
259                    reason: format!("Potentially dangerous script pattern detected: {pattern}"),
260                });
261            }
262        }
263
264        // Check for path traversal patterns
265        if input.contains("../") || input.contains("..\\") {
266            return Err(SklearsError::InvalidData {
267                reason: "Path traversal pattern detected".to_string(),
268            });
269        }
270
271        Ok(())
272    }
273
274    /// Validate numeric input ranges
275    pub fn validate_range<T>(&self, value: T, min: T, max: T) -> Result<()>
276    where
277        T: PartialOrd + std::fmt::Display,
278    {
279        if value < min || value > max {
280            return Err(SklearsError::InvalidParameter {
281                name: "value".to_string(),
282                reason: format!("Value {value} is outside valid range [{min}, {max}]"),
283            });
284        }
285        Ok(())
286    }
287
288    /// Comprehensive input validation
289    pub fn validate_ml_input<T>(
290        &self,
291        features: &Array2<T>,
292        targets: Option<&Array1<T>>,
293    ) -> Result<()>
294    where
295        T: FloatBounds + std::fmt::Display,
296    {
297        // Check if features array is empty
298        if features.is_empty() {
299            return Err(SklearsError::InvalidData {
300                reason: "Feature array cannot be empty".to_string(),
301            });
302        }
303
304        // Check for invalid dimensions
305        if features.nrows() == 0 || features.ncols() == 0 {
306            return Err(SklearsError::InvalidData {
307                reason: "Feature array must have positive dimensions".to_string(),
308            });
309        }
310
311        // Check targets if provided
312        if let Some(targets) = targets {
313            if targets.len() != features.nrows() {
314                return Err(SklearsError::ShapeMismatch {
315                    expected: format!("{} target values", features.nrows()),
316                    actual: format!("{} target values", targets.len()),
317                });
318            }
319
320            // Check for problematic values in targets
321            for (i, &value) in targets.iter().enumerate() {
322                if value.is_nan() {
323                    return Err(SklearsError::InvalidData {
324                        reason: format!("NaN value found in targets at index {i}"),
325                    });
326                }
327                if value.is_infinite() {
328                    return Err(SklearsError::InvalidData {
329                        reason: format!("Infinite value found in targets at index {i}"),
330                    });
331                }
332            }
333        }
334
335        // Check for problematic values in features
336        let mut nan_count = 0;
337        let mut inf_count = 0;
338
339        for (i, row) in features.outer_iter().enumerate() {
340            for (j, &value) in row.iter().enumerate() {
341                if value.is_nan() {
342                    nan_count += 1;
343                    if !self.config.remove_nan {
344                        return Err(SklearsError::InvalidData {
345                            reason: format!("NaN value found in features at position ({i}, {j})"),
346                        });
347                    }
348                }
349                if value.is_infinite() {
350                    inf_count += 1;
351                    if !self.config.remove_infinity {
352                        return Err(SklearsError::InvalidData {
353                            reason: format!(
354                                "Infinite value found in features at position ({i}, {j})"
355                            ),
356                        });
357                    }
358                }
359            }
360        }
361
362        if nan_count > 0 || inf_count > 0 {
363            log::warn!("Found {nan_count} NaN and {inf_count} infinite values in features");
364        }
365
366        Ok(())
367    }
368}
369
370impl Default for InputSanitizer {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376/// Implementations of Sanitize trait for common types
377impl<T> Sanitize for Array2<T>
378where
379    T: FloatBounds + Copy,
380{
381    fn sanitize(self) -> Result<Self> {
382        let sanitizer = InputSanitizer::new();
383        sanitizer.sanitize_array2(self)
384    }
385
386    fn is_safe(&self) -> bool {
387        self.safety_issues().is_empty()
388    }
389
390    fn safety_issues(&self) -> Vec<SafetyIssue> {
391        let mut issues = Vec::new();
392
393        // Check for empty data
394        if self.is_empty() {
395            issues.push(SafetyIssue::EmptyData);
396            return issues;
397        }
398
399        // Check for NaN and infinity values
400        let mut nan_count = 0;
401        let mut inf_count = 0;
402        let mut nan_locations = Vec::new();
403        let mut inf_locations = Vec::new();
404
405        for (i, row) in self.outer_iter().enumerate() {
406            for (j, &value) in row.iter().enumerate() {
407                if value.is_nan() {
408                    nan_count += 1;
409                    nan_locations.push(format!("({i}, {j})"));
410                }
411                if value.is_infinite() {
412                    inf_count += 1;
413                    inf_locations.push(format!("({i}, {j})"));
414                }
415            }
416        }
417
418        if nan_count > 0 {
419            issues.push(SafetyIssue::ContainsNaN {
420                count: nan_count,
421                locations: nan_locations,
422            });
423        }
424
425        if inf_count > 0 {
426            issues.push(SafetyIssue::ContainsInfinity {
427                count: inf_count,
428                locations: inf_locations,
429            });
430        }
431
432        // Check size limits
433        if self.len() > 1_000_000 {
434            issues.push(SafetyIssue::ExceedsLimits {
435                size: self.len(),
436                limit: 1_000_000,
437            });
438        }
439
440        issues
441    }
442}
443
444impl<T> Sanitize for Array1<T>
445where
446    T: FloatBounds + Copy,
447{
448    fn sanitize(self) -> Result<Self> {
449        let sanitizer = InputSanitizer::new();
450        sanitizer.sanitize_array1(self)
451    }
452
453    fn is_safe(&self) -> bool {
454        self.safety_issues().is_empty()
455    }
456
457    fn safety_issues(&self) -> Vec<SafetyIssue> {
458        let mut issues = Vec::new();
459
460        // Check for empty data
461        if self.is_empty() {
462            issues.push(SafetyIssue::EmptyData);
463            return issues;
464        }
465
466        // Check for NaN and infinity values
467        let mut nan_count = 0;
468        let mut inf_count = 0;
469        let mut nan_locations = Vec::new();
470        let mut inf_locations = Vec::new();
471
472        for (i, &value) in self.iter().enumerate() {
473            if value.is_nan() {
474                nan_count += 1;
475                nan_locations.push(format!("[{i}]"));
476            }
477            if value.is_infinite() {
478                inf_count += 1;
479                inf_locations.push(format!("[{i}]"));
480            }
481        }
482
483        if nan_count > 0 {
484            issues.push(SafetyIssue::ContainsNaN {
485                count: nan_count,
486                locations: nan_locations,
487            });
488        }
489
490        if inf_count > 0 {
491            issues.push(SafetyIssue::ContainsInfinity {
492                count: inf_count,
493                locations: inf_locations,
494            });
495        }
496
497        issues
498    }
499}
500
501impl Sanitize for String {
502    fn sanitize(self) -> Result<Self> {
503        let sanitizer = InputSanitizer::new();
504        sanitizer.sanitize_string(self)
505    }
506
507    fn is_safe(&self) -> bool {
508        self.safety_issues().is_empty()
509    }
510
511    fn safety_issues(&self) -> Vec<SafetyIssue> {
512        let mut issues = Vec::new();
513
514        // Check length
515        if self.len() > 1000 {
516            issues.push(SafetyIssue::ExceedsLimits {
517                size: self.len(),
518                limit: 1000,
519            });
520        }
521
522        // Check for forbidden characters
523        let forbidden_chars = ['\0', '\x01', '\x02', '\x03'];
524        let found_chars: Vec<char> = self
525            .chars()
526            .filter(|c| forbidden_chars.contains(c))
527            .collect();
528
529        if !found_chars.is_empty() {
530            issues.push(SafetyIssue::UnsafeCharacters {
531                characters: found_chars,
532            });
533        }
534
535        // Check for suspicious patterns
536        let dangerous_patterns = [
537            ("SQL_INJECTION", "DROP TABLE"),
538            ("SCRIPT_INJECTION", "<script"),
539            ("PATH_TRAVERSAL", "../"),
540        ];
541
542        for (pattern_type, pattern) in &dangerous_patterns {
543            if self.to_lowercase().contains(&pattern.to_lowercase()) {
544                issues.push(SafetyIssue::SuspiciousPattern {
545                    pattern: pattern_type.to_string(),
546                    description: format!("Contains potentially dangerous pattern: {pattern}"),
547                });
548            }
549        }
550
551        issues
552    }
553}
554
555/// Convenience functions for quick sanitization
556/// Sanitize machine learning input data
557pub fn sanitize_ml_data<T>(
558    features: Array2<T>,
559    targets: Option<Array1<T>>,
560) -> Result<(Array2<T>, Option<Array1<T>>)>
561where
562    T: FloatBounds + Copy,
563{
564    let sanitizer = InputSanitizer::new();
565
566    // Validate first
567    sanitizer.validate_ml_input(&features, targets.as_ref())?;
568
569    // Sanitize features
570    let clean_features = sanitizer.sanitize_array2(features)?;
571
572    // Sanitize targets if provided
573    let clean_targets = if let Some(targets) = targets {
574        Some(sanitizer.sanitize_array1(targets)?)
575    } else {
576        None
577    };
578
579    Ok((clean_features, clean_targets))
580}
581
582/// Quick safety check for ML data
583pub fn is_ml_data_safe<T>(features: &Array2<T>, targets: Option<&Array1<T>>) -> bool
584where
585    T: FloatBounds + Copy,
586{
587    features.is_safe() && targets.map_or(true, |t| t.is_safe())
588}
589
590#[allow(non_snake_case)]
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::types::Array2;
595
596    #[test]
597    fn test_array_sanitization() {
598        let mut array: Array2<f64> = Array2::zeros((2, 3));
599        array[[0, 0]] = f64::NAN;
600        array[[1, 1]] = f64::INFINITY;
601
602        assert!(!array.is_safe());
603        let issues = array.safety_issues();
604        assert!(!issues.is_empty());
605
606        let sanitized = array.sanitize().unwrap();
607        assert!(sanitized.is_safe());
608    }
609
610    #[test]
611    fn test_string_sanitization() {
612        let dangerous_string = "Hello\0World<script>alert('xss')</script>".to_string();
613
614        assert!(!dangerous_string.is_safe());
615        let issues = dangerous_string.safety_issues();
616        assert!(!issues.is_empty());
617
618        // This should fail due to dangerous patterns
619        assert!(dangerous_string.sanitize().is_err());
620
621        // Test a string with only forbidden characters (no dangerous patterns)
622        let string_with_forbidden_chars = "Hello\0World".to_string();
623        let sanitized = string_with_forbidden_chars.sanitize().unwrap();
624        assert!(!sanitized.contains('\0'));
625    }
626
627    #[test]
628    fn test_ml_data_validation() {
629        let features: Array2<f64> = Array2::zeros((100, 5));
630        let targets: Array1<f64> = Array1::zeros(100);
631
632        let sanitizer = InputSanitizer::new();
633        assert!(sanitizer
634            .validate_ml_input(&features, Some(&targets))
635            .is_ok());
636
637        // Test mismatched dimensions
638        let bad_targets: Array1<f64> = Array1::zeros(50);
639        assert!(sanitizer
640            .validate_ml_input(&features, Some(&bad_targets))
641            .is_err());
642    }
643
644    #[test]
645    fn test_sanitization_config() {
646        let mut config = SanitizationConfig::default();
647        config.max_string_length = Some(10);
648
649        let sanitizer = InputSanitizer::with_config(config);
650        let long_string = "This is a very long string that exceeds the limit".to_string();
651
652        assert!(sanitizer.sanitize_string(long_string).is_err());
653    }
654
655    #[test]
656    fn test_range_validation() {
657        let sanitizer = InputSanitizer::new();
658
659        assert!(sanitizer.validate_range(5.0, 0.0, 10.0).is_ok());
660        assert!(sanitizer.validate_range(-1.0, 0.0, 10.0).is_err());
661        assert!(sanitizer.validate_range(15.0, 0.0, 10.0).is_err());
662    }
663}