1use scirs2_core::ndarray::Array1;
7use scirs2_core::random::{
8 essentials::Normal, prelude::*, rngs::StdRng, Distribution, Rng, SeedableRng,
9};
10use sklears_core::error::Result;
11use sklears_core::traits::{Estimator, Fit, Predict};
12use sklears_core::types::{Features, Float};
13
14#[derive(Debug, Clone, PartialEq)]
16#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
17pub enum RobustStrategy {
18 OutlierResistant {
20 contamination: Float,
22 detection_method: OutlierDetectionMethod,
24 },
25 TrimmedMean {
27 trim_proportion: Float,
29 },
30 RobustScale {
32 scale_estimator: ScaleEstimator,
34 location_estimator: LocationEstimator,
36 },
37 BreakdownPoint { breakdown_point: Float },
39 InfluenceResistant {
41 huber_delta: Float,
43 max_iter: usize,
45 tolerance: Float,
47 },
48}
49
50#[derive(Debug, Clone, PartialEq)]
52#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
53pub enum OutlierDetectionMethod {
54 IQR { multiplier: Float },
56 ModifiedZScore { threshold: Float },
58 MedianDistance { threshold: Float },
60}
61
62#[derive(Debug, Clone, PartialEq)]
64#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
65pub enum ScaleEstimator {
66 MAD,
68 Qn,
70 IQR,
72 Sn,
74}
75
76#[derive(Debug, Clone, PartialEq)]
78#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
79pub enum LocationEstimator {
80 Median,
82 TrimmedMean { trim_proportion: Float },
84 Huber { delta: Float },
86 Biweight,
88}
89
90#[derive(Debug, Clone)]
92pub struct RobustDummyRegressor<State = sklears_core::traits::Untrained> {
93 pub strategy: RobustStrategy,
95 pub random_state: Option<u64>,
97
98 pub(crate) robust_location_: Option<Float>,
101 pub(crate) robust_scale_: Option<Float>,
103 pub(crate) outlier_mask_: Option<Array1<bool>>,
105 pub(crate) clean_data_: Option<Array1<Float>>,
107 pub(crate) breakdown_point_: Option<Float>,
109 pub(crate) m_weights_: Option<Array1<Float>>,
111
112 pub(crate) _state: std::marker::PhantomData<State>,
114}
115
116impl RobustDummyRegressor {
117 pub fn new(strategy: RobustStrategy) -> Self {
119 Self {
120 strategy,
121 random_state: None,
122 robust_location_: None,
123 robust_scale_: None,
124 outlier_mask_: None,
125 clean_data_: None,
126 breakdown_point_: None,
127 m_weights_: None,
128 _state: std::marker::PhantomData,
129 }
130 }
131
132 pub fn with_random_state(mut self, random_state: u64) -> Self {
134 self.random_state = Some(random_state);
135 self
136 }
137
138 pub fn breakdown_point(&self) -> Option<Float> {
140 self.breakdown_point_
141 }
142
143 pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
145 self.outlier_mask_.as_ref()
146 }
147
148 pub fn m_weights(&self) -> Option<&Array1<Float>> {
150 self.m_weights_.as_ref()
151 }
152}
153
154impl Default for RobustDummyRegressor {
155 fn default() -> Self {
156 Self::new(RobustStrategy::TrimmedMean {
157 trim_proportion: 0.1,
158 })
159 }
160}
161
162impl Estimator for RobustDummyRegressor {
163 type Config = ();
164 type Error = sklears_core::error::SklearsError;
165 type Float = Float;
166
167 fn config(&self) -> &Self::Config {
168 &()
169 }
170}
171
172impl Fit<Features, Array1<Float>> for RobustDummyRegressor {
173 type Fitted = RobustDummyRegressor<sklears_core::traits::Trained>;
174
175 fn fit(self, x: &Features, y: &Array1<Float>) -> Result<Self::Fitted> {
176 if x.is_empty() || y.is_empty() {
177 return Err(sklears_core::error::SklearsError::InvalidInput(
178 "Input cannot be empty".to_string(),
179 ));
180 }
181
182 if x.nrows() != y.len() {
183 return Err(sklears_core::error::SklearsError::InvalidInput(
184 "Number of samples in X and y must be equal".to_string(),
185 ));
186 }
187
188 let mut fitted = RobustDummyRegressor {
189 strategy: self.strategy.clone(),
190 random_state: self.random_state,
191 robust_location_: None,
192 robust_scale_: None,
193 outlier_mask_: None,
194 clean_data_: None,
195 breakdown_point_: None,
196 m_weights_: None,
197 _state: std::marker::PhantomData,
198 };
199
200 match &self.strategy {
201 RobustStrategy::OutlierResistant {
202 contamination,
203 detection_method,
204 } => {
205 fitted.fit_outlier_resistant(y, *contamination, detection_method)?;
206 }
207 RobustStrategy::TrimmedMean { trim_proportion } => {
208 fitted.fit_trimmed_mean(y, *trim_proportion)?;
209 }
210 RobustStrategy::RobustScale {
211 scale_estimator,
212 location_estimator,
213 } => {
214 fitted.fit_robust_scale(y, scale_estimator, location_estimator)?;
215 }
216 RobustStrategy::BreakdownPoint { breakdown_point } => {
217 fitted.fit_breakdown_point(y, *breakdown_point)?;
218 }
219 RobustStrategy::InfluenceResistant {
220 huber_delta,
221 max_iter,
222 tolerance,
223 } => {
224 fitted.fit_influence_resistant(y, *huber_delta, *max_iter, *tolerance)?;
225 }
226 }
227
228 Ok(fitted)
229 }
230}
231
232impl RobustDummyRegressor<sklears_core::traits::Trained> {
233 pub fn breakdown_point(&self) -> Option<Float> {
235 self.breakdown_point_
236 }
237
238 pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
240 self.outlier_mask_.as_ref()
241 }
242
243 pub fn m_weights(&self) -> Option<&Array1<Float>> {
245 self.m_weights_.as_ref()
246 }
247 fn fit_outlier_resistant(
249 &mut self,
250 y: &Array1<Float>,
251 contamination: Float,
252 detection_method: &OutlierDetectionMethod,
253 ) -> Result<()> {
254 let outlier_mask = self.detect_outliers(y, detection_method, contamination)?;
255 let clean_data: Array1<Float> = y
256 .iter()
257 .zip(outlier_mask.iter())
258 .filter_map(|(&value, &is_outlier)| if !is_outlier { Some(value) } else { None })
259 .collect();
260
261 if clean_data.is_empty() {
262 return Err(sklears_core::error::SklearsError::InvalidInput(
263 "All data points detected as outliers".to_string(),
264 ));
265 }
266
267 let location = self.compute_median(&clean_data);
268 let scale = self.compute_mad(&clean_data, location);
269
270 let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
272 let breakdown_point = n_outliers as Float / y.len() as Float;
273
274 self.robust_location_ = Some(location);
275 self.robust_scale_ = Some(scale);
276 self.outlier_mask_ = Some(outlier_mask);
277 self.clean_data_ = Some(clean_data);
278 self.breakdown_point_ = Some(breakdown_point);
279
280 Ok(())
281 }
282
283 fn fit_trimmed_mean(&mut self, y: &Array1<Float>, trim_proportion: Float) -> Result<()> {
285 if !(0.0..0.5).contains(&trim_proportion) {
286 return Err(sklears_core::error::SklearsError::InvalidInput(
287 "Trim proportion must be between 0 and 0.5".to_string(),
288 ));
289 }
290
291 let mut sorted_y = y.to_vec();
292 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
293
294 let n = sorted_y.len();
295 let trim_count = (n as Float * trim_proportion).floor() as usize;
296
297 if trim_count * 2 >= n {
298 return Err(sklears_core::error::SklearsError::InvalidInput(
299 "Too much trimming for dataset size".to_string(),
300 ));
301 }
302
303 let trimmed_data = &sorted_y[trim_count..(n - trim_count)];
304 let location = trimmed_data.iter().sum::<Float>() / trimmed_data.len() as Float;
305
306 let mean = location;
308 let variance = trimmed_data
309 .iter()
310 .map(|&x| (x - mean).powi(2))
311 .sum::<Float>()
312 / (trimmed_data.len() - 1) as Float;
313 let scale = variance.sqrt();
314
315 let breakdown_point = trim_proportion;
316
317 self.robust_location_ = Some(location);
318 self.robust_scale_ = Some(scale);
319 self.breakdown_point_ = Some(breakdown_point);
320
321 Ok(())
322 }
323
324 fn fit_robust_scale(
326 &mut self,
327 y: &Array1<Float>,
328 scale_estimator: &ScaleEstimator,
329 location_estimator: &LocationEstimator,
330 ) -> Result<()> {
331 let location = self.compute_robust_location(y, location_estimator)?;
332 let scale = self.compute_robust_scale(y, scale_estimator, location)?;
333
334 let breakdown_point = match (location_estimator, scale_estimator) {
336 (LocationEstimator::Median, ScaleEstimator::MAD) => 0.5,
337 (LocationEstimator::Median, ScaleEstimator::Qn) => 0.5,
338 (LocationEstimator::TrimmedMean { trim_proportion }, _) => *trim_proportion,
339 _ => 0.25, };
341
342 self.robust_location_ = Some(location);
343 self.robust_scale_ = Some(scale);
344 self.breakdown_point_ = Some(breakdown_point);
345
346 Ok(())
347 }
348
349 fn fit_breakdown_point(&mut self, y: &Array1<Float>, target_breakdown: Float) -> Result<()> {
351 if target_breakdown <= 0.0 || target_breakdown >= 0.5 {
352 return Err(sklears_core::error::SklearsError::InvalidInput(
353 "Breakdown point must be between 0 and 0.5".to_string(),
354 ));
355 }
356
357 let trim_proportion = target_breakdown;
359 self.fit_trimmed_mean(y, trim_proportion)?;
360
361 Ok(())
362 }
363
364 fn fit_influence_resistant(
366 &mut self,
367 y: &Array1<Float>,
368 huber_delta: Float,
369 max_iter: usize,
370 tolerance: Float,
371 ) -> Result<()> {
372 let mut location = self.compute_median(y);
374 let initial_scale = self.compute_mad(y, location);
375
376 let mut weights = Array1::ones(y.len());
377
378 for _iter in 0..max_iter {
380 let old_location = location;
381
382 for i in 0..y.len() {
384 let residual = (y[i] - location).abs();
385 let scaled_residual = residual / initial_scale;
386
387 weights[i] = if scaled_residual <= huber_delta {
388 1.0
389 } else {
390 huber_delta / scaled_residual
391 };
392 }
393
394 let weighted_sum: Float = y.iter().zip(weights.iter()).map(|(&yi, &wi)| wi * yi).sum();
396 let weight_sum: Float = weights.sum();
397
398 if weight_sum > 0.0 {
399 location = weighted_sum / weight_sum;
400 }
401
402 if (location - old_location).abs() < tolerance {
404 break;
405 }
406 }
407
408 let weighted_variance: Float = y
410 .iter()
411 .zip(weights.iter())
412 .map(|(&yi, &wi)| wi * (yi - location).powi(2))
413 .sum();
414 let effective_sample_size: Float = weights.sum();
415
416 let scale = if effective_sample_size > 1.0 {
417 (weighted_variance / (effective_sample_size - 1.0)).sqrt()
418 } else {
419 initial_scale
420 };
421
422 let breakdown_point = 1.0 / (2.0 * huber_delta + 1.0);
424
425 self.robust_location_ = Some(location);
426 self.robust_scale_ = Some(scale);
427 self.m_weights_ = Some(weights);
428 self.breakdown_point_ = Some(breakdown_point);
429
430 Ok(())
431 }
432
433 fn detect_outliers(
435 &self,
436 y: &Array1<Float>,
437 method: &OutlierDetectionMethod,
438 _contamination: Float,
439 ) -> Result<Array1<bool>> {
440 let mut outlier_mask = Array1::from_elem(y.len(), false);
441
442 match method {
443 OutlierDetectionMethod::IQR { multiplier } => {
444 let mut sorted_y = y.to_vec();
445 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
446
447 let n = sorted_y.len();
448 let q1_idx = n / 4;
449 let q3_idx = 3 * n / 4;
450
451 let q1 = sorted_y[q1_idx];
452 let q3 = sorted_y[q3_idx];
453 let iqr = q3 - q1;
454
455 let lower_bound = q1 - multiplier * iqr;
456 let upper_bound = q3 + multiplier * iqr;
457
458 for (i, &value) in y.iter().enumerate() {
459 if value < lower_bound || value > upper_bound {
460 outlier_mask[i] = true;
461 }
462 }
463 }
464 OutlierDetectionMethod::ModifiedZScore { threshold } => {
465 let median = self.compute_median(y);
466 let mad = self.compute_mad(y, median);
467
468 if mad > 0.0 {
469 for (i, &value) in y.iter().enumerate() {
470 let modified_z_score = 0.6745 * (value - median).abs() / mad;
471 if modified_z_score > *threshold {
472 outlier_mask[i] = true;
473 }
474 }
475 }
476 }
477 OutlierDetectionMethod::MedianDistance { threshold } => {
478 let median = self.compute_median(y);
479 let distances: Array1<Float> =
480 y.iter().map(|&value| (value - median).abs()).collect();
481 let distance_threshold = self.compute_median(&distances) * threshold;
482
483 for (i, &distance) in distances.iter().enumerate() {
484 if distance > distance_threshold {
485 outlier_mask[i] = true;
486 }
487 }
488 }
489 }
490
491 Ok(outlier_mask)
492 }
493
494 fn compute_median(&self, data: &Array1<Float>) -> Float {
496 let mut sorted_data = data.to_vec();
497 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
498
499 let n = sorted_data.len();
500 if n % 2 == 0 {
501 (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
502 } else {
503 sorted_data[n / 2]
504 }
505 }
506
507 fn compute_mad(&self, data: &Array1<Float>, median: Float) -> Float {
509 let deviations: Array1<Float> = data.iter().map(|&x| (x - median).abs()).collect();
510 self.compute_median(&deviations) * 1.4826 }
512
513 fn compute_robust_location(
515 &self,
516 y: &Array1<Float>,
517 estimator: &LocationEstimator,
518 ) -> Result<Float> {
519 match estimator {
520 LocationEstimator::Median => Ok(self.compute_median(y)),
521 LocationEstimator::TrimmedMean { trim_proportion } => {
522 let mut sorted_y = y.to_vec();
523 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
524
525 let n = sorted_y.len();
526 let trim_count = (n as Float * trim_proportion).floor() as usize;
527
528 if trim_count * 2 >= n {
529 return Err(sklears_core::error::SklearsError::InvalidInput(
530 "Too much trimming for dataset size".to_string(),
531 ));
532 }
533
534 let trimmed_data = &sorted_y[trim_count..(n - trim_count)];
535 Ok(trimmed_data.iter().sum::<Float>() / trimmed_data.len() as Float)
536 }
537 LocationEstimator::Huber { delta } => {
538 let mut location = self.compute_median(y);
540 let scale = self.compute_mad(y, location);
541
542 for _iter in 0..10 {
543 let old_location = location;
544 let mut weighted_sum = 0.0;
545 let mut weight_sum = 0.0;
546
547 for &yi in y.iter() {
548 let residual = (yi - location).abs();
549 let weight = if residual <= delta * scale {
550 1.0
551 } else {
552 delta * scale / residual
553 };
554
555 weighted_sum += weight * yi;
556 weight_sum += weight;
557 }
558
559 if weight_sum > 0.0 {
560 location = weighted_sum / weight_sum;
561 }
562
563 if (location - old_location).abs() < 1e-6 {
564 break;
565 }
566 }
567
568 Ok(location)
569 }
570 LocationEstimator::Biweight => {
571 let median = self.compute_median(y);
573 let mad = self.compute_mad(y, median);
574
575 if mad == 0.0 {
576 return Ok(median);
577 }
578
579 let mut weighted_sum = 0.0;
580 let mut weight_sum = 0.0;
581
582 for &yi in y.iter() {
583 let u = (yi - median) / (9.0 * mad);
584 if u.abs() < 1.0 {
585 let weight = (1.0 - u * u).powi(2);
586 weighted_sum += weight * yi;
587 weight_sum += weight;
588 }
589 }
590
591 if weight_sum > 0.0 {
592 Ok(weighted_sum / weight_sum)
593 } else {
594 Ok(median)
595 }
596 }
597 }
598 }
599
600 fn compute_robust_scale(
602 &self,
603 y: &Array1<Float>,
604 estimator: &ScaleEstimator,
605 location: Float,
606 ) -> Result<Float> {
607 match estimator {
608 ScaleEstimator::MAD => Ok(self.compute_mad(y, location)),
609 ScaleEstimator::IQR => {
610 let mut sorted_y = y.to_vec();
611 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
612
613 let n = sorted_y.len();
614 let q1_idx = n / 4;
615 let q3_idx = 3 * n / 4;
616
617 let q1 = sorted_y[q1_idx];
618 let q3 = sorted_y[q3_idx];
619 Ok((q3 - q1) / 1.349) }
621 ScaleEstimator::Qn => {
622 let mut pairwise_distances = Vec::new();
624 for i in 0..y.len() {
625 for j in (i + 1)..y.len() {
626 pairwise_distances.push((y[i] - y[j]).abs());
627 }
628 }
629
630 if pairwise_distances.is_empty() {
631 return Ok(0.0);
632 }
633
634 pairwise_distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
635 let q1_idx = pairwise_distances.len() / 4;
636 Ok(pairwise_distances[q1_idx] * 2.2219) }
638 ScaleEstimator::Sn => {
639 let mut medians = Vec::new();
641
642 for i in 0..y.len() {
643 let mut distances: Vec<Float> = y
644 .iter()
645 .enumerate()
646 .filter(|(j, _)| *j != i)
647 .map(|(_, &yj)| (y[i] - yj).abs())
648 .collect();
649
650 if !distances.is_empty() {
651 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
652 let median_dist = if distances.len() % 2 == 0 {
653 (distances[distances.len() / 2 - 1] + distances[distances.len() / 2])
654 / 2.0
655 } else {
656 distances[distances.len() / 2]
657 };
658 medians.push(median_dist);
659 }
660 }
661
662 if medians.is_empty() {
663 return Ok(0.0);
664 }
665
666 medians.sort_by(|a, b| a.partial_cmp(b).unwrap());
667 let result = if medians.len() % 2 == 0 {
668 (medians[medians.len() / 2 - 1] + medians[medians.len() / 2]) / 2.0
669 } else {
670 medians[medians.len() / 2]
671 };
672
673 Ok(result * 1.1926) }
675 }
676 }
677}
678
679impl Predict<Features, Array1<Float>> for RobustDummyRegressor<sklears_core::traits::Trained> {
680 fn predict(&self, x: &Features) -> Result<Array1<Float>> {
681 if x.is_empty() {
682 return Err(sklears_core::error::SklearsError::InvalidInput(
683 "Input cannot be empty".to_string(),
684 ));
685 }
686
687 let n_samples = x.nrows();
688 let mut predictions = Array1::zeros(n_samples);
689 let location = self.robust_location_.unwrap_or(0.0);
690 let scale = self.robust_scale_.unwrap_or(1.0);
691
692 let mut rng = if let Some(seed) = self.random_state {
693 StdRng::seed_from_u64(seed)
694 } else {
695 StdRng::seed_from_u64(0)
696 };
697
698 match &self.strategy {
699 RobustStrategy::OutlierResistant { .. }
700 | RobustStrategy::TrimmedMean { .. }
701 | RobustStrategy::RobustScale { .. }
702 | RobustStrategy::BreakdownPoint { .. } => {
703 predictions.fill(location);
705 }
706 RobustStrategy::InfluenceResistant { .. } => {
707 if scale > 0.0 {
710 let normal = Normal::new(location, scale * 0.1).unwrap();
711 for i in 0..n_samples {
712 predictions[i] = normal.sample(&mut rng);
713 }
714 } else {
715 predictions.fill(location);
716 }
717 }
718 }
719
720 Ok(predictions)
721 }
722}
723
724#[derive(Debug, Clone)]
726pub struct RobustDummyClassifier<State = sklears_core::traits::Untrained> {
727 pub strategy: RobustStrategy,
729 pub random_state: Option<u64>,
731
732 pub(crate) robust_class_probs_: Option<Array1<Float>>,
734 pub(crate) classes_: Option<Array1<i32>>,
735 pub(crate) outlier_mask_: Option<Array1<bool>>,
736
737 pub(crate) _state: std::marker::PhantomData<State>,
739}
740
741impl RobustDummyClassifier {
742 pub fn new(strategy: RobustStrategy) -> Self {
744 Self {
745 strategy,
746 random_state: None,
747 robust_class_probs_: None,
748 classes_: None,
749 outlier_mask_: None,
750 _state: std::marker::PhantomData,
751 }
752 }
753
754 pub fn with_random_state(mut self, random_state: u64) -> Self {
756 self.random_state = Some(random_state);
757 self
758 }
759
760 pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
762 self.outlier_mask_.as_ref()
763 }
764}
765
766impl Default for RobustDummyClassifier {
767 fn default() -> Self {
768 Self::new(RobustStrategy::OutlierResistant {
769 contamination: 0.1,
770 detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
771 })
772 }
773}
774
775impl Estimator for RobustDummyClassifier {
776 type Config = ();
777 type Error = sklears_core::error::SklearsError;
778 type Float = Float;
779
780 fn config(&self) -> &Self::Config {
781 &()
782 }
783}
784
785impl Fit<Features, Array1<i32>> for RobustDummyClassifier {
786 type Fitted = RobustDummyClassifier<sklears_core::traits::Trained>;
787
788 fn fit(self, x: &Features, y: &Array1<i32>) -> Result<Self::Fitted> {
789 if x.is_empty() || y.is_empty() {
790 return Err(sklears_core::error::SklearsError::InvalidInput(
791 "Input cannot be empty".to_string(),
792 ));
793 }
794
795 if x.nrows() != y.len() {
796 return Err(sklears_core::error::SklearsError::InvalidInput(
797 "Number of samples in X and y must be equal".to_string(),
798 ));
799 }
800
801 let mut unique_classes = y.iter().cloned().collect::<Vec<_>>();
803 unique_classes.sort_unstable();
804 unique_classes.dedup();
805 let classes = Array1::from_vec(unique_classes);
806
807 let mut fitted = RobustDummyClassifier {
810 strategy: self.strategy.clone(),
811 random_state: self.random_state,
812 robust_class_probs_: None,
813 classes_: Some(classes.clone()),
814 outlier_mask_: None,
815 _state: std::marker::PhantomData,
816 };
817
818 let mut class_counts = std::collections::HashMap::new();
820 for &class in y.iter() {
821 *class_counts.entry(class).or_insert(0) += 1;
822 }
823
824 let total_samples = y.len() as f64;
826 let min_frequency = 0.05; let mut outlier_mask = Array1::from_elem(y.len(), false);
828
829 for (i, &class) in y.iter().enumerate() {
830 let class_freq = *class_counts.get(&class).unwrap() as f64 / total_samples;
831 if class_freq < min_frequency {
832 outlier_mask[i] = true;
833 }
834 }
835
836 let mut robust_class_counts = std::collections::HashMap::new();
838 let mut total_clean_samples = 0;
839
840 for (i, &class) in y.iter().enumerate() {
841 if !outlier_mask[i] {
842 *robust_class_counts.entry(class).or_insert(0) += 1;
843 total_clean_samples += 1;
844 }
845 }
846
847 let mut class_probs = Array1::zeros(classes.len());
848 for (i, &class) in classes.iter().enumerate() {
849 let count = *robust_class_counts.get(&class).unwrap_or(&0);
850 class_probs[i] = if total_clean_samples > 0 {
851 count as Float / total_clean_samples as Float
852 } else {
853 1.0 / classes.len() as Float
854 };
855 }
856
857 fitted.robust_class_probs_ = Some(class_probs);
858 fitted.outlier_mask_ = Some(outlier_mask);
859
860 Ok(fitted)
861 }
862}
863
864impl RobustDummyClassifier<sklears_core::traits::Trained> {
865 pub fn outlier_mask(&self) -> Option<&Array1<bool>> {
867 self.outlier_mask_.as_ref()
868 }
869}
870
871impl Predict<Features, Array1<i32>> for RobustDummyClassifier<sklears_core::traits::Trained> {
872 fn predict(&self, x: &Features) -> Result<Array1<i32>> {
873 if x.is_empty() {
874 return Err(sklears_core::error::SklearsError::InvalidInput(
875 "Input cannot be empty".to_string(),
876 ));
877 }
878
879 let n_samples = x.nrows();
880 let mut predictions = Array1::zeros(n_samples);
881
882 let classes = self.classes_.as_ref().unwrap();
883 let class_probs = self.robust_class_probs_.as_ref().unwrap();
884
885 let mut rng = if let Some(seed) = self.random_state {
886 StdRng::seed_from_u64(seed)
887 } else {
888 StdRng::seed_from_u64(0)
889 };
890
891 for i in 0..n_samples {
893 let rand_val: Float = rng.gen();
894 let mut cumulative_prob = 0.0;
895 let mut selected_class = classes[0];
896
897 for (j, &class) in classes.iter().enumerate() {
898 cumulative_prob += class_probs[j];
899 if rand_val <= cumulative_prob {
900 selected_class = class;
901 break;
902 }
903 }
904 predictions[i] = selected_class;
905 }
906
907 Ok(predictions)
908 }
909}
910
911#[allow(non_snake_case)]
912#[cfg(test)]
913mod tests {
914 use super::*;
915 use approx::assert_abs_diff_eq;
916 use scirs2_core::ndarray::{array, Array2};
917
918 #[test]
919 fn test_outlier_resistant_regressor() {
920 let x = Array2::from_shape_vec(
921 (10, 2),
922 vec![
923 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0,
924 100.0, 101.0, 102.0, 103.0, ],
926 )
927 .unwrap();
928 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 101.0]; let regressor = RobustDummyRegressor::new(RobustStrategy::OutlierResistant {
931 contamination: 0.2,
932 detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
933 });
934
935 let fitted = regressor.fit(&x, &y).unwrap();
936 let predictions = fitted.predict(&x).unwrap();
937
938 assert_eq!(predictions.len(), 10);
939
940 let outlier_mask = fitted.outlier_mask().unwrap();
942 let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
943 assert!(n_outliers > 0);
944 }
945
946 #[test]
947 fn test_trimmed_mean_regressor() {
948 let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
949 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 100.0, 101.0];
950
951 let regressor = RobustDummyRegressor::new(RobustStrategy::TrimmedMean {
952 trim_proportion: 0.2,
953 });
954
955 let fitted = regressor.fit(&x, &y).unwrap();
956 let predictions = fitted.predict(&x).unwrap();
957
958 assert_eq!(predictions.len(), 10);
959
960 let robust_mean = predictions[0];
962 let regular_mean = y.mean().unwrap();
963 assert!(robust_mean < regular_mean); }
965
966 #[test]
967 fn test_robust_scale_regressor() {
968 let x = Array2::from_shape_vec((8, 2), (0..16).map(|x| x as f64).collect()).unwrap();
969 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 100.0]; let regressor = RobustDummyRegressor::new(RobustStrategy::RobustScale {
972 scale_estimator: ScaleEstimator::MAD,
973 location_estimator: LocationEstimator::Median,
974 });
975
976 let fitted = regressor.fit(&x, &y).unwrap();
977 let predictions = fitted.predict(&x).unwrap();
978
979 assert_eq!(predictions.len(), 8);
980
981 let mut sorted_y = y.to_vec();
983 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
984 let expected_median = (sorted_y[3] + sorted_y[4]) / 2.0; assert_abs_diff_eq!(predictions[0], expected_median, epsilon = 0.1);
986 }
987
988 #[test]
989 fn test_breakdown_point_regressor() {
990 let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
991 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
992
993 let regressor = RobustDummyRegressor::new(RobustStrategy::BreakdownPoint {
994 breakdown_point: 0.3,
995 });
996
997 let fitted = regressor.fit(&x, &y).unwrap();
998 let predictions = fitted.predict(&x).unwrap();
999
1000 assert_eq!(predictions.len(), 10);
1001 assert_eq!(fitted.breakdown_point().unwrap(), 0.3);
1002 }
1003
1004 #[test]
1005 fn test_influence_resistant_regressor() {
1006 let x = Array2::from_shape_vec((8, 2), (0..16).map(|x| x as f64).collect()).unwrap();
1007 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 100.0]; let regressor = RobustDummyRegressor::new(RobustStrategy::InfluenceResistant {
1010 huber_delta: 1.345,
1011 max_iter: 50,
1012 tolerance: 1e-6,
1013 });
1014
1015 let fitted = regressor.fit(&x, &y).unwrap();
1016 let predictions = fitted.predict(&x).unwrap();
1017
1018 assert_eq!(predictions.len(), 8);
1019
1020 let weights = fitted.m_weights().unwrap();
1022 assert_eq!(weights.len(), 8);
1023
1024 assert!(weights[7] < weights[0]);
1026 }
1027
1028 #[test]
1029 fn test_robust_classifier() {
1030 let x = Array2::from_shape_vec((10, 2), (0..20).map(|x| x as f64).collect()).unwrap();
1031 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 3, 3]; let classifier = RobustDummyClassifier::new(RobustStrategy::OutlierResistant {
1034 contamination: 0.1,
1035 detection_method: OutlierDetectionMethod::IQR { multiplier: 1.5 },
1036 })
1037 .with_random_state(42);
1038
1039 let fitted = classifier.fit(&x, &y).unwrap();
1040 let predictions = fitted.predict(&x).unwrap();
1041
1042 assert_eq!(predictions.len(), 10);
1043
1044 let classes = fitted.classes_.as_ref().unwrap();
1046 for &pred in predictions.iter() {
1047 assert!(classes.iter().any(|&c| c == pred));
1048 }
1049 }
1050}