1use scirs2_core::ndarray::ArrayStatCompat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9
10use crate::error::{Result, TransformError};
11use statrs::statistics::Statistics;
12
13pub struct VarianceThreshold {
18 threshold: f64,
20 variances_: Option<Array1<f64>>,
22 selected_features_: Option<Vec<usize>>,
24}
25
26impl VarianceThreshold {
27 pub fn new(threshold: f64) -> Result<Self> {
43 if threshold < 0.0 {
44 return Err(TransformError::InvalidInput(
45 "Threshold must be non-negative".to_string(),
46 ));
47 }
48
49 Ok(VarianceThreshold {
50 threshold,
51 variances_: None,
52 selected_features_: None,
53 })
54 }
55
56 pub fn with_defaults() -> Self {
60 Self::new(0.0).expect("Operation failed")
61 }
62
63 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
71 where
72 S: Data,
73 S::Elem: Float + NumCast,
74 {
75 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
76
77 let n_samples = x_f64.shape()[0];
78 let n_features = x_f64.shape()[1];
79
80 if n_samples == 0 || n_features == 0 {
81 return Err(TransformError::InvalidInput("Empty input data".to_string()));
82 }
83
84 if n_samples < 2 {
85 return Err(TransformError::InvalidInput(
86 "At least 2 samples required to compute variance".to_string(),
87 ));
88 }
89
90 let mut variances = Array1::zeros(n_features);
92 let mut selected_features = Vec::new();
93
94 for j in 0..n_features {
95 let feature_data = x_f64.column(j);
96
97 let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
99
100 let variance = feature_data
102 .iter()
103 .map(|&x| (x - mean).powi(2))
104 .sum::<f64>()
105 / n_samples as f64;
106
107 variances[j] = variance;
108
109 if variance > self.threshold {
111 selected_features.push(j);
112 }
113 }
114
115 self.variances_ = Some(variances);
116 self.selected_features_ = Some(selected_features);
117
118 Ok(())
119 }
120
121 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
129 where
130 S: Data,
131 S::Elem: Float + NumCast,
132 {
133 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
134
135 let n_samples = x_f64.shape()[0];
136 let n_features = x_f64.shape()[1];
137
138 if self.selected_features_.is_none() {
139 return Err(TransformError::TransformationError(
140 "VarianceThreshold has not been fitted".to_string(),
141 ));
142 }
143
144 let selected_features = self.selected_features_.as_ref().expect("Operation failed");
145
146 if let Some(ref variances) = self.variances_ {
148 if n_features != variances.len() {
149 return Err(TransformError::InvalidInput(format!(
150 "x has {} features, but VarianceThreshold was fitted with {} features",
151 n_features,
152 variances.len()
153 )));
154 }
155 }
156
157 let n_selected = selected_features.len();
158 let mut transformed = Array2::zeros((n_samples, n_selected));
159
160 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
162 for i in 0..n_samples {
163 transformed[[i, new_idx]] = x_f64[[i, old_idx]];
164 }
165 }
166
167 Ok(transformed)
168 }
169
170 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
178 where
179 S: Data,
180 S::Elem: Float + NumCast,
181 {
182 self.fit(x)?;
183 self.transform(x)
184 }
185
186 pub fn variances(&self) -> Option<&Array1<f64>> {
191 self.variances_.as_ref()
192 }
193
194 pub fn get_support(&self) -> Option<&Vec<usize>> {
199 self.selected_features_.as_ref()
200 }
201
202 pub fn get_support_mask(&self) -> Option<Array1<bool>> {
207 if let (Some(ref variances), Some(ref selected)) =
208 (&self.variances_, &self.selected_features_)
209 {
210 let n_features = variances.len();
211 let mut mask = Array1::from_elem(n_features, false);
212
213 for &idx in selected {
214 mask[idx] = true;
215 }
216
217 Some(mask)
218 } else {
219 None
220 }
221 }
222
223 pub fn n_features_selected(&self) -> Option<usize> {
228 self.selected_features_.as_ref().map(|s| s.len())
229 }
230
231 pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
236 where
237 S: Data,
238 S::Elem: Float + NumCast,
239 {
240 Err(TransformError::TransformationError(
241 "inverse_transform is not supported for feature selection".to_string(),
242 ))
243 }
244}
245
246#[derive(Debug, Clone)]
251pub struct RecursiveFeatureElimination<F>
252where
253 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
254{
255 n_features_to_select: usize,
257 step: usize,
259 importance_func: F,
262 selected_features_: Option<Vec<usize>>,
264 ranking_: Option<Array1<usize>>,
266 scores_: Option<Array1<f64>>,
268}
269
270impl<F> RecursiveFeatureElimination<F>
271where
272 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
273{
274 pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
280 RecursiveFeatureElimination {
281 n_features_to_select,
282 step: 1,
283 importance_func: importancefunc,
284 selected_features_: None,
285 ranking_: None,
286 scores_: None,
287 }
288 }
289
290 pub fn with_step(mut self, step: usize) -> Self {
292 self.step = step.max(1);
293 self
294 }
295
296 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
302 let n_samples = x.shape()[0];
303 let n_features = x.shape()[1];
304
305 if n_samples != y.len() {
306 return Err(TransformError::InvalidInput(format!(
307 "X has {} samples but y has {} samples",
308 n_samples,
309 y.len()
310 )));
311 }
312
313 if self.n_features_to_select > n_features {
314 return Err(TransformError::InvalidInput(format!(
315 "n_features_to_select={} must be <= n_features={}",
316 self.n_features_to_select, n_features
317 )));
318 }
319
320 let mut remaining_features: Vec<usize> = (0..n_features).collect();
322 let mut ranking = Array1::zeros(n_features);
323 let mut current_rank = 1;
324
325 while remaining_features.len() > self.n_features_to_select {
327 let x_subset = self.subset_features(x, &remaining_features);
329
330 let importances = (self.importance_func)(&x_subset, y)?;
332
333 if importances.len() != remaining_features.len() {
334 return Err(TransformError::InvalidInput(
335 "Importance function returned wrong number of scores".to_string(),
336 ));
337 }
338
339 let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
341
342 let mut indices: Vec<usize> = (0..importances.len()).collect();
344 indices.sort_by(|&i, &j| {
345 importances[i]
346 .partial_cmp(&importances[j])
347 .expect("Operation failed")
348 });
349
350 for i in 0..n_to_remove {
352 let feature_idx = remaining_features[indices[i]];
353 ranking[feature_idx] = n_features - current_rank + 1;
354 current_rank += 1;
355 }
356
357 let eliminated: std::collections::HashSet<usize> =
359 indices.iter().take(n_to_remove).cloned().collect();
360 let features_to_retain: Vec<usize> = remaining_features
361 .iter()
362 .filter(|&&idx| !eliminated.contains(&idx))
363 .cloned()
364 .collect();
365 remaining_features = features_to_retain;
366 }
367
368 for &feature_idx in &remaining_features {
370 ranking[feature_idx] = 1;
371 }
372
373 let x_final = self.subset_features(x, &remaining_features);
375 let final_scores = (self.importance_func)(&x_final, y)?;
376
377 let mut scores = Array1::zeros(n_features);
378 for (i, &feature_idx) in remaining_features.iter().enumerate() {
379 scores[feature_idx] = final_scores[i];
380 }
381
382 self.selected_features_ = Some(remaining_features);
383 self.ranking_ = Some(ranking);
384 self.scores_ = Some(scores);
385
386 Ok(())
387 }
388
389 fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
391 let n_samples = x.shape()[0];
392 let n_selected = features.len();
393 let mut subset = Array2::zeros((n_samples, n_selected));
394
395 for (new_idx, &old_idx) in features.iter().enumerate() {
396 subset.column_mut(new_idx).assign(&x.column(old_idx));
397 }
398
399 subset
400 }
401
402 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
404 if self.selected_features_.is_none() {
405 return Err(TransformError::TransformationError(
406 "RFE has not been fitted".to_string(),
407 ));
408 }
409
410 let selected = self.selected_features_.as_ref().expect("Operation failed");
411 Ok(self.subset_features(x, selected))
412 }
413
414 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
416 self.fit(x, y)?;
417 self.transform(x)
418 }
419
420 pub fn get_support(&self) -> Option<&Vec<usize>> {
422 self.selected_features_.as_ref()
423 }
424
425 pub fn ranking(&self) -> Option<&Array1<usize>> {
427 self.ranking_.as_ref()
428 }
429
430 pub fn scores(&self) -> Option<&Array1<f64>> {
432 self.scores_.as_ref()
433 }
434}
435
436#[derive(Debug, Clone)]
440pub struct MutualInfoSelector {
441 k: usize,
443 discrete_target: bool,
445 n_neighbors: usize,
447 selected_features_: Option<Vec<usize>>,
449 scores_: Option<Array1<f64>>,
451}
452
453impl MutualInfoSelector {
454 pub fn new(k: usize) -> Self {
459 MutualInfoSelector {
460 k,
461 discrete_target: false,
462 n_neighbors: 3,
463 selected_features_: None,
464 scores_: None,
465 }
466 }
467
468 pub fn with_discrete_target(mut self) -> Self {
470 self.discrete_target = true;
471 self
472 }
473
474 pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
476 self.n_neighbors = nneighbors;
477 self
478 }
479
480 fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
482 let n = x.len();
483 if n < self.n_neighbors + 1 {
484 return 0.0;
485 }
486
487 if !self.discrete_target {
489 let x_mean = x.mean_or(0.0);
491 let y_mean = y.mean_or(0.0);
492 let x_std = x.std(0.0);
493 let y_std = y.std(0.0);
494
495 if x_std < 1e-10 || y_std < 1e-10 {
496 return 0.0;
497 }
498
499 let mut correlation = 0.0;
500 for i in 0..n {
501 correlation += (x[i] - x_mean) * (y[i] - y_mean);
502 }
503 correlation /= (n as f64 - 1.0) * x_std * y_std;
504
505 if correlation.abs() >= 1.0 {
508 return 5.0; }
510 (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
511 } else {
512 let mut groups = std::collections::HashMap::new();
514
515 for i in 0..n {
516 let key = y[i].round() as i64;
517 groups.entry(key).or_insert_with(Vec::new).push(x[i]);
518 }
519
520 let total_mean = x.mean_or(0.0);
522 let total_var = x.variance();
523
524 if total_var < 1e-10 {
525 return 0.0;
526 }
527
528 let mut between_var = 0.0;
529 for (_, values) in groups {
530 let group_mean = values.iter().sum::<f64>() / values.len() as f64;
531 let weight = values.len() as f64 / n as f64;
532 between_var += weight * (group_mean - total_mean).powi(2);
533 }
534
535 (between_var / total_var).min(1.0) * 2.0 }
537 }
538
539 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
541 let n_features = x.shape()[1];
542
543 if self.k > n_features {
544 return Err(TransformError::InvalidInput(format!(
545 "k={} must be <= n_features={}",
546 self.k, n_features
547 )));
548 }
549
550 let mut scores = Array1::zeros(n_features);
552
553 for j in 0..n_features {
554 let feature = x.column(j).to_owned();
555 scores[j] = self.estimate_mutual_info(&feature, y);
556 }
557
558 let mut indices: Vec<usize> = (0..n_features).collect();
560 indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).expect("Operation failed"));
561
562 let selected_features = indices.into_iter().take(self.k).collect();
563
564 self.scores_ = Some(scores);
565 self.selected_features_ = Some(selected_features);
566
567 Ok(())
568 }
569
570 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
572 if self.selected_features_.is_none() {
573 return Err(TransformError::TransformationError(
574 "MutualInfoSelector has not been fitted".to_string(),
575 ));
576 }
577
578 let selected = self.selected_features_.as_ref().expect("Operation failed");
579 let n_samples = x.shape()[0];
580 let mut transformed = Array2::zeros((n_samples, self.k));
581
582 for (new_idx, &old_idx) in selected.iter().enumerate() {
583 transformed.column_mut(new_idx).assign(&x.column(old_idx));
584 }
585
586 Ok(transformed)
587 }
588
589 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
591 self.fit(x, y)?;
592 self.transform(x)
593 }
594
595 pub fn get_support(&self) -> Option<&Vec<usize>> {
597 self.selected_features_.as_ref()
598 }
599
600 pub fn scores(&self) -> Option<&Array1<f64>> {
602 self.scores_.as_ref()
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use approx::assert_abs_diff_eq;
610 use scirs2_core::ndarray::Array;
611
612 #[test]
613 fn test_variance_threshold_basic() {
614 let data = Array::from_shape_vec(
620 (3, 4),
621 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
622 )
623 .expect("Operation failed");
624
625 let mut selector = VarianceThreshold::with_defaults();
626 let transformed = selector.fit_transform(&data).expect("Operation failed");
627
628 assert_eq!(transformed.shape(), &[3, 2]);
630
631 let selected = selector.get_support().expect("Operation failed");
633 assert_eq!(selected, &[1, 3]);
634
635 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); }
644
645 #[test]
646 fn test_variance_threshold_custom() {
647 let data = Array::from_shape_vec(
649 (4, 3),
650 vec![
651 1.0, 1.0, 1.0, 2.0, 1.1, 2.0, 3.0, 1.0, 3.0, 4.0, 1.1, 4.0, ],
656 )
657 .expect("Operation failed");
658
659 let mut selector = VarianceThreshold::new(0.1).expect("Operation failed");
661 let transformed = selector.fit_transform(&data).expect("Operation failed");
662
663 assert_eq!(transformed.shape(), &[4, 2]);
666
667 let selected = selector.get_support().expect("Operation failed");
668 assert_eq!(selected, &[0, 2]);
669
670 let variances = selector.variances().expect("Operation failed");
672 assert!(variances[0] > 0.1); assert!(variances[1] <= 0.1); assert!(variances[2] > 0.1); }
676
677 #[test]
678 fn test_variance_threshold_support_mask() {
679 let data = Array::from_shape_vec(
680 (3, 4),
681 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
682 )
683 .expect("Operation failed");
684
685 let mut selector = VarianceThreshold::with_defaults();
686 selector.fit(&data).expect("Operation failed");
687
688 let mask = selector.get_support_mask().expect("Operation failed");
689 assert_eq!(mask.len(), 4);
690 assert!(!mask[0]); assert!(mask[1]); assert!(!mask[2]); assert!(mask[3]); assert_eq!(selector.n_features_selected().expect("Operation failed"), 2);
696 }
697
698 #[test]
699 fn test_variance_threshold_all_removed() {
700 let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0])
702 .expect("Operation failed");
703
704 let mut selector = VarianceThreshold::with_defaults();
705 let transformed = selector.fit_transform(&data).expect("Operation failed");
706
707 assert_eq!(transformed.shape(), &[3, 0]);
709 assert_eq!(selector.n_features_selected().expect("Operation failed"), 0);
710 }
711
712 #[test]
713 fn test_variance_threshold_errors() {
714 assert!(VarianceThreshold::new(-0.1).is_err());
716
717 let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("Operation failed");
719 let mut selector = VarianceThreshold::with_defaults();
720 assert!(selector.fit(&small_data).is_err());
721
722 let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
724 .expect("Operation failed");
725 let selector_unfitted = VarianceThreshold::with_defaults();
726 assert!(selector_unfitted.transform(&data).is_err());
727
728 let mut selector = VarianceThreshold::with_defaults();
730 selector.fit(&data).expect("Operation failed");
731 assert!(selector.inverse_transform(&data).is_err());
732 }
733
734 #[test]
735 fn test_variance_threshold_feature_mismatch() {
736 let train_data =
737 Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
738 .expect("Operation failed");
739 let test_data =
740 Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed"); let mut selector = VarianceThreshold::with_defaults();
743 selector.fit(&train_data).expect("Operation failed");
744 assert!(selector.transform(&test_data).is_err());
745 }
746
747 #[test]
748 fn test_variance_calculation() {
749 let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).expect("Operation failed");
752
753 let mut selector = VarianceThreshold::with_defaults();
754 selector.fit(&data).expect("Operation failed");
755
756 let variances = selector.variances().expect("Operation failed");
757 let expected_variance = 2.0 / 3.0;
758 assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
759 }
760
761 #[test]
762 fn test_rfe_basic() {
763 let n_samples = 100;
765 let mut data_vec = Vec::new();
766 let mut target_vec = Vec::new();
767
768 for i in 0..n_samples {
769 let x1 = i as f64 / n_samples as f64;
770 let x2 = (i as f64 / n_samples as f64).sin();
771 let x3 = scirs2_core::random::random::<f64>(); let x4 = 2.0 * x1; data_vec.extend_from_slice(&[x1, x2, x3, x4]);
775 target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
776 }
777
778 let x = Array::from_shape_vec((n_samples, 4), data_vec).expect("Operation failed");
779 let y = Array::from_vec(target_vec);
780
781 let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
783 let n_features = x.shape()[1];
784 let mut scores = Array1::zeros(n_features);
785
786 for j in 0..n_features {
787 let feature = x.column(j);
788 let corr = pearson_correlation(&feature.to_owned(), y);
789 scores[j] = corr.abs();
790 }
791
792 Ok(scores)
793 };
794
795 let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
796 let transformed = rfe.fit_transform(&x, &y).expect("Operation failed");
797
798 assert_eq!(transformed.shape()[1], 2);
800
801 let selected = rfe.get_support().expect("Operation failed");
803 assert!(selected.contains(&0) || selected.contains(&3));
804 }
805
806 #[test]
807 fn test_mutual_info_continuous() {
808 let n_samples = 100;
810 let mut x_data = Vec::new();
811 let mut y_data = Vec::new();
812
813 for i in 0..n_samples {
814 let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
815
816 let x0 = t;
818 let x1 = scirs2_core::random::random::<f64>();
820 let x2 = t.sin();
822
823 x_data.extend_from_slice(&[x0, x1, x2]);
824 y_data.push(t + 0.5 * t.sin());
825 }
826
827 let x = Array::from_shape_vec((n_samples, 3), x_data).expect("Operation failed");
828 let y = Array::from_vec(y_data);
829
830 let mut selector = MutualInfoSelector::new(2);
831 selector.fit(&x, &y).expect("Operation failed");
832
833 let scores = selector.scores().expect("Operation failed");
834
835 assert!(scores[0] > scores[1]);
839 assert!(scores[2] > scores[1]);
840 }
841
842 #[test]
843 fn test_mutual_info_discrete() {
844 let x = Array::from_shape_vec(
846 (6, 3),
847 vec![
848 1.0, 0.1, 5.0, 1.1, 0.2, 5.1, 2.0, 0.1, 4.0, 2.1, 0.2, 4.1, 3.0, 0.1, 3.0, 3.1, 0.2, 3.1, ],
855 )
856 .expect("Operation failed");
857
858 let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
859
860 let mut selector = MutualInfoSelector::new(2).with_discrete_target();
861 let transformed = selector.fit_transform(&x, &y).expect("Operation failed");
862
863 assert_eq!(transformed.shape(), &[6, 2]);
864
865 let selected = selector.get_support().expect("Operation failed");
867 assert!(!selected.contains(&1));
868 }
869
870 fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
872 #[allow(unused_variables)]
873 let n = x.len() as f64;
874 let x_mean = x.mean_or(0.0);
875 let y_mean = y.mean_or(0.0);
876
877 let mut num = 0.0;
878 let mut x_var = 0.0;
879 let mut y_var = 0.0;
880
881 for i in 0..x.len() {
882 let x_diff = x[i] - x_mean;
883 let y_diff = y[i] - y_mean;
884 num += x_diff * y_diff;
885 x_var += x_diff * x_diff;
886 y_var += y_diff * y_diff;
887 }
888
889 if x_var * y_var > 0.0 {
890 num / (x_var * y_var).sqrt()
891 } else {
892 0.0
893 }
894 }
895}