1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10use statrs::statistics::Statistics;
11
12pub struct VarianceThreshold {
17 threshold: f64,
19 variances_: Option<Array1<f64>>,
21 selected_features_: Option<Vec<usize>>,
23}
24
25impl VarianceThreshold {
26 pub fn new(threshold: f64) -> Result<Self> {
42 if threshold < 0.0 {
43 return Err(TransformError::InvalidInput(
44 "Threshold must be non-negative".to_string(),
45 ));
46 }
47
48 Ok(VarianceThreshold {
49 threshold,
50 variances_: None,
51 selected_features_: None,
52 })
53 }
54
55 pub fn with_defaults() -> Self {
59 Self::new(0.0).unwrap()
60 }
61
62 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
70 where
71 S: Data,
72 S::Elem: Float + NumCast,
73 {
74 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
75
76 let n_samples = x_f64.shape()[0];
77 let n_features = x_f64.shape()[1];
78
79 if n_samples == 0 || n_features == 0 {
80 return Err(TransformError::InvalidInput("Empty input data".to_string()));
81 }
82
83 if n_samples < 2 {
84 return Err(TransformError::InvalidInput(
85 "At least 2 samples required to compute variance".to_string(),
86 ));
87 }
88
89 let mut variances = Array1::zeros(n_features);
91 let mut selected_features = Vec::new();
92
93 for j in 0..n_features {
94 let feature_data = x_f64.column(j);
95
96 let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
98
99 let variance = feature_data
101 .iter()
102 .map(|&x| (x - mean).powi(2))
103 .sum::<f64>()
104 / n_samples as f64;
105
106 variances[j] = variance;
107
108 if variance > self.threshold {
110 selected_features.push(j);
111 }
112 }
113
114 self.variances_ = Some(variances);
115 self.selected_features_ = Some(selected_features);
116
117 Ok(())
118 }
119
120 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
128 where
129 S: Data,
130 S::Elem: Float + NumCast,
131 {
132 let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
133
134 let n_samples = x_f64.shape()[0];
135 let n_features = x_f64.shape()[1];
136
137 if self.selected_features_.is_none() {
138 return Err(TransformError::TransformationError(
139 "VarianceThreshold has not been fitted".to_string(),
140 ));
141 }
142
143 let selected_features = self.selected_features_.as_ref().unwrap();
144
145 if let Some(ref variances) = self.variances_ {
147 if n_features != variances.len() {
148 return Err(TransformError::InvalidInput(format!(
149 "x has {} features, but VarianceThreshold was fitted with {} features",
150 n_features,
151 variances.len()
152 )));
153 }
154 }
155
156 let n_selected = selected_features.len();
157 let mut transformed = Array2::zeros((n_samples, n_selected));
158
159 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
161 for i in 0..n_samples {
162 transformed[[i, new_idx]] = x_f64[[i, old_idx]];
163 }
164 }
165
166 Ok(transformed)
167 }
168
169 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
177 where
178 S: Data,
179 S::Elem: Float + NumCast,
180 {
181 self.fit(x)?;
182 self.transform(x)
183 }
184
185 pub fn variances(&self) -> Option<&Array1<f64>> {
190 self.variances_.as_ref()
191 }
192
193 pub fn get_support(&self) -> Option<&Vec<usize>> {
198 self.selected_features_.as_ref()
199 }
200
201 pub fn get_support_mask(&self) -> Option<Array1<bool>> {
206 if let (Some(ref variances), Some(ref selected)) =
207 (&self.variances_, &self.selected_features_)
208 {
209 let n_features = variances.len();
210 let mut mask = Array1::from_elem(n_features, false);
211
212 for &idx in selected {
213 mask[idx] = true;
214 }
215
216 Some(mask)
217 } else {
218 None
219 }
220 }
221
222 pub fn n_features_selected(&self) -> Option<usize> {
227 self.selected_features_.as_ref().map(|s| s.len())
228 }
229
230 pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
235 where
236 S: Data,
237 S::Elem: Float + NumCast,
238 {
239 Err(TransformError::TransformationError(
240 "inverse_transform is not supported for feature selection".to_string(),
241 ))
242 }
243}
244
245#[derive(Debug, Clone)]
250pub struct RecursiveFeatureElimination<F>
251where
252 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
253{
254 n_features_to_select: usize,
256 step: usize,
258 importance_func: F,
261 selected_features_: Option<Vec<usize>>,
263 ranking_: Option<Array1<usize>>,
265 scores_: Option<Array1<f64>>,
267}
268
269impl<F> RecursiveFeatureElimination<F>
270where
271 F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
272{
273 pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
279 RecursiveFeatureElimination {
280 n_features_to_select,
281 step: 1,
282 importance_func: importancefunc,
283 selected_features_: None,
284 ranking_: None,
285 scores_: None,
286 }
287 }
288
289 pub fn with_step(mut self, step: usize) -> Self {
291 self.step = step.max(1);
292 self
293 }
294
295 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
301 let n_samples = x.shape()[0];
302 let n_features = x.shape()[1];
303
304 if n_samples != y.len() {
305 return Err(TransformError::InvalidInput(format!(
306 "X has {} samples but y has {} samples",
307 n_samples,
308 y.len()
309 )));
310 }
311
312 if self.n_features_to_select > n_features {
313 return Err(TransformError::InvalidInput(format!(
314 "n_features_to_select={} must be <= n_features={}",
315 self.n_features_to_select, n_features
316 )));
317 }
318
319 let mut remaining_features: Vec<usize> = (0..n_features).collect();
321 let mut ranking = Array1::zeros(n_features);
322 let mut current_rank = 1;
323
324 while remaining_features.len() > self.n_features_to_select {
326 let x_subset = self.subset_features(x, &remaining_features);
328
329 let importances = (self.importance_func)(&x_subset, y)?;
331
332 if importances.len() != remaining_features.len() {
333 return Err(TransformError::InvalidInput(
334 "Importance function returned wrong number of scores".to_string(),
335 ));
336 }
337
338 let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
340
341 let mut indices: Vec<usize> = (0..importances.len()).collect();
343 indices.sort_by(|&i, &j| importances[i].partial_cmp(&importances[j]).unwrap());
344
345 for i in 0..n_to_remove {
347 let feature_idx = remaining_features[indices[i]];
348 ranking[feature_idx] = n_features - current_rank + 1;
349 current_rank += 1;
350 }
351
352 let eliminated: std::collections::HashSet<usize> =
354 indices.iter().take(n_to_remove).cloned().collect();
355 let features_to_retain: Vec<usize> = remaining_features
356 .iter()
357 .filter(|&&idx| !eliminated.contains(&idx))
358 .cloned()
359 .collect();
360 remaining_features = features_to_retain;
361 }
362
363 for &feature_idx in &remaining_features {
365 ranking[feature_idx] = 1;
366 }
367
368 let x_final = self.subset_features(x, &remaining_features);
370 let final_scores = (self.importance_func)(&x_final, y)?;
371
372 let mut scores = Array1::zeros(n_features);
373 for (i, &feature_idx) in remaining_features.iter().enumerate() {
374 scores[feature_idx] = final_scores[i];
375 }
376
377 self.selected_features_ = Some(remaining_features);
378 self.ranking_ = Some(ranking);
379 self.scores_ = Some(scores);
380
381 Ok(())
382 }
383
384 fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
386 let n_samples = x.shape()[0];
387 let n_selected = features.len();
388 let mut subset = Array2::zeros((n_samples, n_selected));
389
390 for (new_idx, &old_idx) in features.iter().enumerate() {
391 subset.column_mut(new_idx).assign(&x.column(old_idx));
392 }
393
394 subset
395 }
396
397 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
399 if self.selected_features_.is_none() {
400 return Err(TransformError::TransformationError(
401 "RFE has not been fitted".to_string(),
402 ));
403 }
404
405 let selected = self.selected_features_.as_ref().unwrap();
406 Ok(self.subset_features(x, selected))
407 }
408
409 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
411 self.fit(x, y)?;
412 self.transform(x)
413 }
414
415 pub fn get_support(&self) -> Option<&Vec<usize>> {
417 self.selected_features_.as_ref()
418 }
419
420 pub fn ranking(&self) -> Option<&Array1<usize>> {
422 self.ranking_.as_ref()
423 }
424
425 pub fn scores(&self) -> Option<&Array1<f64>> {
427 self.scores_.as_ref()
428 }
429}
430
431#[derive(Debug, Clone)]
435pub struct MutualInfoSelector {
436 k: usize,
438 discrete_target: bool,
440 n_neighbors: usize,
442 selected_features_: Option<Vec<usize>>,
444 scores_: Option<Array1<f64>>,
446}
447
448impl MutualInfoSelector {
449 pub fn new(k: usize) -> Self {
454 MutualInfoSelector {
455 k,
456 discrete_target: false,
457 n_neighbors: 3,
458 selected_features_: None,
459 scores_: None,
460 }
461 }
462
463 pub fn with_discrete_target(mut self) -> Self {
465 self.discrete_target = true;
466 self
467 }
468
469 pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
471 self.n_neighbors = nneighbors;
472 self
473 }
474
475 fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
477 let n = x.len();
478 if n < self.n_neighbors + 1 {
479 return 0.0;
480 }
481
482 if !self.discrete_target {
484 let x_mean = x.mean().unwrap_or(0.0);
486 let y_mean = y.mean().unwrap_or(0.0);
487 let x_std = x.std(0.0);
488 let y_std = y.std(0.0);
489
490 if x_std < 1e-10 || y_std < 1e-10 {
491 return 0.0;
492 }
493
494 let mut correlation = 0.0;
495 for i in 0..n {
496 correlation += (x[i] - x_mean) * (y[i] - y_mean);
497 }
498 correlation /= (n as f64 - 1.0) * x_std * y_std;
499
500 if correlation.abs() >= 1.0 {
503 return 5.0; }
505 (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
506 } else {
507 let mut groups = std::collections::HashMap::new();
509
510 for i in 0..n {
511 let key = y[i].round() as i64;
512 groups.entry(key).or_insert_with(Vec::new).push(x[i]);
513 }
514
515 let total_mean = x.mean().unwrap_or(0.0);
517 let total_var = x.variance();
518
519 if total_var < 1e-10 {
520 return 0.0;
521 }
522
523 let mut between_var = 0.0;
524 for (_, values) in groups {
525 let group_mean = values.iter().sum::<f64>() / values.len() as f64;
526 let weight = values.len() as f64 / n as f64;
527 between_var += weight * (group_mean - total_mean).powi(2);
528 }
529
530 (between_var / total_var).min(1.0) * 2.0 }
532 }
533
534 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
536 let n_features = x.shape()[1];
537
538 if self.k > n_features {
539 return Err(TransformError::InvalidInput(format!(
540 "k={} must be <= n_features={}",
541 self.k, n_features
542 )));
543 }
544
545 let mut scores = Array1::zeros(n_features);
547
548 for j in 0..n_features {
549 let feature = x.column(j).to_owned();
550 scores[j] = self.estimate_mutual_info(&feature, y);
551 }
552
553 let mut indices: Vec<usize> = (0..n_features).collect();
555 indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap());
556
557 let selected_features = indices.into_iter().take(self.k).collect();
558
559 self.scores_ = Some(scores);
560 self.selected_features_ = Some(selected_features);
561
562 Ok(())
563 }
564
565 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
567 if self.selected_features_.is_none() {
568 return Err(TransformError::TransformationError(
569 "MutualInfoSelector has not been fitted".to_string(),
570 ));
571 }
572
573 let selected = self.selected_features_.as_ref().unwrap();
574 let n_samples = x.shape()[0];
575 let mut transformed = Array2::zeros((n_samples, self.k));
576
577 for (new_idx, &old_idx) in selected.iter().enumerate() {
578 transformed.column_mut(new_idx).assign(&x.column(old_idx));
579 }
580
581 Ok(transformed)
582 }
583
584 pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
586 self.fit(x, y)?;
587 self.transform(x)
588 }
589
590 pub fn get_support(&self) -> Option<&Vec<usize>> {
592 self.selected_features_.as_ref()
593 }
594
595 pub fn scores(&self) -> Option<&Array1<f64>> {
597 self.scores_.as_ref()
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use approx::assert_abs_diff_eq;
605 use scirs2_core::ndarray::Array;
606
607 #[test]
608 fn test_variance_threshold_basic() {
609 let data = Array::from_shape_vec(
615 (3, 4),
616 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
617 )
618 .unwrap();
619
620 let mut selector = VarianceThreshold::with_defaults();
621 let transformed = selector.fit_transform(&data).unwrap();
622
623 assert_eq!(transformed.shape(), &[3, 2]);
625
626 let selected = selector.get_support().unwrap();
628 assert_eq!(selected, &[1, 3]);
629
630 assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); }
639
640 #[test]
641 fn test_variance_threshold_custom() {
642 let data = Array::from_shape_vec(
644 (4, 3),
645 vec![
646 1.0, 1.0, 1.0, 2.0, 1.1, 2.0, 3.0, 1.0, 3.0, 4.0, 1.1, 4.0, ],
651 )
652 .unwrap();
653
654 let mut selector = VarianceThreshold::new(0.1).unwrap();
656 let transformed = selector.fit_transform(&data).unwrap();
657
658 assert_eq!(transformed.shape(), &[4, 2]);
661
662 let selected = selector.get_support().unwrap();
663 assert_eq!(selected, &[0, 2]);
664
665 let variances = selector.variances().unwrap();
667 assert!(variances[0] > 0.1); assert!(variances[1] <= 0.1); assert!(variances[2] > 0.1); }
671
672 #[test]
673 fn test_variance_threshold_support_mask() {
674 let data = Array::from_shape_vec(
675 (3, 4),
676 vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
677 )
678 .unwrap();
679
680 let mut selector = VarianceThreshold::with_defaults();
681 selector.fit(&data).unwrap();
682
683 let mask = selector.get_support_mask().unwrap();
684 assert_eq!(mask.len(), 4);
685 assert!(!mask[0]); assert!(mask[1]); assert!(!mask[2]); assert!(mask[3]); assert_eq!(selector.n_features_selected().unwrap(), 2);
691 }
692
693 #[test]
694 fn test_variance_threshold_all_removed() {
695 let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
697
698 let mut selector = VarianceThreshold::with_defaults();
699 let transformed = selector.fit_transform(&data).unwrap();
700
701 assert_eq!(transformed.shape(), &[3, 0]);
703 assert_eq!(selector.n_features_selected().unwrap(), 0);
704 }
705
706 #[test]
707 fn test_variance_threshold_errors() {
708 assert!(VarianceThreshold::new(-0.1).is_err());
710
711 let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
713 let mut selector = VarianceThreshold::with_defaults();
714 assert!(selector.fit(&small_data).is_err());
715
716 let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
718 let selector_unfitted = VarianceThreshold::with_defaults();
719 assert!(selector_unfitted.transform(&data).is_err());
720
721 let mut selector = VarianceThreshold::with_defaults();
723 selector.fit(&data).unwrap();
724 assert!(selector.inverse_transform(&data).is_err());
725 }
726
727 #[test]
728 fn test_variance_threshold_feature_mismatch() {
729 let train_data =
730 Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
731 .unwrap();
732 let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let mut selector = VarianceThreshold::with_defaults();
735 selector.fit(&train_data).unwrap();
736 assert!(selector.transform(&test_data).is_err());
737 }
738
739 #[test]
740 fn test_variance_calculation() {
741 let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
744
745 let mut selector = VarianceThreshold::with_defaults();
746 selector.fit(&data).unwrap();
747
748 let variances = selector.variances().unwrap();
749 let expected_variance = 2.0 / 3.0;
750 assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
751 }
752
753 #[test]
754 fn test_rfe_basic() {
755 let n_samples = 100;
757 let mut data_vec = Vec::new();
758 let mut target_vec = Vec::new();
759
760 for i in 0..n_samples {
761 let x1 = i as f64 / n_samples as f64;
762 let x2 = (i as f64 / n_samples as f64).sin();
763 let x3 = scirs2_core::random::random::<f64>(); let x4 = 2.0 * x1; data_vec.extend_from_slice(&[x1, x2, x3, x4]);
767 target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
768 }
769
770 let x = Array::from_shape_vec((n_samples, 4), data_vec).unwrap();
771 let y = Array::from_vec(target_vec);
772
773 let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
775 let n_features = x.shape()[1];
776 let mut scores = Array1::zeros(n_features);
777
778 for j in 0..n_features {
779 let feature = x.column(j);
780 let corr = pearson_correlation(&feature.to_owned(), y);
781 scores[j] = corr.abs();
782 }
783
784 Ok(scores)
785 };
786
787 let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
788 let transformed = rfe.fit_transform(&x, &y).unwrap();
789
790 assert_eq!(transformed.shape()[1], 2);
792
793 let selected = rfe.get_support().unwrap();
795 assert!(selected.contains(&0) || selected.contains(&3));
796 }
797
798 #[test]
799 fn test_mutual_info_continuous() {
800 let n_samples = 100;
802 let mut x_data = Vec::new();
803 let mut y_data = Vec::new();
804
805 for i in 0..n_samples {
806 let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
807
808 let x0 = t;
810 let x1 = scirs2_core::random::random::<f64>();
812 let x2 = t.sin();
814
815 x_data.extend_from_slice(&[x0, x1, x2]);
816 y_data.push(t + 0.5 * t.sin());
817 }
818
819 let x = Array::from_shape_vec((n_samples, 3), x_data).unwrap();
820 let y = Array::from_vec(y_data);
821
822 let mut selector = MutualInfoSelector::new(2);
823 selector.fit(&x, &y).unwrap();
824
825 let scores = selector.scores().unwrap();
826
827 assert!(scores[0] > scores[1]);
831 assert!(scores[2] > scores[1]);
832 }
833
834 #[test]
835 fn test_mutual_info_discrete() {
836 let x = Array::from_shape_vec(
838 (6, 3),
839 vec![
840 1.0, 0.1, 5.0, 1.1, 0.2, 5.1, 2.0, 0.1, 4.0, 2.1, 0.2, 4.1, 3.0, 0.1, 3.0, 3.1, 0.2, 3.1, ],
847 )
848 .unwrap();
849
850 let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
851
852 let mut selector = MutualInfoSelector::new(2).with_discrete_target();
853 let transformed = selector.fit_transform(&x, &y).unwrap();
854
855 assert_eq!(transformed.shape(), &[6, 2]);
856
857 let selected = selector.get_support().unwrap();
859 assert!(!selected.contains(&1));
860 }
861
862 fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
864 #[allow(unused_variables)]
865 let n = x.len() as f64;
866 let x_mean = x.mean().unwrap_or(0.0);
867 let y_mean = y.mean().unwrap_or(0.0);
868
869 let mut num = 0.0;
870 let mut x_var = 0.0;
871 let mut y_var = 0.0;
872
873 for i in 0..x.len() {
874 let x_diff = x[i] - x_mean;
875 let y_diff = y[i] - y_mean;
876 num += x_diff * y_diff;
877 x_var += x_diff * x_diff;
878 y_var += y_diff * y_diff;
879 }
880
881 if x_var * y_var > 0.0 {
882 num / (x_var * y_var).sqrt()
883 } else {
884 0.0
885 }
886 }
887}