scirs2_transform/
selection.rs

1//! Feature selection utilities
2//!
3//! This module provides methods for selecting relevant features from datasets,
4//! which can help reduce dimensionality and improve model performance.
5
6use scirs2_core::ndarray::ArrayStatCompat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9
10use crate::error::{Result, TransformError};
11use statrs::statistics::Statistics;
12
13/// VarianceThreshold for removing low-variance features
14///
15/// Features with variance below the threshold are removed. This is useful for
16/// removing features that are mostly constant and don't provide much information.
17pub struct VarianceThreshold {
18    /// Variance threshold for feature selection
19    threshold: f64,
20    /// Variances computed for each feature (learned during fit)
21    variances_: Option<Array1<f64>>,
22    /// Indices of selected features
23    selected_features_: Option<Vec<usize>>,
24}
25
26impl VarianceThreshold {
27    /// Creates a new VarianceThreshold selector
28    ///
29    /// # Arguments
30    /// * `threshold` - Features with variance below this threshold are removed (default: 0.0)
31    ///
32    /// # Returns
33    /// * A new VarianceThreshold instance
34    ///
35    /// # Examples
36    /// ```
37    /// use scirs2_transform::selection::VarianceThreshold;
38    ///
39    /// // Remove features with variance less than 0.1
40    /// let selector = VarianceThreshold::new(0.1);
41    /// ```
42    pub fn new(threshold: f64) -> Result<Self> {
43        if threshold < 0.0 {
44            return Err(TransformError::InvalidInput(
45                "Threshold must be non-negative".to_string(),
46            ));
47        }
48
49        Ok(VarianceThreshold {
50            threshold,
51            variances_: None,
52            selected_features_: None,
53        })
54    }
55
56    /// Creates a VarianceThreshold with default threshold (0.0)
57    ///
58    /// This will only remove features that are completely constant.
59    pub fn with_defaults() -> Self {
60        Self::new(0.0).unwrap()
61    }
62
63    /// Fits the VarianceThreshold to the input data
64    ///
65    /// # Arguments
66    /// * `x` - The input data, shape (n_samples, n_features)
67    ///
68    /// # Returns
69    /// * `Result<()>` - Ok if successful, Err otherwise
70    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
71    where
72        S: Data,
73        S::Elem: Float + NumCast,
74    {
75        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
76
77        let n_samples = x_f64.shape()[0];
78        let n_features = x_f64.shape()[1];
79
80        if n_samples == 0 || n_features == 0 {
81            return Err(TransformError::InvalidInput("Empty input data".to_string()));
82        }
83
84        if n_samples < 2 {
85            return Err(TransformError::InvalidInput(
86                "At least 2 samples required to compute variance".to_string(),
87            ));
88        }
89
90        // Compute variance for each feature
91        let mut variances = Array1::zeros(n_features);
92        let mut selected_features = Vec::new();
93
94        for j in 0..n_features {
95            let feature_data = x_f64.column(j);
96
97            // Calculate mean
98            let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
99
100            // Calculate variance (using population variance for consistency with sklearn)
101            let variance = feature_data
102                .iter()
103                .map(|&x| (x - mean).powi(2))
104                .sum::<f64>()
105                / n_samples as f64;
106
107            variances[j] = variance;
108
109            // Select feature if variance is above threshold
110            if variance > self.threshold {
111                selected_features.push(j);
112            }
113        }
114
115        self.variances_ = Some(variances);
116        self.selected_features_ = Some(selected_features);
117
118        Ok(())
119    }
120
121    /// Transforms the input data by removing low-variance features
122    ///
123    /// # Arguments
124    /// * `x` - The input data, shape (n_samples, n_features)
125    ///
126    /// # Returns
127    /// * `Result<Array2<f64>>` - The transformed data with selected features only
128    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
129    where
130        S: Data,
131        S::Elem: Float + NumCast,
132    {
133        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
134
135        let n_samples = x_f64.shape()[0];
136        let n_features = x_f64.shape()[1];
137
138        if self.selected_features_.is_none() {
139            return Err(TransformError::TransformationError(
140                "VarianceThreshold has not been fitted".to_string(),
141            ));
142        }
143
144        let selected_features = self.selected_features_.as_ref().unwrap();
145
146        // Check feature consistency
147        if let Some(ref variances) = self.variances_ {
148            if n_features != variances.len() {
149                return Err(TransformError::InvalidInput(format!(
150                    "x has {} features, but VarianceThreshold was fitted with {} features",
151                    n_features,
152                    variances.len()
153                )));
154            }
155        }
156
157        let n_selected = selected_features.len();
158        let mut transformed = Array2::zeros((n_samples, n_selected));
159
160        // Copy selected features
161        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
162            for i in 0..n_samples {
163                transformed[[i, new_idx]] = x_f64[[i, old_idx]];
164            }
165        }
166
167        Ok(transformed)
168    }
169
170    /// Fits the VarianceThreshold to the input data and transforms it
171    ///
172    /// # Arguments
173    /// * `x` - The input data, shape (n_samples, n_features)
174    ///
175    /// # Returns
176    /// * `Result<Array2<f64>>` - The transformed data with selected features only
177    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
178    where
179        S: Data,
180        S::Elem: Float + NumCast,
181    {
182        self.fit(x)?;
183        self.transform(x)
184    }
185
186    /// Returns the variances computed for each feature
187    ///
188    /// # Returns
189    /// * `Option<&Array1<f64>>` - The variances for each feature
190    pub fn variances(&self) -> Option<&Array1<f64>> {
191        self.variances_.as_ref()
192    }
193
194    /// Returns the indices of selected features
195    ///
196    /// # Returns
197    /// * `Option<&Vec<usize>>` - Indices of features that pass the variance threshold
198    pub fn get_support(&self) -> Option<&Vec<usize>> {
199        self.selected_features_.as_ref()
200    }
201
202    /// Returns a boolean mask indicating which features are selected
203    ///
204    /// # Returns
205    /// * `Option<Array1<bool>>` - Boolean mask where true indicates selected features
206    pub fn get_support_mask(&self) -> Option<Array1<bool>> {
207        if let (Some(ref variances), Some(ref selected)) =
208            (&self.variances_, &self.selected_features_)
209        {
210            let n_features = variances.len();
211            let mut mask = Array1::from_elem(n_features, false);
212
213            for &idx in selected {
214                mask[idx] = true;
215            }
216
217            Some(mask)
218        } else {
219            None
220        }
221    }
222
223    /// Returns the number of selected features
224    ///
225    /// # Returns
226    /// * `Option<usize>` - Number of features that pass the variance threshold
227    pub fn n_features_selected(&self) -> Option<usize> {
228        self.selected_features_.as_ref().map(|s| s.len())
229    }
230
231    /// Inverse transform - not applicable for feature selection
232    ///
233    /// This method is not implemented for feature selection as it's not possible
234    /// to reconstruct removed features.
235    pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
236    where
237        S: Data,
238        S::Elem: Float + NumCast,
239    {
240        Err(TransformError::TransformationError(
241            "inverse_transform is not supported for feature selection".to_string(),
242        ))
243    }
244}
245
246/// Recursive Feature Elimination (RFE) for feature selection
247///
248/// RFE works by recursively removing features and evaluating model performance.
249/// This implementation uses a feature importance scoring function to rank features.
250#[derive(Debug, Clone)]
251pub struct RecursiveFeatureElimination<F>
252where
253    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
254{
255    /// Number of features to select
256    n_features_to_select: usize,
257    /// Number of features to remove at each iteration
258    step: usize,
259    /// Feature importance scoring function
260    /// Takes (X, y) and returns importance scores for each feature
261    importance_func: F,
262    /// Indices of selected features
263    selected_features_: Option<Vec<usize>>,
264    /// Feature rankings (1 is best)
265    ranking_: Option<Array1<usize>>,
266    /// Feature importance scores
267    scores_: Option<Array1<f64>>,
268}
269
270impl<F> RecursiveFeatureElimination<F>
271where
272    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
273{
274    /// Creates a new RFE selector
275    ///
276    /// # Arguments
277    /// * `n_features_to_select` - Number of features to select
278    /// * `importance_func` - Function that computes feature importance scores
279    pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
280        RecursiveFeatureElimination {
281            n_features_to_select,
282            step: 1,
283            importance_func: importancefunc,
284            selected_features_: None,
285            ranking_: None,
286            scores_: None,
287        }
288    }
289
290    /// Set the number of features to remove at each iteration
291    pub fn with_step(mut self, step: usize) -> Self {
292        self.step = step.max(1);
293        self
294    }
295
296    /// Fit the RFE selector
297    ///
298    /// # Arguments
299    /// * `x` - Training data, shape (n_samples, n_features)
300    /// * `y` - Target values, shape (n_samples,)
301    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
302        let n_samples = x.shape()[0];
303        let n_features = x.shape()[1];
304
305        if n_samples != y.len() {
306            return Err(TransformError::InvalidInput(format!(
307                "X has {} samples but y has {} samples",
308                n_samples,
309                y.len()
310            )));
311        }
312
313        if self.n_features_to_select > n_features {
314            return Err(TransformError::InvalidInput(format!(
315                "n_features_to_select={} must be <= n_features={}",
316                self.n_features_to_select, n_features
317            )));
318        }
319
320        // Initialize with all features
321        let mut remaining_features: Vec<usize> = (0..n_features).collect();
322        let mut ranking = Array1::zeros(n_features);
323        let mut current_rank = 1;
324
325        // Recursively eliminate features
326        while remaining_features.len() > self.n_features_to_select {
327            // Create subset of data with remaining features
328            let x_subset = self.subset_features(x, &remaining_features);
329
330            // Get feature importances
331            let importances = (self.importance_func)(&x_subset, y)?;
332
333            if importances.len() != remaining_features.len() {
334                return Err(TransformError::InvalidInput(
335                    "Importance function returned wrong number of scores".to_string(),
336                ));
337            }
338
339            // Find features to eliminate
340            let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
341
342            // Get indices of features with lowest importance
343            let mut indices: Vec<usize> = (0..importances.len()).collect();
344            indices.sort_by(|&i, &j| importances[i].partial_cmp(&importances[j]).unwrap());
345
346            // Mark eliminated features with current rank
347            for i in 0..n_to_remove {
348                let feature_idx = remaining_features[indices[i]];
349                ranking[feature_idx] = n_features - current_rank + 1;
350                current_rank += 1;
351            }
352
353            // Remove eliminated features
354            let eliminated: std::collections::HashSet<usize> =
355                indices.iter().take(n_to_remove).cloned().collect();
356            let features_to_retain: Vec<usize> = remaining_features
357                .iter()
358                .filter(|&&idx| !eliminated.contains(&idx))
359                .cloned()
360                .collect();
361            remaining_features = features_to_retain;
362        }
363
364        // Mark remaining features as rank 1
365        for &feature_idx in &remaining_features {
366            ranking[feature_idx] = 1;
367        }
368
369        // Compute final scores for selected features
370        let x_final = self.subset_features(x, &remaining_features);
371        let final_scores = (self.importance_func)(&x_final, y)?;
372
373        let mut scores = Array1::zeros(n_features);
374        for (i, &feature_idx) in remaining_features.iter().enumerate() {
375            scores[feature_idx] = final_scores[i];
376        }
377
378        self.selected_features_ = Some(remaining_features);
379        self.ranking_ = Some(ranking);
380        self.scores_ = Some(scores);
381
382        Ok(())
383    }
384
385    /// Create a subset of features
386    fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
387        let n_samples = x.shape()[0];
388        let n_selected = features.len();
389        let mut subset = Array2::zeros((n_samples, n_selected));
390
391        for (new_idx, &old_idx) in features.iter().enumerate() {
392            subset.column_mut(new_idx).assign(&x.column(old_idx));
393        }
394
395        subset
396    }
397
398    /// Transform data by selecting features
399    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
400        if self.selected_features_.is_none() {
401            return Err(TransformError::TransformationError(
402                "RFE has not been fitted".to_string(),
403            ));
404        }
405
406        let selected = self.selected_features_.as_ref().unwrap();
407        Ok(self.subset_features(x, selected))
408    }
409
410    /// Fit and transform in one step
411    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
412        self.fit(x, y)?;
413        self.transform(x)
414    }
415
416    /// Get selected feature indices
417    pub fn get_support(&self) -> Option<&Vec<usize>> {
418        self.selected_features_.as_ref()
419    }
420
421    /// Get feature rankings (1 is best)
422    pub fn ranking(&self) -> Option<&Array1<usize>> {
423        self.ranking_.as_ref()
424    }
425
426    /// Get feature scores
427    pub fn scores(&self) -> Option<&Array1<f64>> {
428        self.scores_.as_ref()
429    }
430}
431
432/// Mutual information based feature selection
433///
434/// Selects features based on mutual information between features and target.
435#[derive(Debug, Clone)]
436pub struct MutualInfoSelector {
437    /// Number of features to select
438    k: usize,
439    /// Whether to use discrete mutual information
440    discrete_target: bool,
441    /// Number of neighbors for KNN estimation
442    n_neighbors: usize,
443    /// Selected feature indices
444    selected_features_: Option<Vec<usize>>,
445    /// Mutual information scores
446    scores_: Option<Array1<f64>>,
447}
448
449impl MutualInfoSelector {
450    /// Create a new mutual information selector
451    ///
452    /// # Arguments
453    /// * `k` - Number of top features to select
454    pub fn new(k: usize) -> Self {
455        MutualInfoSelector {
456            k,
457            discrete_target: false,
458            n_neighbors: 3,
459            selected_features_: None,
460            scores_: None,
461        }
462    }
463
464    /// Use discrete mutual information (for classification)
465    pub fn with_discrete_target(mut self) -> Self {
466        self.discrete_target = true;
467        self
468    }
469
470    /// Set number of neighbors for KNN estimation
471    pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
472        self.n_neighbors = nneighbors;
473        self
474    }
475
476    /// Estimate mutual information using KNN method (simplified)
477    fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
478        let n = x.len();
479        if n < self.n_neighbors + 1 {
480            return 0.0;
481        }
482
483        // Simple correlation-based approximation for continuous variables
484        if !self.discrete_target {
485            // Standardize variables
486            let x_mean = x.mean_or(0.0);
487            let y_mean = y.mean_or(0.0);
488            let x_std = x.std(0.0);
489            let y_std = y.std(0.0);
490
491            if x_std < 1e-10 || y_std < 1e-10 {
492                return 0.0;
493            }
494
495            let mut correlation = 0.0;
496            for i in 0..n {
497                correlation += (x[i] - x_mean) * (y[i] - y_mean);
498            }
499            correlation /= (n as f64 - 1.0) * x_std * y_std;
500
501            // Convert correlation to mutual information approximation
502            // MI ≈ -0.5 * log(1 - r²) for Gaussian variables
503            if correlation.abs() >= 1.0 {
504                return 5.0; // Cap at reasonable value
505            }
506            (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
507        } else {
508            // For discrete targets, use a simple grouping approach
509            let mut groups = std::collections::HashMap::new();
510
511            for i in 0..n {
512                let key = y[i].round() as i64;
513                groups.entry(key).or_insert_with(Vec::new).push(x[i]);
514            }
515
516            // Calculate between-group variance / total variance ratio
517            let total_mean = x.mean_or(0.0);
518            let total_var = x.variance();
519
520            if total_var < 1e-10 {
521                return 0.0;
522            }
523
524            let mut between_var = 0.0;
525            for (_, values) in groups {
526                let group_mean = values.iter().sum::<f64>() / values.len() as f64;
527                let weight = values.len() as f64 / n as f64;
528                between_var += weight * (group_mean - total_mean).powi(2);
529            }
530
531            (between_var / total_var).min(1.0) * 2.0 // Scale to reasonable range
532        }
533    }
534
535    /// Fit the selector
536    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
537        let n_features = x.shape()[1];
538
539        if self.k > n_features {
540            return Err(TransformError::InvalidInput(format!(
541                "k={} must be <= n_features={}",
542                self.k, n_features
543            )));
544        }
545
546        // Compute mutual information for each feature
547        let mut scores = Array1::zeros(n_features);
548
549        for j in 0..n_features {
550            let feature = x.column(j).to_owned();
551            scores[j] = self.estimate_mutual_info(&feature, y);
552        }
553
554        // Select top k features
555        let mut indices: Vec<usize> = (0..n_features).collect();
556        indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap());
557
558        let selected_features = indices.into_iter().take(self.k).collect();
559
560        self.scores_ = Some(scores);
561        self.selected_features_ = Some(selected_features);
562
563        Ok(())
564    }
565
566    /// Transform data by selecting features
567    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
568        if self.selected_features_.is_none() {
569            return Err(TransformError::TransformationError(
570                "MutualInfoSelector has not been fitted".to_string(),
571            ));
572        }
573
574        let selected = self.selected_features_.as_ref().unwrap();
575        let n_samples = x.shape()[0];
576        let mut transformed = Array2::zeros((n_samples, self.k));
577
578        for (new_idx, &old_idx) in selected.iter().enumerate() {
579            transformed.column_mut(new_idx).assign(&x.column(old_idx));
580        }
581
582        Ok(transformed)
583    }
584
585    /// Fit and transform in one step
586    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
587        self.fit(x, y)?;
588        self.transform(x)
589    }
590
591    /// Get selected feature indices
592    pub fn get_support(&self) -> Option<&Vec<usize>> {
593        self.selected_features_.as_ref()
594    }
595
596    /// Get mutual information scores
597    pub fn scores(&self) -> Option<&Array1<f64>> {
598        self.scores_.as_ref()
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use approx::assert_abs_diff_eq;
606    use scirs2_core::ndarray::Array;
607
608    #[test]
609    fn test_variance_threshold_basic() {
610        // Create test data with different variances
611        // Feature 0: [1, 1, 1] - constant, variance = 0
612        // Feature 1: [1, 2, 3] - varying, variance > 0
613        // Feature 2: [5, 5, 5] - constant, variance = 0
614        // Feature 3: [1, 3, 5] - varying, variance > 0
615        let data = Array::from_shape_vec(
616            (3, 4),
617            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
618        )
619        .unwrap();
620
621        let mut selector = VarianceThreshold::with_defaults();
622        let transformed = selector.fit_transform(&data).unwrap();
623
624        // Should keep features 1 and 3 (indices 1 and 3)
625        assert_eq!(transformed.shape(), &[3, 2]);
626
627        // Check that we kept the right features
628        let selected = selector.get_support().unwrap();
629        assert_eq!(selected, &[1, 3]);
630
631        // Check transformed values
632        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // Feature 1, sample 0
633        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); // Feature 1, sample 1
634        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); // Feature 1, sample 2
635
636        assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); // Feature 3, sample 0
637        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); // Feature 3, sample 1
638        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); // Feature 3, sample 2
639    }
640
641    #[test]
642    fn test_variance_threshold_custom() {
643        // Create test data with specific variances
644        let data = Array::from_shape_vec(
645            (4, 3),
646            vec![
647                1.0, 1.0, 1.0, // Sample 0
648                2.0, 1.1, 2.0, // Sample 1
649                3.0, 1.0, 3.0, // Sample 2
650                4.0, 1.1, 4.0, // Sample 3
651            ],
652        )
653        .unwrap();
654
655        // Set threshold to remove features with very low variance
656        let mut selector = VarianceThreshold::new(0.1).unwrap();
657        let transformed = selector.fit_transform(&data).unwrap();
658
659        // Feature 1 has very low variance (between 1.0 and 1.1), should be removed
660        // Features 0 and 2 have higher variance, should be kept
661        assert_eq!(transformed.shape(), &[4, 2]);
662
663        let selected = selector.get_support().unwrap();
664        assert_eq!(selected, &[0, 2]);
665
666        // Check variances
667        let variances = selector.variances().unwrap();
668        assert!(variances[0] > 0.1); // Feature 0 variance
669        assert!(variances[1] <= 0.1); // Feature 1 variance (should be low)
670        assert!(variances[2] > 0.1); // Feature 2 variance
671    }
672
673    #[test]
674    fn test_variance_threshold_support_mask() {
675        let data = Array::from_shape_vec(
676            (3, 4),
677            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
678        )
679        .unwrap();
680
681        let mut selector = VarianceThreshold::with_defaults();
682        selector.fit(&data).unwrap();
683
684        let mask = selector.get_support_mask().unwrap();
685        assert_eq!(mask.len(), 4);
686        assert!(!mask[0]); // Feature 0 is constant
687        assert!(mask[1]); // Feature 1 has variance
688        assert!(!mask[2]); // Feature 2 is constant
689        assert!(mask[3]); // Feature 3 has variance
690
691        assert_eq!(selector.n_features_selected().unwrap(), 2);
692    }
693
694    #[test]
695    fn test_variance_threshold_all_removed() {
696        // Create data where all features are constant
697        let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
698
699        let mut selector = VarianceThreshold::with_defaults();
700        let transformed = selector.fit_transform(&data).unwrap();
701
702        // All features should be removed
703        assert_eq!(transformed.shape(), &[3, 0]);
704        assert_eq!(selector.n_features_selected().unwrap(), 0);
705    }
706
707    #[test]
708    fn test_variance_threshold_errors() {
709        // Test negative threshold
710        assert!(VarianceThreshold::new(-0.1).is_err());
711
712        // Test with insufficient samples
713        let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
714        let mut selector = VarianceThreshold::with_defaults();
715        assert!(selector.fit(&small_data).is_err());
716
717        // Test transform before fit
718        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
719        let selector_unfitted = VarianceThreshold::with_defaults();
720        assert!(selector_unfitted.transform(&data).is_err());
721
722        // Test inverse transform (should always fail)
723        let mut selector = VarianceThreshold::with_defaults();
724        selector.fit(&data).unwrap();
725        assert!(selector.inverse_transform(&data).is_err());
726    }
727
728    #[test]
729    fn test_variance_threshold_feature_mismatch() {
730        let train_data =
731            Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
732                .unwrap();
733        let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); // Different number of features
734
735        let mut selector = VarianceThreshold::with_defaults();
736        selector.fit(&train_data).unwrap();
737        assert!(selector.transform(&test_data).is_err());
738    }
739
740    #[test]
741    fn test_variance_calculation() {
742        // Test variance calculation manually
743        // Data: [1, 2, 3] should have variance = ((1-2)² + (2-2)² + (3-2)²) / 3 = (1 + 0 + 1) / 3 = 2/3
744        let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
745
746        let mut selector = VarianceThreshold::with_defaults();
747        selector.fit(&data).unwrap();
748
749        let variances = selector.variances().unwrap();
750        let expected_variance = 2.0 / 3.0;
751        assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
752    }
753
754    #[test]
755    fn test_rfe_basic() {
756        // Create test data where features have clear importance
757        let n_samples = 100;
758        let mut data_vec = Vec::new();
759        let mut target_vec = Vec::new();
760
761        for i in 0..n_samples {
762            let x1 = i as f64 / n_samples as f64;
763            let x2 = (i as f64 / n_samples as f64).sin();
764            let x3 = scirs2_core::random::random::<f64>(); // Noise
765            let x4 = 2.0 * x1; // Highly correlated with target
766
767            data_vec.extend_from_slice(&[x1, x2, x3, x4]);
768            target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
769        }
770
771        let x = Array::from_shape_vec((n_samples, 4), data_vec).unwrap();
772        let y = Array::from_vec(target_vec);
773
774        // Simple importance function based on correlation
775        let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
776            let n_features = x.shape()[1];
777            let mut scores = Array1::zeros(n_features);
778
779            for j in 0..n_features {
780                let feature = x.column(j);
781                let corr = pearson_correlation(&feature.to_owned(), y);
782                scores[j] = corr.abs();
783            }
784
785            Ok(scores)
786        };
787
788        let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
789        let transformed = rfe.fit_transform(&x, &y).unwrap();
790
791        // Should select 2 features
792        assert_eq!(transformed.shape()[1], 2);
793
794        // Check that features 0 and 3 (most important) were selected
795        let selected = rfe.get_support().unwrap();
796        assert!(selected.contains(&0) || selected.contains(&3));
797    }
798
799    #[test]
800    fn test_mutual_info_continuous() {
801        // Create data with clear relationships
802        let n_samples = 100;
803        let mut x_data = Vec::new();
804        let mut y_data = Vec::new();
805
806        for i in 0..n_samples {
807            let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
808
809            // Feature 0: Strongly related to target
810            let x0 = t;
811            // Feature 1: Noise
812            let x1 = scirs2_core::random::random::<f64>();
813            // Feature 2: Non-linearly related
814            let x2 = t.sin();
815
816            x_data.extend_from_slice(&[x0, x1, x2]);
817            y_data.push(t + 0.5 * t.sin());
818        }
819
820        let x = Array::from_shape_vec((n_samples, 3), x_data).unwrap();
821        let y = Array::from_vec(y_data);
822
823        let mut selector = MutualInfoSelector::new(2);
824        selector.fit(&x, &y).unwrap();
825
826        let scores = selector.scores().unwrap();
827
828        // Feature 0 should have highest score (linear relationship)
829        // Feature 2 should have second highest (non-linear relationship)
830        // Feature 1 should have lowest score (noise)
831        assert!(scores[0] > scores[1]);
832        assert!(scores[2] > scores[1]);
833    }
834
835    #[test]
836    fn test_mutual_info_discrete() {
837        // Create classification-like data
838        let x = Array::from_shape_vec(
839            (6, 3),
840            vec![
841                1.0, 0.1, 5.0, // Class 0
842                1.1, 0.2, 5.1, // Class 0
843                2.0, 0.1, 4.0, // Class 1
844                2.1, 0.2, 4.1, // Class 1
845                3.0, 0.1, 3.0, // Class 2
846                3.1, 0.2, 3.1, // Class 2
847            ],
848        )
849        .unwrap();
850
851        let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
852
853        let mut selector = MutualInfoSelector::new(2).with_discrete_target();
854        let transformed = selector.fit_transform(&x, &y).unwrap();
855
856        assert_eq!(transformed.shape(), &[6, 2]);
857
858        // Feature 1 (middle column) has low variance within groups, should be excluded
859        let selected = selector.get_support().unwrap();
860        assert!(!selected.contains(&1));
861    }
862
863    // Helper function for correlation
864    fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
865        #[allow(unused_variables)]
866        let n = x.len() as f64;
867        let x_mean = x.mean_or(0.0);
868        let y_mean = y.mean_or(0.0);
869
870        let mut num = 0.0;
871        let mut x_var = 0.0;
872        let mut y_var = 0.0;
873
874        for i in 0..x.len() {
875            let x_diff = x[i] - x_mean;
876            let y_diff = y[i] - y_mean;
877            num += x_diff * y_diff;
878            x_var += x_diff * x_diff;
879            y_var += y_diff * y_diff;
880        }
881
882        if x_var * y_var > 0.0 {
883            num / (x_var * y_var).sqrt()
884        } else {
885            0.0
886        }
887    }
888}