1use crate::error::{Result, SklearsError};
7use crate::types::{Array1, Array2, FloatBounds};
8use std::collections::HashMap;
9
10pub trait Sanitize {
12 fn sanitize(self) -> Result<Self>
14 where
15 Self: Sized;
16
17 fn is_safe(&self) -> bool;
19
20 fn safety_issues(&self) -> Vec<SafetyIssue>;
22}
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum SafetyIssue {
27 ContainsNaN {
29 count: usize,
30 locations: Vec<String>,
31 },
32 ContainsInfinity {
34 count: usize,
35 locations: Vec<String>,
36 },
37 OutOfRange {
39 min_allowed: f64,
40 max_allowed: f64,
41 violations: usize,
42 },
43 InvalidShape {
45 expected: Vec<usize>,
46 actual: Vec<usize>,
47 },
48 EmptyData,
50 SuspiciousPattern {
52 pattern: String,
53 description: String,
54 },
55 UnsafeCharacters { characters: Vec<char> },
57 ExceedsLimits { size: usize, limit: usize },
59}
60
61#[derive(Debug, Clone)]
63pub struct SanitizationConfig {
64 pub remove_nan: bool,
66 pub remove_infinity: bool,
68 pub clamp_values: bool,
70 pub valid_range: Option<(f64, f64)>,
72 pub max_array_size: Option<usize>,
74 pub max_string_length: Option<usize>,
76 pub forbidden_chars: Vec<char>,
78 pub deep_validation: bool,
80}
81
82impl Default for SanitizationConfig {
83 fn default() -> Self {
84 Self {
85 remove_nan: true,
86 remove_infinity: true,
87 clamp_values: false,
88 valid_range: None,
89 max_array_size: Some(1_000_000), max_string_length: Some(1000),
91 forbidden_chars: vec!['\0', '\x01', '\x02', '\x03'],
92 deep_validation: true,
93 }
94 }
95}
96
97#[allow(dead_code)]
99pub struct InputSanitizer {
100 config: SanitizationConfig,
101 validation_cache: std::sync::Mutex<HashMap<String, bool>>,
102}
103
104impl InputSanitizer {
105 pub fn new() -> Self {
107 Self {
108 config: SanitizationConfig::default(),
109 validation_cache: std::sync::Mutex::new(HashMap::new()),
110 }
111 }
112
113 pub fn with_config(config: SanitizationConfig) -> Self {
115 Self {
116 config,
117 validation_cache: std::sync::Mutex::new(HashMap::new()),
118 }
119 }
120
121 pub fn sanitize_array2<T>(&self, array: Array2<T>) -> Result<Array2<T>>
123 where
124 T: FloatBounds + Copy,
125 {
126 if let Some(max_size) = self.config.max_array_size {
128 if array.len() > max_size {
129 return Err(SklearsError::InvalidData {
130 reason: format!("Array size {} exceeds limit {max_size}", array.len()),
131 });
132 }
133 }
134
135 let mut sanitized = array.clone();
136 let mut removed_count = 0;
137
138 for element in sanitized.iter_mut() {
140 if self.config.remove_nan && element.is_nan() {
141 *element = T::zero();
142 removed_count += 1;
143 } else if self.config.remove_infinity && element.is_infinite() {
144 *element = if element.is_sign_positive() {
145 T::from(1e10).unwrap_or(T::one())
146 } else {
147 T::from(-1e10).unwrap_or(-T::one())
148 };
149 removed_count += 1;
150 }
151
152 if let Some((min_val, max_val)) = self.config.valid_range {
154 if self.config.clamp_values {
155 let val = element.to_f64().unwrap_or(0.0);
156 if val < min_val {
157 *element = T::from(min_val).unwrap_or(T::zero());
158 } else if val > max_val {
159 *element = T::from(max_val).unwrap_or(T::one());
160 }
161 }
162 }
163 }
164
165 if removed_count > 0 {
166 log::warn!("Sanitized {removed_count} problematic values in array");
167 }
168
169 Ok(sanitized)
170 }
171
172 pub fn sanitize_array1<T>(&self, array: Array1<T>) -> Result<Array1<T>>
174 where
175 T: FloatBounds + Copy,
176 {
177 if let Some(max_size) = self.config.max_array_size {
179 if array.len() > max_size {
180 return Err(SklearsError::InvalidData {
181 reason: format!("Array size {} exceeds limit {max_size}", array.len()),
182 });
183 }
184 }
185
186 let mut sanitized = array.clone();
187 let mut removed_count = 0;
188
189 for element in sanitized.iter_mut() {
191 if self.config.remove_nan && element.is_nan() {
192 *element = T::zero();
193 removed_count += 1;
194 } else if self.config.remove_infinity && element.is_infinite() {
195 *element = if element.is_sign_positive() {
196 T::from(1e10).unwrap_or(T::one())
197 } else {
198 T::from(-1e10).unwrap_or(-T::one())
199 };
200 removed_count += 1;
201 }
202 }
203
204 if removed_count > 0 {
205 log::warn!("Sanitized {removed_count} problematic values in 1D array");
206 }
207
208 Ok(sanitized)
209 }
210
211 pub fn sanitize_string(&self, input: String) -> Result<String> {
213 if let Some(max_len) = self.config.max_string_length {
215 if input.len() > max_len {
216 return Err(SklearsError::InvalidData {
217 reason: format!("String length {} exceeds limit {}", input.len(), max_len),
218 });
219 }
220 }
221
222 let sanitized = input
224 .chars()
225 .filter(|c| !self.config.forbidden_chars.contains(c))
226 .collect::<String>();
227
228 if self.config.deep_validation {
230 self.check_suspicious_patterns(&sanitized)?;
231 }
232
233 Ok(sanitized)
234 }
235
236 fn check_suspicious_patterns(&self, input: &str) -> Result<()> {
238 let sql_patterns = [
240 "DROP TABLE",
241 "DELETE FROM",
242 "INSERT INTO",
243 "UPDATE SET",
244 "UNION SELECT",
245 ];
246 for pattern in &sql_patterns {
247 if input.to_uppercase().contains(pattern) {
248 return Err(SklearsError::InvalidData {
249 reason: format!("Potentially dangerous SQL pattern detected: {pattern}"),
250 });
251 }
252 }
253
254 let script_patterns = ["<script", "javascript:", "onload=", "onerror="];
256 for pattern in &script_patterns {
257 if input.to_lowercase().contains(pattern) {
258 return Err(SklearsError::InvalidData {
259 reason: format!("Potentially dangerous script pattern detected: {pattern}"),
260 });
261 }
262 }
263
264 if input.contains("../") || input.contains("..\\") {
266 return Err(SklearsError::InvalidData {
267 reason: "Path traversal pattern detected".to_string(),
268 });
269 }
270
271 Ok(())
272 }
273
274 pub fn validate_range<T>(&self, value: T, min: T, max: T) -> Result<()>
276 where
277 T: PartialOrd + std::fmt::Display,
278 {
279 if value < min || value > max {
280 return Err(SklearsError::InvalidParameter {
281 name: "value".to_string(),
282 reason: format!("Value {value} is outside valid range [{min}, {max}]"),
283 });
284 }
285 Ok(())
286 }
287
288 pub fn validate_ml_input<T>(
290 &self,
291 features: &Array2<T>,
292 targets: Option<&Array1<T>>,
293 ) -> Result<()>
294 where
295 T: FloatBounds + std::fmt::Display,
296 {
297 if features.is_empty() {
299 return Err(SklearsError::InvalidData {
300 reason: "Feature array cannot be empty".to_string(),
301 });
302 }
303
304 if features.nrows() == 0 || features.ncols() == 0 {
306 return Err(SklearsError::InvalidData {
307 reason: "Feature array must have positive dimensions".to_string(),
308 });
309 }
310
311 if let Some(targets) = targets {
313 if targets.len() != features.nrows() {
314 return Err(SklearsError::ShapeMismatch {
315 expected: format!("{} target values", features.nrows()),
316 actual: format!("{} target values", targets.len()),
317 });
318 }
319
320 for (i, &value) in targets.iter().enumerate() {
322 if value.is_nan() {
323 return Err(SklearsError::InvalidData {
324 reason: format!("NaN value found in targets at index {i}"),
325 });
326 }
327 if value.is_infinite() {
328 return Err(SklearsError::InvalidData {
329 reason: format!("Infinite value found in targets at index {i}"),
330 });
331 }
332 }
333 }
334
335 let mut nan_count = 0;
337 let mut inf_count = 0;
338
339 for (i, row) in features.outer_iter().enumerate() {
340 for (j, &value) in row.iter().enumerate() {
341 if value.is_nan() {
342 nan_count += 1;
343 if !self.config.remove_nan {
344 return Err(SklearsError::InvalidData {
345 reason: format!("NaN value found in features at position ({i}, {j})"),
346 });
347 }
348 }
349 if value.is_infinite() {
350 inf_count += 1;
351 if !self.config.remove_infinity {
352 return Err(SklearsError::InvalidData {
353 reason: format!(
354 "Infinite value found in features at position ({i}, {j})"
355 ),
356 });
357 }
358 }
359 }
360 }
361
362 if nan_count > 0 || inf_count > 0 {
363 log::warn!("Found {nan_count} NaN and {inf_count} infinite values in features");
364 }
365
366 Ok(())
367 }
368}
369
370impl Default for InputSanitizer {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376impl<T> Sanitize for Array2<T>
378where
379 T: FloatBounds + Copy,
380{
381 fn sanitize(self) -> Result<Self> {
382 let sanitizer = InputSanitizer::new();
383 sanitizer.sanitize_array2(self)
384 }
385
386 fn is_safe(&self) -> bool {
387 self.safety_issues().is_empty()
388 }
389
390 fn safety_issues(&self) -> Vec<SafetyIssue> {
391 let mut issues = Vec::new();
392
393 if self.is_empty() {
395 issues.push(SafetyIssue::EmptyData);
396 return issues;
397 }
398
399 let mut nan_count = 0;
401 let mut inf_count = 0;
402 let mut nan_locations = Vec::new();
403 let mut inf_locations = Vec::new();
404
405 for (i, row) in self.outer_iter().enumerate() {
406 for (j, &value) in row.iter().enumerate() {
407 if value.is_nan() {
408 nan_count += 1;
409 nan_locations.push(format!("({i}, {j})"));
410 }
411 if value.is_infinite() {
412 inf_count += 1;
413 inf_locations.push(format!("({i}, {j})"));
414 }
415 }
416 }
417
418 if nan_count > 0 {
419 issues.push(SafetyIssue::ContainsNaN {
420 count: nan_count,
421 locations: nan_locations,
422 });
423 }
424
425 if inf_count > 0 {
426 issues.push(SafetyIssue::ContainsInfinity {
427 count: inf_count,
428 locations: inf_locations,
429 });
430 }
431
432 if self.len() > 1_000_000 {
434 issues.push(SafetyIssue::ExceedsLimits {
435 size: self.len(),
436 limit: 1_000_000,
437 });
438 }
439
440 issues
441 }
442}
443
444impl<T> Sanitize for Array1<T>
445where
446 T: FloatBounds + Copy,
447{
448 fn sanitize(self) -> Result<Self> {
449 let sanitizer = InputSanitizer::new();
450 sanitizer.sanitize_array1(self)
451 }
452
453 fn is_safe(&self) -> bool {
454 self.safety_issues().is_empty()
455 }
456
457 fn safety_issues(&self) -> Vec<SafetyIssue> {
458 let mut issues = Vec::new();
459
460 if self.is_empty() {
462 issues.push(SafetyIssue::EmptyData);
463 return issues;
464 }
465
466 let mut nan_count = 0;
468 let mut inf_count = 0;
469 let mut nan_locations = Vec::new();
470 let mut inf_locations = Vec::new();
471
472 for (i, &value) in self.iter().enumerate() {
473 if value.is_nan() {
474 nan_count += 1;
475 nan_locations.push(format!("[{i}]"));
476 }
477 if value.is_infinite() {
478 inf_count += 1;
479 inf_locations.push(format!("[{i}]"));
480 }
481 }
482
483 if nan_count > 0 {
484 issues.push(SafetyIssue::ContainsNaN {
485 count: nan_count,
486 locations: nan_locations,
487 });
488 }
489
490 if inf_count > 0 {
491 issues.push(SafetyIssue::ContainsInfinity {
492 count: inf_count,
493 locations: inf_locations,
494 });
495 }
496
497 issues
498 }
499}
500
501impl Sanitize for String {
502 fn sanitize(self) -> Result<Self> {
503 let sanitizer = InputSanitizer::new();
504 sanitizer.sanitize_string(self)
505 }
506
507 fn is_safe(&self) -> bool {
508 self.safety_issues().is_empty()
509 }
510
511 fn safety_issues(&self) -> Vec<SafetyIssue> {
512 let mut issues = Vec::new();
513
514 if self.len() > 1000 {
516 issues.push(SafetyIssue::ExceedsLimits {
517 size: self.len(),
518 limit: 1000,
519 });
520 }
521
522 let forbidden_chars = ['\0', '\x01', '\x02', '\x03'];
524 let found_chars: Vec<char> = self
525 .chars()
526 .filter(|c| forbidden_chars.contains(c))
527 .collect();
528
529 if !found_chars.is_empty() {
530 issues.push(SafetyIssue::UnsafeCharacters {
531 characters: found_chars,
532 });
533 }
534
535 let dangerous_patterns = [
537 ("SQL_INJECTION", "DROP TABLE"),
538 ("SCRIPT_INJECTION", "<script"),
539 ("PATH_TRAVERSAL", "../"),
540 ];
541
542 for (pattern_type, pattern) in &dangerous_patterns {
543 if self.to_lowercase().contains(&pattern.to_lowercase()) {
544 issues.push(SafetyIssue::SuspiciousPattern {
545 pattern: pattern_type.to_string(),
546 description: format!("Contains potentially dangerous pattern: {pattern}"),
547 });
548 }
549 }
550
551 issues
552 }
553}
554
555pub fn sanitize_ml_data<T>(
558 features: Array2<T>,
559 targets: Option<Array1<T>>,
560) -> Result<(Array2<T>, Option<Array1<T>>)>
561where
562 T: FloatBounds + Copy,
563{
564 let sanitizer = InputSanitizer::new();
565
566 sanitizer.validate_ml_input(&features, targets.as_ref())?;
568
569 let clean_features = sanitizer.sanitize_array2(features)?;
571
572 let clean_targets = if let Some(targets) = targets {
574 Some(sanitizer.sanitize_array1(targets)?)
575 } else {
576 None
577 };
578
579 Ok((clean_features, clean_targets))
580}
581
582pub fn is_ml_data_safe<T>(features: &Array2<T>, targets: Option<&Array1<T>>) -> bool
584where
585 T: FloatBounds + Copy,
586{
587 features.is_safe() && targets.map_or(true, |t| t.is_safe())
588}
589
590#[allow(non_snake_case)]
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use crate::types::Array2;
595
596 #[test]
597 fn test_array_sanitization() {
598 let mut array: Array2<f64> = Array2::zeros((2, 3));
599 array[[0, 0]] = f64::NAN;
600 array[[1, 1]] = f64::INFINITY;
601
602 assert!(!array.is_safe());
603 let issues = array.safety_issues();
604 assert!(!issues.is_empty());
605
606 let sanitized = array.sanitize().unwrap();
607 assert!(sanitized.is_safe());
608 }
609
610 #[test]
611 fn test_string_sanitization() {
612 let dangerous_string = "Hello\0World<script>alert('xss')</script>".to_string();
613
614 assert!(!dangerous_string.is_safe());
615 let issues = dangerous_string.safety_issues();
616 assert!(!issues.is_empty());
617
618 assert!(dangerous_string.sanitize().is_err());
620
621 let string_with_forbidden_chars = "Hello\0World".to_string();
623 let sanitized = string_with_forbidden_chars.sanitize().unwrap();
624 assert!(!sanitized.contains('\0'));
625 }
626
627 #[test]
628 fn test_ml_data_validation() {
629 let features: Array2<f64> = Array2::zeros((100, 5));
630 let targets: Array1<f64> = Array1::zeros(100);
631
632 let sanitizer = InputSanitizer::new();
633 assert!(sanitizer
634 .validate_ml_input(&features, Some(&targets))
635 .is_ok());
636
637 let bad_targets: Array1<f64> = Array1::zeros(50);
639 assert!(sanitizer
640 .validate_ml_input(&features, Some(&bad_targets))
641 .is_err());
642 }
643
644 #[test]
645 fn test_sanitization_config() {
646 let mut config = SanitizationConfig::default();
647 config.max_string_length = Some(10);
648
649 let sanitizer = InputSanitizer::with_config(config);
650 let long_string = "This is a very long string that exceeds the limit".to_string();
651
652 assert!(sanitizer.sanitize_string(long_string).is_err());
653 }
654
655 #[test]
656 fn test_range_validation() {
657 let sanitizer = InputSanitizer::new();
658
659 assert!(sanitizer.validate_range(5.0, 0.0, 10.0).is_ok());
660 assert!(sanitizer.validate_range(-1.0, 0.0, 10.0).is_err());
661 assert!(sanitizer.validate_range(15.0, 0.0, 10.0).is_err());
662 }
663}