1use scirs2_core::ndarray::ArrayStatCompat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9
10use crate::error::{Result, TransformError};
11use statrs::statistics::Statistics;
12
13pub struct VarianceThreshold {
18 threshold: f64,
20 variances_: Option<Array1<f64>>,
22 selected_features_: Option<Vec<usize>>,
24}
25
26impl VarianceThreshold {
27 pub fn new(threshold: f64) -> Result<Self> {
43 if threshold < 0.0 {
44 return Err(TransformError::InvalidInput(
45 "Threshold must be non-negative".to_string(),
46 ));
47 }
48
49 Ok(VarianceThreshold {
50 threshold,
51 variances_: None,
52 selected_features_: None,
53 })
54 }
55
56 pub fn with_defaults() -> Self {
60 Self::new(0.0).unwrap()
61 }
62
63 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
71 where
72 S: Data,
73 S::Elem: Float + NumCast,
74 {
75 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
76
77 let n_samples = x_f64.shape()[0];
78 let n_features = x_f64.shape()[1];
79
80 if n_samples == 0 || n_features == 0 {
81 return Err(TransformError::InvalidInput("Empty input data".to_string()));
82 }
83
84 if n_samples < 2 {
85 return Err(TransformError::InvalidInput(
86 "At least 2 samples required to compute variance".to_string(),
87 ));
88 }
89
90 let mut variances = Array1::zeros(n_features);
92 let mut selected_features = Vec::new();
93
94 for j in 0..n_features {
95 let feature_data = x_f64.column(j);
96
97 let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
99
100 let variance = feature_data
102 .iter()
103 .map(|&x| (x - mean).powi(2))
104 .sum::<f64>()
105 / n_samples as f64;
106
107 variances[j] = variance;
108
109 if variance > self.threshold {
111 selected_features.push(j);
112 }
113 }
114
115 self.variances_ = Some(variances);
116 self.selected_features_ = Some(selected_features);
117
118 Ok(())
119 }
120
121 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
129 where
130 S: Data,
131 S::Elem: Float + NumCast,
132 {
133 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
134
135 let n_samples = x_f64.shape()[0];
136 let n_features = x_f64.shape()[1];
137
138 if self.selected_features_.is_none() {
139 return Err(TransformError::TransformationError(
140 "VarianceThreshold has not been fitted".to_string(),
141 ));
142 }
143
144 let selected_features = self.selected_features_.as_ref().unwrap();
145
146 if let Some(ref variances) = self.variances_ {
148 if n_features != variances.len() {
149 return Err(TransformError::InvalidInput(format!(
150 "x has {} features, but VarianceThreshold was fitted with {} features",
151 n_features,
152 variances.len()
153 )));
154 }
155 }
156
157 let n_selected = selected_features.len();
158 let mut transformed = Array2::zeros((n_samples, n_selected));
159
160 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
162 for i in 0..n_samples {
163 transformed[[i, new_idx]] = x_f64[[i, old_idx]];
164 }
165 }
166
167 Ok(transformed)
168 }
169
170 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
178 where
179 S: Data,
180 S::Elem: Float + NumCast,
181 {
182 self.fit(x)?;
183 self.transform(x)
184 }
185
186 pub fn variances(&self) -> Option<&Array1<f64>> {
191 self.variances_.as_ref()
192 }
193
194 pub fn get_support(&self) -> Option<&Vec<usize>> {
199 self.selected_features_.as_ref()
200 }
201
202 pub fn get_support_mask(&self) -> Option<Array1<bool>> {
207 if let (Some(ref variances), Some(ref selected)) =
208 (&self.variances_, &self.selected_features_)
209 {
210 let n_features = variances.len();
211 let mut mask = Array1::from_elem(n_features, false);
212
213 for &idx in selected {
214 mask[idx] = true;
215 }
216
217 Some(mask)
218 } else {
219 None
220 }
221 }
222
223 pub fn n_features_selected(&self) -> Option<usize> {
228 self.selected_features_.as_ref().map(|s| s.len())
229 }
230
231 pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
236 where
237 S: Data,
238 S::Elem: Float + NumCast,
239 {
240 Err(TransformError::TransformationError(
241 "inverse_transform is not supported for feature selection".to_string(),
242 ))
243 }
244}
245
246#[derive(Debug, Clone)]
251pub struct RecursiveFeatureElimination<F>
252where
253 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
254{
255 n_features_to_select: usize,
257 step: usize,
259 importance_func: F,
262 selected_features_: Option<Vec<usize>>,
264 ranking_: Option<Array1<usize>>,
266 scores_: Option<Array1<f64>>,
268}
269
270impl<F> RecursiveFeatureElimination<F>
271where
272 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
273{
274 pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
280 RecursiveFeatureElimination {
281 n_features_to_select,
282 step: 1,
283 importance_func: importancefunc,
284 selected_features_: None,
285 ranking_: None,
286 scores_: None,
287 }
288 }
289
290 pub fn with_step(mut self, step: usize) -> Self {
292 self.step = step.max(1);
293 self
294 }
295
296 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
302 let n_samples = x.shape()[0];
303 let n_features = x.shape()[1];
304
305 if n_samples != y.len() {
306 return Err(TransformError::InvalidInput(format!(
307 "X has {} samples but y has {} samples",
308 n_samples,
309 y.len()
310 )));
311 }
312
313 if self.n_features_to_select > n_features {
314 return Err(TransformError::InvalidInput(format!(
315 "n_features_to_select={} must be <= n_features={}",
316 self.n_features_to_select, n_features
317 )));
318 }
319
320 let mut remaining_features: Vec<usize> = (0..n_features).collect();
322 let mut ranking = Array1::zeros(n_features);
323 let mut current_rank = 1;
324
325 while remaining_features.len() > self.n_features_to_select {
327 let x_subset = self.subset_features(x, &remaining_features);
329
330 let importances = (self.importance_func)(&x_subset, y)?;
332
333 if importances.len() != remaining_features.len() {
334 return Err(TransformError::InvalidInput(
335 "Importance function returned wrong number of scores".to_string(),
336 ));
337 }
338
339 let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
341
342 let mut indices: Vec<usize> = (0..importances.len()).collect();
344 indices.sort_by(|&i, &j| importances[i].partial_cmp(&importances[j]).unwrap());
345
346 for i in 0..n_to_remove {
348 let feature_idx = remaining_features[indices[i]];
349 ranking[feature_idx] = n_features - current_rank + 1;
350 current_rank += 1;
351 }
352
353 let eliminated: std::collections::HashSet<usize> =
355 indices.iter().take(n_to_remove).cloned().collect();
356 let features_to_retain: Vec<usize> = remaining_features
357 .iter()
358 .filter(|&&idx| !eliminated.contains(&idx))
359 .cloned()
360 .collect();
361 remaining_features = features_to_retain;
362 }
363
364 for &feature_idx in &remaining_features {
366 ranking[feature_idx] = 1;
367 }
368
369 let x_final = self.subset_features(x, &remaining_features);
371 let final_scores = (self.importance_func)(&x_final, y)?;
372
373 let mut scores = Array1::zeros(n_features);
374 for (i, &feature_idx) in remaining_features.iter().enumerate() {
375 scores[feature_idx] = final_scores[i];
376 }
377
378 self.selected_features_ = Some(remaining_features);
379 self.ranking_ = Some(ranking);
380 self.scores_ = Some(scores);
381
382 Ok(())
383 }
384
385 fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
387 let n_samples = x.shape()[0];
388 let n_selected = features.len();
389 let mut subset = Array2::zeros((n_samples, n_selected));
390
391 for (new_idx, &old_idx) in features.iter().enumerate() {
392 subset.column_mut(new_idx).assign(&x.column(old_idx));
393 }
394
395 subset
396 }
397
398 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
400 if self.selected_features_.is_none() {
401 return Err(TransformError::TransformationError(
402 "RFE has not been fitted".to_string(),
403 ));
404 }
405
406 let selected = self.selected_features_.as_ref().unwrap();
407 Ok(self.subset_features(x, selected))
408 }
409
410 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
412 self.fit(x, y)?;
413 self.transform(x)
414 }
415
416 pub fn get_support(&self) -> Option<&Vec<usize>> {
418 self.selected_features_.as_ref()
419 }
420
421 pub fn ranking(&self) -> Option<&Array1<usize>> {
423 self.ranking_.as_ref()
424 }
425
426 pub fn scores(&self) -> Option<&Array1<f64>> {
428 self.scores_.as_ref()
429 }
430}
431
432#[derive(Debug, Clone)]
436pub struct MutualInfoSelector {
437 k: usize,
439 discrete_target: bool,
441 n_neighbors: usize,
443 selected_features_: Option<Vec<usize>>,
445 scores_: Option<Array1<f64>>,
447}
448
449impl MutualInfoSelector {
450 pub fn new(k: usize) -> Self {
455 MutualInfoSelector {
456 k,
457 discrete_target: false,
458 n_neighbors: 3,
459 selected_features_: None,
460 scores_: None,
461 }
462 }
463
464 pub fn with_discrete_target(mut self) -> Self {
466 self.discrete_target = true;
467 self
468 }
469
470 pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
472 self.n_neighbors = nneighbors;
473 self
474 }
475
476 fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
478 let n = x.len();
479 if n < self.n_neighbors + 1 {
480 return 0.0;
481 }
482
483 if !self.discrete_target {
485 let x_mean = x.mean_or(0.0);
487 let y_mean = y.mean_or(0.0);
488 let x_std = x.std(0.0);
489 let y_std = y.std(0.0);
490
491 if x_std < 1e-10 || y_std < 1e-10 {
492 return 0.0;
493 }
494
495 let mut correlation = 0.0;
496 for i in 0..n {
497 correlation += (x[i] - x_mean) * (y[i] - y_mean);
498 }
499 correlation /= (n as f64 - 1.0) * x_std * y_std;
500
501 if correlation.abs() >= 1.0 {
504 return 5.0; }
506 (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
507 } else {
508 let mut groups = std::collections::HashMap::new();
510
511 for i in 0..n {
512 let key = y[i].round() as i64;
513 groups.entry(key).or_insert_with(Vec::new).push(x[i]);
514 }
515
516 let total_mean = x.mean_or(0.0);
518 let total_var = x.variance();
519
520 if total_var < 1e-10 {
521 return 0.0;
522 }
523
524 let mut between_var = 0.0;
525 for (_, values) in groups {
526 let group_mean = values.iter().sum::<f64>() / values.len() as f64;
527 let weight = values.len() as f64 / n as f64;
528 between_var += weight * (group_mean - total_mean).powi(2);
529 }
530
531 (between_var / total_var).min(1.0) * 2.0 }
533 }
534
535 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
537 let n_features = x.shape()[1];
538
539 if self.k > n_features {
540 return Err(TransformError::InvalidInput(format!(
541 "k={} must be <= n_features={}",
542 self.k, n_features
543 )));
544 }
545
546 let mut scores = Array1::zeros(n_features);
548
549 for j in 0..n_features {
550 let feature = x.column(j).to_owned();
551 scores[j] = self.estimate_mutual_info(&feature, y);
552 }
553
554 let mut indices: Vec<usize> = (0..n_features).collect();
556 indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap());
557
558 let selected_features = indices.into_iter().take(self.k).collect();
559
560 self.scores_ = Some(scores);
561 self.selected_features_ = Some(selected_features);
562
563 Ok(())
564 }
565
566 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
568 if self.selected_features_.is_none() {
569 return Err(TransformError::TransformationError(
570 "MutualInfoSelector has not been fitted".to_string(),
571 ));
572 }
573
574 let selected = self.selected_features_.as_ref().unwrap();
575 let n_samples = x.shape()[0];
576 let mut transformed = Array2::zeros((n_samples, self.k));
577
578 for (new_idx, &old_idx) in selected.iter().enumerate() {
579 transformed.column_mut(new_idx).assign(&x.column(old_idx));
580 }
581
582 Ok(transformed)
583 }
584
585 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
587 self.fit(x, y)?;
588 self.transform(x)
589 }
590
591 pub fn get_support(&self) -> Option<&Vec<usize>> {
593 self.selected_features_.as_ref()
594 }
595
596 pub fn scores(&self) -> Option<&Array1<f64>> {
598 self.scores_.as_ref()
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use approx::assert_abs_diff_eq;
606 use scirs2_core::ndarray::Array;
607
608 #[test]
609 fn test_variance_threshold_basic() {
610 let data = Array::from_shape_vec(
616 (3, 4),
617 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
618 )
619 .unwrap();
620
621 let mut selector = VarianceThreshold::with_defaults();
622 let transformed = selector.fit_transform(&data).unwrap();
623
624 assert_eq!(transformed.shape(), &[3, 2]);
626
627 let selected = selector.get_support().unwrap();
629 assert_eq!(selected, &[1, 3]);
630
631 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); }
640
641 #[test]
642 fn test_variance_threshold_custom() {
643 let data = Array::from_shape_vec(
645 (4, 3),
646 vec![
647 1.0, 1.0, 1.0, 2.0, 1.1, 2.0, 3.0, 1.0, 3.0, 4.0, 1.1, 4.0, ],
652 )
653 .unwrap();
654
655 let mut selector = VarianceThreshold::new(0.1).unwrap();
657 let transformed = selector.fit_transform(&data).unwrap();
658
659 assert_eq!(transformed.shape(), &[4, 2]);
662
663 let selected = selector.get_support().unwrap();
664 assert_eq!(selected, &[0, 2]);
665
666 let variances = selector.variances().unwrap();
668 assert!(variances[0] > 0.1); assert!(variances[1] <= 0.1); assert!(variances[2] > 0.1); }
672
673 #[test]
674 fn test_variance_threshold_support_mask() {
675 let data = Array::from_shape_vec(
676 (3, 4),
677 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
678 )
679 .unwrap();
680
681 let mut selector = VarianceThreshold::with_defaults();
682 selector.fit(&data).unwrap();
683
684 let mask = selector.get_support_mask().unwrap();
685 assert_eq!(mask.len(), 4);
686 assert!(!mask[0]); assert!(mask[1]); assert!(!mask[2]); assert!(mask[3]); assert_eq!(selector.n_features_selected().unwrap(), 2);
692 }
693
694 #[test]
695 fn test_variance_threshold_all_removed() {
696 let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
698
699 let mut selector = VarianceThreshold::with_defaults();
700 let transformed = selector.fit_transform(&data).unwrap();
701
702 assert_eq!(transformed.shape(), &[3, 0]);
704 assert_eq!(selector.n_features_selected().unwrap(), 0);
705 }
706
707 #[test]
708 fn test_variance_threshold_errors() {
709 assert!(VarianceThreshold::new(-0.1).is_err());
711
712 let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
714 let mut selector = VarianceThreshold::with_defaults();
715 assert!(selector.fit(&small_data).is_err());
716
717 let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
719 let selector_unfitted = VarianceThreshold::with_defaults();
720 assert!(selector_unfitted.transform(&data).is_err());
721
722 let mut selector = VarianceThreshold::with_defaults();
724 selector.fit(&data).unwrap();
725 assert!(selector.inverse_transform(&data).is_err());
726 }
727
728 #[test]
729 fn test_variance_threshold_feature_mismatch() {
730 let train_data =
731 Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
732 .unwrap();
733 let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let mut selector = VarianceThreshold::with_defaults();
736 selector.fit(&train_data).unwrap();
737 assert!(selector.transform(&test_data).is_err());
738 }
739
740 #[test]
741 fn test_variance_calculation() {
742 let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
745
746 let mut selector = VarianceThreshold::with_defaults();
747 selector.fit(&data).unwrap();
748
749 let variances = selector.variances().unwrap();
750 let expected_variance = 2.0 / 3.0;
751 assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
752 }
753
754 #[test]
755 fn test_rfe_basic() {
756 let n_samples = 100;
758 let mut data_vec = Vec::new();
759 let mut target_vec = Vec::new();
760
761 for i in 0..n_samples {
762 let x1 = i as f64 / n_samples as f64;
763 let x2 = (i as f64 / n_samples as f64).sin();
764 let x3 = scirs2_core::random::random::<f64>(); let x4 = 2.0 * x1; data_vec.extend_from_slice(&[x1, x2, x3, x4]);
768 target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
769 }
770
771 let x = Array::from_shape_vec((n_samples, 4), data_vec).unwrap();
772 let y = Array::from_vec(target_vec);
773
774 let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
776 let n_features = x.shape()[1];
777 let mut scores = Array1::zeros(n_features);
778
779 for j in 0..n_features {
780 let feature = x.column(j);
781 let corr = pearson_correlation(&feature.to_owned(), y);
782 scores[j] = corr.abs();
783 }
784
785 Ok(scores)
786 };
787
788 let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
789 let transformed = rfe.fit_transform(&x, &y).unwrap();
790
791 assert_eq!(transformed.shape()[1], 2);
793
794 let selected = rfe.get_support().unwrap();
796 assert!(selected.contains(&0) || selected.contains(&3));
797 }
798
799 #[test]
800 fn test_mutual_info_continuous() {
801 let n_samples = 100;
803 let mut x_data = Vec::new();
804 let mut y_data = Vec::new();
805
806 for i in 0..n_samples {
807 let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
808
809 let x0 = t;
811 let x1 = scirs2_core::random::random::<f64>();
813 let x2 = t.sin();
815
816 x_data.extend_from_slice(&[x0, x1, x2]);
817 y_data.push(t + 0.5 * t.sin());
818 }
819
820 let x = Array::from_shape_vec((n_samples, 3), x_data).unwrap();
821 let y = Array::from_vec(y_data);
822
823 let mut selector = MutualInfoSelector::new(2);
824 selector.fit(&x, &y).unwrap();
825
826 let scores = selector.scores().unwrap();
827
828 assert!(scores[0] > scores[1]);
832 assert!(scores[2] > scores[1]);
833 }
834
835 #[test]
836 fn test_mutual_info_discrete() {
837 let x = Array::from_shape_vec(
839 (6, 3),
840 vec![
841 1.0, 0.1, 5.0, 1.1, 0.2, 5.1, 2.0, 0.1, 4.0, 2.1, 0.2, 4.1, 3.0, 0.1, 3.0, 3.1, 0.2, 3.1, ],
848 )
849 .unwrap();
850
851 let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
852
853 let mut selector = MutualInfoSelector::new(2).with_discrete_target();
854 let transformed = selector.fit_transform(&x, &y).unwrap();
855
856 assert_eq!(transformed.shape(), &[6, 2]);
857
858 let selected = selector.get_support().unwrap();
860 assert!(!selected.contains(&1));
861 }
862
863 fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
865 #[allow(unused_variables)]
866 let n = x.len() as f64;
867 let x_mean = x.mean_or(0.0);
868 let y_mean = y.mean_or(0.0);
869
870 let mut num = 0.0;
871 let mut x_var = 0.0;
872 let mut y_var = 0.0;
873
874 for i in 0..x.len() {
875 let x_diff = x[i] - x_mean;
876 let y_diff = y[i] - y_mean;
877 num += x_diff * y_diff;
878 x_var += x_diff * x_diff;
879 y_var += y_diff * y_diff;
880 }
881
882 if x_var * y_var > 0.0 {
883 num / (x_var * y_var).sqrt()
884 } else {
885 0.0
886 }
887 }
888}