1use crate::imputation::OutlierAwareImputer;
18use crate::outlier_detection::{OutlierDetectionMethod, OutlierDetector};
19use crate::outlier_transformation::{OutlierTransformationMethod, OutlierTransformer};
20use crate::scaling::RobustScaler;
21use scirs2_core::ndarray::Array2;
22use sklears_core::{
23 error::{Result, SklearsError},
24 traits::{Fit, Trained, Transform, Untrained},
25 types::Float,
26};
27use std::marker::PhantomData;
28
29#[derive(Debug, Clone, Copy)]
31pub enum RobustStrategy {
32 Conservative,
34 Moderate,
36 Aggressive,
38 Custom,
40}
41
42#[derive(Debug, Clone)]
44pub struct RobustPreprocessorConfig {
45 pub strategy: RobustStrategy,
47 pub enable_outlier_detection: bool,
49 pub enable_outlier_transformation: bool,
51 pub enable_outlier_imputation: bool,
53 pub enable_robust_scaling: bool,
55 pub outlier_threshold: Option<Float>,
57 pub detection_method: OutlierDetectionMethod,
59 pub transformation_method: OutlierTransformationMethod,
61 pub contamination_rate: Float,
63 pub adaptive_thresholds: bool,
65 pub quantile_range: (Float, Float),
67 pub with_centering: bool,
69 pub with_scaling: bool,
71 pub parallel: bool,
73}
74
75impl Default for RobustPreprocessorConfig {
76 fn default() -> Self {
77 Self {
78 strategy: RobustStrategy::Moderate,
79 enable_outlier_detection: true,
80 enable_outlier_transformation: true,
81 enable_outlier_imputation: true,
82 enable_robust_scaling: true,
83 outlier_threshold: None, detection_method: OutlierDetectionMethod::MahalanobisDistance,
85 transformation_method: OutlierTransformationMethod::Log1p,
86 contamination_rate: 0.1,
87 adaptive_thresholds: true,
88 quantile_range: (25.0, 75.0),
89 with_centering: true,
90 with_scaling: true,
91 parallel: true,
92 }
93 }
94}
95
96impl RobustPreprocessorConfig {
97 pub fn conservative() -> Self {
99 Self {
100 strategy: RobustStrategy::Conservative,
101 outlier_threshold: Some(3.0),
102 contamination_rate: 0.05,
103 adaptive_thresholds: false,
104 enable_outlier_transformation: false,
105 transformation_method: OutlierTransformationMethod::RobustScale,
106 ..Self::default()
107 }
108 }
109
110 pub fn moderate() -> Self {
112 Self {
113 strategy: RobustStrategy::Moderate,
114 outlier_threshold: Some(2.5),
115 contamination_rate: 0.1,
116 adaptive_thresholds: true,
117 transformation_method: OutlierTransformationMethod::Log1p,
118 ..Self::default()
119 }
120 }
121
122 pub fn aggressive() -> Self {
124 Self {
125 strategy: RobustStrategy::Aggressive,
126 outlier_threshold: Some(2.0),
127 contamination_rate: 0.15,
128 adaptive_thresholds: true,
129 transformation_method: OutlierTransformationMethod::BoxCox,
130 ..Self::default()
131 }
132 }
133
134 pub fn custom() -> Self {
136 Self {
137 strategy: RobustStrategy::Custom,
138 adaptive_thresholds: true,
139 ..Self::default()
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct RobustPreprocessor<State = Untrained> {
147 config: RobustPreprocessorConfig,
148 state: PhantomData<State>,
149 outlier_detector_: Option<OutlierDetector<Trained>>,
151 outlier_transformer_: Option<OutlierTransformer<Trained>>,
152 outlier_imputer_: Option<OutlierAwareImputer>,
153 robust_scaler_: Option<RobustScaler>,
154 preprocessing_stats_: Option<RobustPreprocessingStats>,
156 n_features_in_: Option<usize>,
157}
158
159#[derive(Debug, Clone)]
161pub struct RobustPreprocessingStats {
162 pub outliers_per_feature: Vec<usize>,
164 pub outlier_percentages: Vec<Float>,
166 pub adaptive_thresholds: Vec<Float>,
168 pub robustness_score: Float,
170 pub missing_stats: MissingValueStats,
172 pub transformation_stats: TransformationStats,
174 pub quality_improvement: Float,
176}
177
178#[derive(Debug, Clone)]
180pub struct MissingValueStats {
181 pub missing_before: usize,
182 pub missing_after: usize,
183 pub imputation_success_rate: Float,
184}
185
186#[derive(Debug, Clone)]
188pub struct TransformationStats {
189 pub skewness_reduction: Vec<Float>,
191 pub kurtosis_reduction: Vec<Float>,
193 pub normality_improvement: Vec<Float>,
195}
196
197impl RobustPreprocessor<Untrained> {
198 pub fn new() -> Self {
200 Self {
201 config: RobustPreprocessorConfig::default(),
202 state: PhantomData,
203 outlier_detector_: None,
204 outlier_transformer_: None,
205 outlier_imputer_: None,
206 robust_scaler_: None,
207 preprocessing_stats_: None,
208 n_features_in_: None,
209 }
210 }
211
212 pub fn conservative() -> Self {
214 Self::new().config(RobustPreprocessorConfig::conservative())
215 }
216
217 pub fn moderate() -> Self {
219 Self::new().config(RobustPreprocessorConfig::moderate())
220 }
221
222 pub fn aggressive() -> Self {
224 Self::new().config(RobustPreprocessorConfig::aggressive())
225 }
226
227 pub fn custom() -> Self {
229 Self::new().config(RobustPreprocessorConfig::custom())
230 }
231
232 pub fn config(mut self, config: RobustPreprocessorConfig) -> Self {
234 self.config = config;
235 self
236 }
237
238 pub fn outlier_detection(mut self, enable: bool) -> Self {
240 self.config.enable_outlier_detection = enable;
241 self
242 }
243
244 pub fn outlier_transformation(mut self, enable: bool) -> Self {
246 self.config.enable_outlier_transformation = enable;
247 self
248 }
249
250 pub fn outlier_imputation(mut self, enable: bool) -> Self {
252 self.config.enable_outlier_imputation = enable;
253 self
254 }
255
256 pub fn robust_scaling(mut self, enable: bool) -> Self {
258 self.config.enable_robust_scaling = enable;
259 self
260 }
261
262 pub fn detection_method(mut self, method: OutlierDetectionMethod) -> Self {
264 self.config.detection_method = method;
265 self
266 }
267
268 pub fn transformation_method(mut self, method: OutlierTransformationMethod) -> Self {
270 self.config.transformation_method = method;
271 self
272 }
273
274 pub fn outlier_threshold(mut self, threshold: Float) -> Self {
276 self.config.outlier_threshold = Some(threshold);
277 self.config.adaptive_thresholds = false;
278 self
279 }
280
281 pub fn adaptive_thresholds(mut self, enable: bool) -> Self {
283 self.config.adaptive_thresholds = enable;
284 if enable {
285 self.config.outlier_threshold = None;
286 }
287 self
288 }
289
290 pub fn contamination_rate(mut self, rate: Float) -> Self {
292 self.config.contamination_rate = rate;
293 self
294 }
295
296 pub fn quantile_range(mut self, range: (Float, Float)) -> Self {
298 self.config.quantile_range = range;
299 self
300 }
301
302 pub fn with_centering(mut self, center: bool) -> Self {
304 self.config.with_centering = center;
305 self
306 }
307
308 pub fn with_scaling(mut self, scale: bool) -> Self {
310 self.config.with_scaling = scale;
311 self
312 }
313
314 pub fn parallel(mut self, enable: bool) -> Self {
316 self.config.parallel = enable;
317 self
318 }
319}
320
321impl Fit<Array2<Float>, ()> for RobustPreprocessor<Untrained> {
322 type Fitted = RobustPreprocessor<Trained>;
323
324 fn fit(mut self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
325 let (n_samples, n_features) = x.dim();
326
327 if n_samples == 0 || n_features == 0 {
328 return Err(SklearsError::InvalidInput(
329 "Input array is empty".to_string(),
330 ));
331 }
332
333 self.n_features_in_ = Some(n_features);
334
335 let mut stats = RobustPreprocessingStats {
337 outliers_per_feature: vec![0; n_features],
338 outlier_percentages: vec![0.0; n_features],
339 adaptive_thresholds: vec![0.0; n_features],
340 robustness_score: 0.0,
341 missing_stats: MissingValueStats {
342 missing_before: 0,
343 missing_after: 0,
344 imputation_success_rate: 0.0,
345 },
346 transformation_stats: TransformationStats {
347 skewness_reduction: vec![0.0; n_features],
348 kurtosis_reduction: vec![0.0; n_features],
349 normality_improvement: vec![0.0; n_features],
350 },
351 quality_improvement: 0.0,
352 };
353
354 stats.missing_stats.missing_before = x.iter().filter(|x| x.is_nan()).count();
356
357 let mut current_data = x.clone();
358
359 if self.config.enable_outlier_imputation {
361 let threshold = self.get_adaptive_threshold(¤t_data, 0.5)?;
362
363 let _imputer = OutlierAwareImputer::exclude_outliers(threshold, "mad")?
364 .base_strategy(crate::imputation::ImputationStrategy::Median);
365
366 for j in 0..current_data.ncols() {
369 let mut column: Vec<Float> = current_data.column(j).to_vec();
370 column.retain(|x| !x.is_nan()); if !column.is_empty() {
372 column
373 .sort_by(|a, b| a.partial_cmp(b).expect("matrix indexing should be valid"));
374 let median = column[column.len() / 2];
375
376 for i in 0..current_data.nrows() {
378 if current_data[[i, j]].is_nan() {
379 current_data[[i, j]] = median;
380 }
381 }
382 }
383 }
384
385 stats.missing_stats.missing_after = current_data.iter().filter(|x| x.is_nan()).count();
389 stats.missing_stats.imputation_success_rate = 1.0
390 - (stats.missing_stats.missing_after as Float
391 / stats.missing_stats.missing_before.max(1) as Float);
392 }
393
394 if self.config.enable_outlier_detection {
396 let threshold = if self.config.adaptive_thresholds {
397 self.get_adaptive_threshold(¤t_data, self.config.contamination_rate)?
398 } else {
399 self.config.outlier_threshold.unwrap_or(2.5)
400 };
401
402 let detector = OutlierDetector::new()
403 .method(self.config.detection_method)
404 .threshold(threshold);
405
406 let fitted_detector = detector.fit(¤t_data, &())?;
407
408 let outlier_result = fitted_detector.detect_outliers(¤t_data)?;
410 stats.outliers_per_feature = vec![outlier_result.summary.n_outliers; n_features]; stats.outlier_percentages = vec![outlier_result.summary.outlier_fraction; n_features]; stats.adaptive_thresholds = vec![threshold; n_features];
414
415 self.outlier_detector_ = Some(fitted_detector);
416 }
417
418 if self.config.enable_outlier_transformation {
420 let transformer = OutlierTransformer::new()
421 .method(self.config.transformation_method)
422 .handle_negatives(true)
423 .feature_wise(true);
424
425 let fitted_transformer = transformer.fit(¤t_data, &())?;
426
427 let original_stats = self.compute_distribution_stats(¤t_data);
429
430 current_data = fitted_transformer.transform(¤t_data)?;
431
432 let transformed_stats = self.compute_distribution_stats(¤t_data);
434 stats.transformation_stats.skewness_reduction = original_stats
435 .iter()
436 .zip(transformed_stats.iter())
437 .map(|((orig_skew, _), (trans_skew, _))| {
438 (orig_skew.abs() - trans_skew.abs()).max(0.0)
439 })
440 .collect();
441
442 stats.transformation_stats.kurtosis_reduction = original_stats
443 .iter()
444 .zip(transformed_stats.iter())
445 .map(|((_, orig_kurt), (_, trans_kurt))| {
446 (orig_kurt.abs() - trans_kurt.abs()).max(0.0)
447 })
448 .collect();
449
450 self.outlier_transformer_ = Some(fitted_transformer);
451 }
452
453 if self.config.enable_robust_scaling {
455 let _scaler = RobustScaler::new();
456 }
463
464 stats.robustness_score = self.compute_robustness_score(&stats);
466
467 stats.quality_improvement = self.compute_quality_improvement(&stats);
469
470 self.preprocessing_stats_ = Some(stats);
471
472 Ok(RobustPreprocessor {
473 config: self.config,
474 state: PhantomData,
475 outlier_detector_: self.outlier_detector_,
476 outlier_transformer_: self.outlier_transformer_,
477 outlier_imputer_: self.outlier_imputer_,
478 robust_scaler_: self.robust_scaler_,
479 preprocessing_stats_: self.preprocessing_stats_,
480 n_features_in_: self.n_features_in_,
481 })
482 }
483}
484
485impl RobustPreprocessor<Untrained> {
486 fn get_adaptive_threshold(
488 &self,
489 data: &Array2<Float>,
490 contamination_rate: Float,
491 ) -> Result<Float> {
492 let valid_values: Vec<Float> = data.iter().filter(|x| x.is_finite()).copied().collect();
493
494 if valid_values.is_empty() {
495 return Ok(2.5); }
497
498 let mut sorted_values = valid_values.clone();
500 sorted_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
501
502 let median = if sorted_values.len() % 2 == 0 {
503 let mid = sorted_values.len() / 2;
504 (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
505 } else {
506 sorted_values[sorted_values.len() / 2]
507 };
508
509 let deviations: Vec<Float> = valid_values.iter().map(|x| (x - median).abs()).collect();
511 let mut sorted_deviations = deviations;
512 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
513
514 let _mad = if sorted_deviations.len() % 2 == 0 {
515 let mid = sorted_deviations.len() / 2;
516 (sorted_deviations[mid - 1] + sorted_deviations[mid]) / 2.0
517 } else {
518 sorted_deviations[sorted_deviations.len() / 2]
519 };
520
521 let base_threshold = 2.5;
524 let adaptation_factor = 1.0 - contamination_rate;
525 let threshold = base_threshold * adaptation_factor + 1.5 * contamination_rate;
526
527 Ok(threshold.clamp(1.5, 4.0)) }
529
530 fn compute_distribution_stats(&self, data: &Array2<Float>) -> Vec<(Float, Float)> {
532 (0..data.ncols())
533 .map(|j| {
534 let column = data.column(j);
535 let valid_values: Vec<Float> =
536 column.iter().filter(|x| x.is_finite()).copied().collect();
537
538 if valid_values.len() < 3 {
539 return (0.0, 0.0);
540 }
541
542 let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
543 let variance = valid_values
544 .iter()
545 .map(|x| (x - mean).powi(2))
546 .sum::<Float>()
547 / valid_values.len() as Float;
548 let std = variance.sqrt();
549
550 if std == 0.0 {
551 return (0.0, 0.0);
552 }
553
554 let skewness = valid_values
556 .iter()
557 .map(|x| ((x - mean) / std).powi(3))
558 .sum::<Float>()
559 / valid_values.len() as Float;
560
561 let kurtosis = valid_values
563 .iter()
564 .map(|x| ((x - mean) / std).powi(4))
565 .sum::<Float>()
566 / valid_values.len() as Float
567 - 3.0; (skewness, kurtosis)
570 })
571 .collect()
572 }
573
574 fn compute_robustness_score(&self, stats: &RobustPreprocessingStats) -> Float {
576 let mut score = 1.0;
577
578 let avg_outlier_rate = stats.outlier_percentages.iter().sum::<Float>()
580 / stats.outlier_percentages.len() as Float;
581 score *= (1.0 - avg_outlier_rate / 100.0).max(0.1);
582
583 score *= stats.missing_stats.imputation_success_rate;
585
586 let avg_skewness_reduction = stats
588 .transformation_stats
589 .skewness_reduction
590 .iter()
591 .sum::<Float>()
592 / stats.transformation_stats.skewness_reduction.len() as Float;
593 score *= (1.0 + avg_skewness_reduction / 10.0).min(1.5);
594
595 score.clamp(0.0, 1.0)
596 }
597
598 fn compute_quality_improvement(&self, stats: &RobustPreprocessingStats) -> Float {
600 let imputation_improvement = stats.missing_stats.imputation_success_rate * 0.3;
601 let outlier_improvement = (1.0
602 - stats.outlier_percentages.iter().sum::<Float>()
603 / (stats.outlier_percentages.len() as Float * 100.0))
604 * 0.4;
605 let transformation_improvement = (stats
606 .transformation_stats
607 .skewness_reduction
608 .iter()
609 .sum::<Float>()
610 / stats.transformation_stats.skewness_reduction.len() as Float)
611 * 0.3;
612
613 (imputation_improvement + outlier_improvement + transformation_improvement).clamp(0.0, 1.0)
614 }
615}
616
617impl Transform<Array2<Float>, Array2<Float>> for RobustPreprocessor<Trained> {
618 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
619 let (_n_samples, n_features) = x.dim();
620
621 if n_features != self.n_features_in().expect("operation should succeed") {
622 return Err(SklearsError::FeatureMismatch {
623 expected: self.n_features_in().expect("operation should succeed"),
624 actual: n_features,
625 });
626 }
627
628 let mut result = x.clone();
629
630 if let Some(ref imputer) = self.outlier_imputer_ {
634 result = imputer.transform(&result)?;
635 }
636
637 if let Some(ref transformer) = self.outlier_transformer_ {
639 result = transformer.transform(&result)?;
640 }
641
642 if let Some(ref scaler) = self.robust_scaler_ {
644 result = scaler.transform(&result)?;
645 }
646
647 Ok(result)
648 }
649}
650
651impl RobustPreprocessor<Trained> {
652 pub fn n_features_in(&self) -> Option<usize> {
654 self.n_features_in_
655 }
656
657 pub fn preprocessing_stats(&self) -> Option<&RobustPreprocessingStats> {
659 self.preprocessing_stats_.as_ref()
660 }
661
662 pub fn outlier_detector(&self) -> Option<&OutlierDetector<Trained>> {
664 self.outlier_detector_.as_ref()
665 }
666
667 pub fn outlier_transformer(&self) -> Option<&OutlierTransformer<Trained>> {
669 self.outlier_transformer_.as_ref()
670 }
671
672 pub fn outlier_imputer(&self) -> Option<&OutlierAwareImputer> {
674 self.outlier_imputer_.as_ref()
675 }
676
677 pub fn robust_scaler(&self) -> Option<&RobustScaler> {
679 self.robust_scaler_.as_ref()
680 }
681
682 pub fn preprocessing_report(&self) -> Result<String> {
684 let stats = self.preprocessing_stats_.as_ref().ok_or_else(|| {
685 SklearsError::InvalidInput("No preprocessing statistics available".to_string())
686 })?;
687
688 let mut report = String::new();
689
690 report.push_str("=== Robust Preprocessing Report ===\n\n");
691
692 report.push_str(&format!(
694 "Robustness Score: {:.3}\n",
695 stats.robustness_score
696 ));
697 report.push_str(&format!(
698 "Quality Improvement: {:.3}\n",
699 stats.quality_improvement
700 ));
701 report.push('\n');
702
703 report.push_str("=== Missing Value Handling ===\n");
705 report.push_str(&format!(
706 "Missing values before: {}\n",
707 stats.missing_stats.missing_before
708 ));
709 report.push_str(&format!(
710 "Missing values after: {}\n",
711 stats.missing_stats.missing_after
712 ));
713 report.push_str(&format!(
714 "Imputation success rate: {:.1}%\n",
715 stats.missing_stats.imputation_success_rate * 100.0
716 ));
717 report.push('\n');
718
719 if !stats.outliers_per_feature.is_empty() {
721 report.push_str("=== Outlier Detection ===\n");
722 for (i, (&count, &percentage)) in stats
723 .outliers_per_feature
724 .iter()
725 .zip(stats.outlier_percentages.iter())
726 .enumerate()
727 {
728 report.push_str(&format!(
729 "Feature {}: {} outliers ({:.1}%)\n",
730 i, count, percentage
731 ));
732 }
733 report.push('\n');
734 }
735
736 if !stats.transformation_stats.skewness_reduction.is_empty() {
738 report.push_str("=== Transformation Effectiveness ===\n");
739 for (i, (&skew_red, &kurt_red)) in stats
740 .transformation_stats
741 .skewness_reduction
742 .iter()
743 .zip(stats.transformation_stats.kurtosis_reduction.iter())
744 .enumerate()
745 {
746 report.push_str(&format!(
747 "Feature {}: Skewness reduction: {:.3}, Kurtosis reduction: {:.3}\n",
748 i, skew_red, kurt_red
749 ));
750 }
751 report.push('\n');
752 }
753
754 report.push_str("=== Configuration ===\n");
756 report.push_str(&format!("Strategy: {:?}\n", self.config.strategy));
757 report.push_str(&format!(
758 "Outlier detection: {}\n",
759 self.config.enable_outlier_detection
760 ));
761 report.push_str(&format!(
762 "Outlier transformation: {}\n",
763 self.config.enable_outlier_transformation
764 ));
765 report.push_str(&format!(
766 "Outlier imputation: {}\n",
767 self.config.enable_outlier_imputation
768 ));
769 report.push_str(&format!(
770 "Robust scaling: {}\n",
771 self.config.enable_robust_scaling
772 ));
773 report.push_str(&format!(
774 "Adaptive thresholds: {}\n",
775 self.config.adaptive_thresholds
776 ));
777
778 Ok(report)
779 }
780
781 pub fn is_effective(&self) -> bool {
783 if let Some(stats) = &self.preprocessing_stats_ {
784 stats.robustness_score > 0.7 && stats.quality_improvement > 0.5
785 } else {
786 false
787 }
788 }
789
790 pub fn get_recommendations(&self) -> Vec<String> {
792 let mut recommendations = Vec::new();
793
794 if let Some(stats) = &self.preprocessing_stats_ {
795 if stats.robustness_score < 0.5 {
796 recommendations
797 .push("Consider using a more aggressive robust strategy".to_string());
798 }
799
800 let avg_outlier_rate = stats.outlier_percentages.iter().sum::<Float>()
801 / stats.outlier_percentages.len() as Float;
802 if avg_outlier_rate > 20.0 {
803 recommendations.push(
804 "High outlier rate detected - consider additional data cleaning".to_string(),
805 );
806 }
807
808 if stats.missing_stats.imputation_success_rate < 0.8 {
809 recommendations.push(
810 "Low imputation success rate - consider alternative imputation strategies"
811 .to_string(),
812 );
813 }
814
815 let avg_skewness_reduction = stats
816 .transformation_stats
817 .skewness_reduction
818 .iter()
819 .sum::<Float>()
820 / stats.transformation_stats.skewness_reduction.len() as Float;
821 if avg_skewness_reduction < 0.1 {
822 recommendations.push("Low transformation effectiveness - consider alternative transformation methods".to_string());
823 }
824
825 if stats.quality_improvement < 0.3 {
826 recommendations.push(
827 "Low overall quality improvement - consider reviewing preprocessing pipeline"
828 .to_string(),
829 );
830 }
831 }
832
833 if recommendations.is_empty() {
834 recommendations
835 .push("Preprocessing appears effective - no specific recommendations".to_string());
836 }
837
838 recommendations
839 }
840}
841
842impl Default for RobustPreprocessor<Untrained> {
843 fn default() -> Self {
844 Self::new()
845 }
846}
847
848#[allow(non_snake_case)]
849#[cfg(test)]
850mod tests {
851 use super::*;
852 use scirs2_core::ndarray::Array2;
853
854 #[test]
855 fn test_robust_preprocessor_creation() {
856 let preprocessor = RobustPreprocessor::new();
857 assert_eq!(
858 preprocessor.config.strategy as u8,
859 RobustStrategy::Moderate as u8
860 );
861 assert!(preprocessor.config.enable_outlier_detection);
862 assert!(preprocessor.config.enable_robust_scaling);
863 }
864
865 #[test]
866 fn test_robust_preprocessor_conservative() {
867 let preprocessor = RobustPreprocessor::conservative();
868 assert_eq!(
869 preprocessor.config.strategy as u8,
870 RobustStrategy::Conservative as u8
871 );
872 assert_eq!(preprocessor.config.contamination_rate, 0.05);
873 assert!(!preprocessor.config.adaptive_thresholds);
874 }
875
876 #[test]
877 fn test_robust_preprocessor_aggressive() {
878 let preprocessor = RobustPreprocessor::aggressive();
879 assert_eq!(
880 preprocessor.config.strategy as u8,
881 RobustStrategy::Aggressive as u8
882 );
883 assert_eq!(preprocessor.config.contamination_rate, 0.15);
884 assert_eq!(preprocessor.config.outlier_threshold, Some(2.0));
885 }
886
887 #[test]
888 fn test_robust_preprocessor_fit_transform() {
889 let data = Array2::from_shape_vec(
890 (10, 2),
891 vec![
892 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 6.0, 60.0, 7.0, 70.0, 8.0, 80.0, 100.0,
894 1000.0, 9.0, 90.0,
896 ],
897 )
898 .expect("operation should succeed");
899
900 let preprocessor = RobustPreprocessor::moderate();
901 let fitted = preprocessor
902 .fit(&data, &())
903 .expect("model fitting should succeed");
904 let result = fitted
905 .transform(&data)
906 .expect("transformation should succeed");
907
908 assert_eq!(result.dim(), data.dim());
909
910 assert!(
912 fitted.is_effective()
913 || fitted
914 .preprocessing_stats()
915 .expect("operation should succeed")
916 .robustness_score
917 > 0.3
918 );
919 }
920
921 #[test]
922 fn test_robust_preprocessor_with_missing_values() {
923 let data = Array2::from_shape_vec(
924 (8, 2),
925 vec![
926 1.0,
927 10.0,
928 2.0,
929 Float::NAN, 3.0,
931 30.0,
932 Float::NAN,
933 40.0, 5.0,
935 50.0,
936 100.0,
937 1000.0, 7.0,
939 70.0,
940 8.0,
941 80.0,
942 ],
943 )
944 .expect("operation should succeed");
945
946 let preprocessor = RobustPreprocessor::moderate()
947 .outlier_imputation(false) .outlier_transformation(false); let fitted = preprocessor
951 .fit(&data, &())
952 .expect("model fitting should succeed");
953 let result = fitted
954 .transform(&data)
955 .expect("transformation should succeed");
956
957 assert_eq!(result.dim(), data.dim());
958
959 let missing_before = data.iter().filter(|x| x.is_nan()).count();
961 let missing_after = result.iter().filter(|x| x.is_nan()).count();
962 assert_eq!(missing_after, missing_before); let stats = fitted
965 .preprocessing_stats()
966 .expect("operation should succeed");
967 assert!(stats.robustness_score >= 0.0);
970 }
971
972 #[test]
973 fn test_robust_preprocessor_configuration() {
974 let preprocessor = RobustPreprocessor::new()
975 .outlier_detection(false)
976 .robust_scaling(true)
977 .outlier_threshold(2.0)
978 .contamination_rate(0.05);
979
980 assert!(!preprocessor.config.enable_outlier_detection);
981 assert!(preprocessor.config.enable_robust_scaling);
982 assert_eq!(preprocessor.config.outlier_threshold, Some(2.0));
983 assert_eq!(preprocessor.config.contamination_rate, 0.05);
984 }
985
986 #[test]
987 fn test_adaptive_threshold_computation() {
988 let data = Array2::from_shape_vec((6, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0])
989 .expect("shape and data length should match");
990
991 let preprocessor = RobustPreprocessor::new();
992 let threshold = preprocessor
993 .get_adaptive_threshold(&data, 0.1)
994 .expect("operation should succeed");
995
996 assert!(threshold >= 1.5 && threshold <= 4.0);
997 }
998
999 #[test]
1000 fn test_preprocessing_report() {
1001 let data = Array2::from_shape_vec(
1002 (6, 2),
1003 vec![
1004 1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0,
1005 1000.0, ],
1007 )
1008 .expect("operation should succeed");
1009
1010 let preprocessor = RobustPreprocessor::moderate();
1011 let fitted = preprocessor
1012 .fit(&data, &())
1013 .expect("model fitting should succeed");
1014
1015 let report = fitted
1016 .preprocessing_report()
1017 .expect("operation should succeed");
1018 assert!(report.contains("Robust Preprocessing Report"));
1019 assert!(report.contains("Robustness Score"));
1020 assert!(report.contains("Quality Improvement"));
1021 }
1022
1023 #[test]
1024 fn test_recommendations() {
1025 let data = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0])
1026 .expect("shape and data length should match");
1027
1028 let preprocessor = RobustPreprocessor::conservative();
1029 let fitted = preprocessor
1030 .fit(&data, &())
1031 .expect("model fitting should succeed");
1032
1033 let recommendations = fitted.get_recommendations();
1034 assert!(!recommendations.is_empty());
1035 }
1036
1037 #[test]
1038 fn test_robust_preprocessor_error_handling() {
1039 let preprocessor = RobustPreprocessor::new();
1040
1041 let empty_data =
1043 Array2::from_shape_vec((0, 0), vec![]).expect("shape and data length should match");
1044 assert!(preprocessor.fit(&empty_data, &()).is_err());
1045 }
1046
1047 #[test]
1048 fn test_feature_mismatch() {
1049 let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
1050 .expect("shape and data length should match");
1051 let wrong_data = Array2::from_shape_vec(
1052 (4, 3),
1053 vec![
1054 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1055 ],
1056 )
1057 .expect("operation should succeed");
1058
1059 let preprocessor = RobustPreprocessor::moderate();
1060 let fitted = preprocessor
1061 .fit(&data, &())
1062 .expect("model fitting should succeed");
1063
1064 assert!(fitted.transform(&wrong_data).is_err());
1065 }
1066}