scirs2_transform/
selection.rs

1//! Feature selection utilities
2//!
3//! This module provides methods for selecting relevant features from datasets,
4//! which can help reduce dimensionality and improve model performance.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10use statrs::statistics::Statistics;
11
12/// VarianceThreshold for removing low-variance features
13///
14/// Features with variance below the threshold are removed. This is useful for
15/// removing features that are mostly constant and don't provide much information.
16pub struct VarianceThreshold {
17    /// Variance threshold for feature selection
18    threshold: f64,
19    /// Variances computed for each feature (learned during fit)
20    variances_: Option<Array1<f64>>,
21    /// Indices of selected features
22    selected_features_: Option<Vec<usize>>,
23}
24
25impl VarianceThreshold {
26    /// Creates a new VarianceThreshold selector
27    ///
28    /// # Arguments
29    /// * `threshold` - Features with variance below this threshold are removed (default: 0.0)
30    ///
31    /// # Returns
32    /// * A new VarianceThreshold instance
33    ///
34    /// # Examples
35    /// ```
36    /// use scirs2_transform::selection::VarianceThreshold;
37    ///
38    /// // Remove features with variance less than 0.1
39    /// let selector = VarianceThreshold::new(0.1);
40    /// ```
41    pub fn new(threshold: f64) -> Result<Self> {
42        if threshold < 0.0 {
43            return Err(TransformError::InvalidInput(
44                "Threshold must be non-negative".to_string(),
45            ));
46        }
47
48        Ok(VarianceThreshold {
49            threshold,
50            variances_: None,
51            selected_features_: None,
52        })
53    }
54
55    /// Creates a VarianceThreshold with default threshold (0.0)
56    ///
57    /// This will only remove features that are completely constant.
58    pub fn with_defaults() -> Self {
59        Self::new(0.0).unwrap()
60    }
61
62    /// Fits the VarianceThreshold to the input data
63    ///
64    /// # Arguments
65    /// * `x` - The input data, shape (n_samples, n_features)
66    ///
67    /// # Returns
68    /// * `Result<()>` - Ok if successful, Err otherwise
69    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
70    where
71        S: Data,
72        S::Elem: Float + NumCast,
73    {
74        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
75
76        let n_samples = x_f64.shape()[0];
77        let n_features = x_f64.shape()[1];
78
79        if n_samples == 0 || n_features == 0 {
80            return Err(TransformError::InvalidInput("Empty input data".to_string()));
81        }
82
83        if n_samples < 2 {
84            return Err(TransformError::InvalidInput(
85                "At least 2 samples required to compute variance".to_string(),
86            ));
87        }
88
89        // Compute variance for each feature
90        let mut variances = Array1::zeros(n_features);
91        let mut selected_features = Vec::new();
92
93        for j in 0..n_features {
94            let feature_data = x_f64.column(j);
95
96            // Calculate mean
97            let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
98
99            // Calculate variance (using population variance for consistency with sklearn)
100            let variance = feature_data
101                .iter()
102                .map(|&x| (x - mean).powi(2))
103                .sum::<f64>()
104                / n_samples as f64;
105
106            variances[j] = variance;
107
108            // Select feature if variance is above threshold
109            if variance > self.threshold {
110                selected_features.push(j);
111            }
112        }
113
114        self.variances_ = Some(variances);
115        self.selected_features_ = Some(selected_features);
116
117        Ok(())
118    }
119
120    /// Transforms the input data by removing low-variance features
121    ///
122    /// # Arguments
123    /// * `x` - The input data, shape (n_samples, n_features)
124    ///
125    /// # Returns
126    /// * `Result<Array2<f64>>` - The transformed data with selected features only
127    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
128    where
129        S: Data,
130        S::Elem: Float + NumCast,
131    {
132        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
133
134        let n_samples = x_f64.shape()[0];
135        let n_features = x_f64.shape()[1];
136
137        if self.selected_features_.is_none() {
138            return Err(TransformError::TransformationError(
139                "VarianceThreshold has not been fitted".to_string(),
140            ));
141        }
142
143        let selected_features = self.selected_features_.as_ref().unwrap();
144
145        // Check feature consistency
146        if let Some(ref variances) = self.variances_ {
147            if n_features != variances.len() {
148                return Err(TransformError::InvalidInput(format!(
149                    "x has {} features, but VarianceThreshold was fitted with {} features",
150                    n_features,
151                    variances.len()
152                )));
153            }
154        }
155
156        let n_selected = selected_features.len();
157        let mut transformed = Array2::zeros((n_samples, n_selected));
158
159        // Copy selected features
160        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
161            for i in 0..n_samples {
162                transformed[[i, new_idx]] = x_f64[[i, old_idx]];
163            }
164        }
165
166        Ok(transformed)
167    }
168
169    /// Fits the VarianceThreshold to the input data and transforms it
170    ///
171    /// # Arguments
172    /// * `x` - The input data, shape (n_samples, n_features)
173    ///
174    /// # Returns
175    /// * `Result<Array2<f64>>` - The transformed data with selected features only
176    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
177    where
178        S: Data,
179        S::Elem: Float + NumCast,
180    {
181        self.fit(x)?;
182        self.transform(x)
183    }
184
185    /// Returns the variances computed for each feature
186    ///
187    /// # Returns
188    /// * `Option<&Array1<f64>>` - The variances for each feature
189    pub fn variances(&self) -> Option<&Array1<f64>> {
190        self.variances_.as_ref()
191    }
192
193    /// Returns the indices of selected features
194    ///
195    /// # Returns
196    /// * `Option<&Vec<usize>>` - Indices of features that pass the variance threshold
197    pub fn get_support(&self) -> Option<&Vec<usize>> {
198        self.selected_features_.as_ref()
199    }
200
201    /// Returns a boolean mask indicating which features are selected
202    ///
203    /// # Returns
204    /// * `Option<Array1<bool>>` - Boolean mask where true indicates selected features
205    pub fn get_support_mask(&self) -> Option<Array1<bool>> {
206        if let (Some(ref variances), Some(ref selected)) =
207            (&self.variances_, &self.selected_features_)
208        {
209            let n_features = variances.len();
210            let mut mask = Array1::from_elem(n_features, false);
211
212            for &idx in selected {
213                mask[idx] = true;
214            }
215
216            Some(mask)
217        } else {
218            None
219        }
220    }
221
222    /// Returns the number of selected features
223    ///
224    /// # Returns
225    /// * `Option<usize>` - Number of features that pass the variance threshold
226    pub fn n_features_selected(&self) -> Option<usize> {
227        self.selected_features_.as_ref().map(|s| s.len())
228    }
229
230    /// Inverse transform - not applicable for feature selection
231    ///
232    /// This method is not implemented for feature selection as it's not possible
233    /// to reconstruct removed features.
234    pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
235    where
236        S: Data,
237        S::Elem: Float + NumCast,
238    {
239        Err(TransformError::TransformationError(
240            "inverse_transform is not supported for feature selection".to_string(),
241        ))
242    }
243}
244
245/// Recursive Feature Elimination (RFE) for feature selection
246///
247/// RFE works by recursively removing features and evaluating model performance.
248/// This implementation uses a feature importance scoring function to rank features.
249#[derive(Debug, Clone)]
250pub struct RecursiveFeatureElimination<F>
251where
252    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
253{
254    /// Number of features to select
255    n_features_to_select: usize,
256    /// Number of features to remove at each iteration
257    step: usize,
258    /// Feature importance scoring function
259    /// Takes (X, y) and returns importance scores for each feature
260    importance_func: F,
261    /// Indices of selected features
262    selected_features_: Option<Vec<usize>>,
263    /// Feature rankings (1 is best)
264    ranking_: Option<Array1<usize>>,
265    /// Feature importance scores
266    scores_: Option<Array1<f64>>,
267}
268
269impl<F> RecursiveFeatureElimination<F>
270where
271    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
272{
273    /// Creates a new RFE selector
274    ///
275    /// # Arguments
276    /// * `n_features_to_select` - Number of features to select
277    /// * `importance_func` - Function that computes feature importance scores
278    pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
279        RecursiveFeatureElimination {
280            n_features_to_select,
281            step: 1,
282            importance_func: importancefunc,
283            selected_features_: None,
284            ranking_: None,
285            scores_: None,
286        }
287    }
288
289    /// Set the number of features to remove at each iteration
290    pub fn with_step(mut self, step: usize) -> Self {
291        self.step = step.max(1);
292        self
293    }
294
295    /// Fit the RFE selector
296    ///
297    /// # Arguments
298    /// * `x` - Training data, shape (n_samples, n_features)
299    /// * `y` - Target values, shape (n_samples,)
300    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
301        let n_samples = x.shape()[0];
302        let n_features = x.shape()[1];
303
304        if n_samples != y.len() {
305            return Err(TransformError::InvalidInput(format!(
306                "X has {} samples but y has {} samples",
307                n_samples,
308                y.len()
309            )));
310        }
311
312        if self.n_features_to_select > n_features {
313            return Err(TransformError::InvalidInput(format!(
314                "n_features_to_select={} must be <= n_features={}",
315                self.n_features_to_select, n_features
316            )));
317        }
318
319        // Initialize with all features
320        let mut remaining_features: Vec<usize> = (0..n_features).collect();
321        let mut ranking = Array1::zeros(n_features);
322        let mut current_rank = 1;
323
324        // Recursively eliminate features
325        while remaining_features.len() > self.n_features_to_select {
326            // Create subset of data with remaining features
327            let x_subset = self.subset_features(x, &remaining_features);
328
329            // Get feature importances
330            let importances = (self.importance_func)(&x_subset, y)?;
331
332            if importances.len() != remaining_features.len() {
333                return Err(TransformError::InvalidInput(
334                    "Importance function returned wrong number of scores".to_string(),
335                ));
336            }
337
338            // Find features to eliminate
339            let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
340
341            // Get indices of features with lowest importance
342            let mut indices: Vec<usize> = (0..importances.len()).collect();
343            indices.sort_by(|&i, &j| importances[i].partial_cmp(&importances[j]).unwrap());
344
345            // Mark eliminated features with current rank
346            for i in 0..n_to_remove {
347                let feature_idx = remaining_features[indices[i]];
348                ranking[feature_idx] = n_features - current_rank + 1;
349                current_rank += 1;
350            }
351
352            // Remove eliminated features
353            let eliminated: std::collections::HashSet<usize> =
354                indices.iter().take(n_to_remove).cloned().collect();
355            let features_to_retain: Vec<usize> = remaining_features
356                .iter()
357                .filter(|&&idx| !eliminated.contains(&idx))
358                .cloned()
359                .collect();
360            remaining_features = features_to_retain;
361        }
362
363        // Mark remaining features as rank 1
364        for &feature_idx in &remaining_features {
365            ranking[feature_idx] = 1;
366        }
367
368        // Compute final scores for selected features
369        let x_final = self.subset_features(x, &remaining_features);
370        let final_scores = (self.importance_func)(&x_final, y)?;
371
372        let mut scores = Array1::zeros(n_features);
373        for (i, &feature_idx) in remaining_features.iter().enumerate() {
374            scores[feature_idx] = final_scores[i];
375        }
376
377        self.selected_features_ = Some(remaining_features);
378        self.ranking_ = Some(ranking);
379        self.scores_ = Some(scores);
380
381        Ok(())
382    }
383
384    /// Create a subset of features
385    fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
386        let n_samples = x.shape()[0];
387        let n_selected = features.len();
388        let mut subset = Array2::zeros((n_samples, n_selected));
389
390        for (new_idx, &old_idx) in features.iter().enumerate() {
391            subset.column_mut(new_idx).assign(&x.column(old_idx));
392        }
393
394        subset
395    }
396
397    /// Transform data by selecting features
398    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
399        if self.selected_features_.is_none() {
400            return Err(TransformError::TransformationError(
401                "RFE has not been fitted".to_string(),
402            ));
403        }
404
405        let selected = self.selected_features_.as_ref().unwrap();
406        Ok(self.subset_features(x, selected))
407    }
408
409    /// Fit and transform in one step
410    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
411        self.fit(x, y)?;
412        self.transform(x)
413    }
414
415    /// Get selected feature indices
416    pub fn get_support(&self) -> Option<&Vec<usize>> {
417        self.selected_features_.as_ref()
418    }
419
420    /// Get feature rankings (1 is best)
421    pub fn ranking(&self) -> Option<&Array1<usize>> {
422        self.ranking_.as_ref()
423    }
424
425    /// Get feature scores
426    pub fn scores(&self) -> Option<&Array1<f64>> {
427        self.scores_.as_ref()
428    }
429}
430
431/// Mutual information based feature selection
432///
433/// Selects features based on mutual information between features and target.
434#[derive(Debug, Clone)]
435pub struct MutualInfoSelector {
436    /// Number of features to select
437    k: usize,
438    /// Whether to use discrete mutual information
439    discrete_target: bool,
440    /// Number of neighbors for KNN estimation
441    n_neighbors: usize,
442    /// Selected feature indices
443    selected_features_: Option<Vec<usize>>,
444    /// Mutual information scores
445    scores_: Option<Array1<f64>>,
446}
447
448impl MutualInfoSelector {
449    /// Create a new mutual information selector
450    ///
451    /// # Arguments
452    /// * `k` - Number of top features to select
453    pub fn new(k: usize) -> Self {
454        MutualInfoSelector {
455            k,
456            discrete_target: false,
457            n_neighbors: 3,
458            selected_features_: None,
459            scores_: None,
460        }
461    }
462
463    /// Use discrete mutual information (for classification)
464    pub fn with_discrete_target(mut self) -> Self {
465        self.discrete_target = true;
466        self
467    }
468
469    /// Set number of neighbors for KNN estimation
470    pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
471        self.n_neighbors = nneighbors;
472        self
473    }
474
475    /// Estimate mutual information using KNN method (simplified)
476    fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
477        let n = x.len();
478        if n < self.n_neighbors + 1 {
479            return 0.0;
480        }
481
482        // Simple correlation-based approximation for continuous variables
483        if !self.discrete_target {
484            // Standardize variables
485            let x_mean = x.mean().unwrap_or(0.0);
486            let y_mean = y.mean().unwrap_or(0.0);
487            let x_std = x.std(0.0);
488            let y_std = y.std(0.0);
489
490            if x_std < 1e-10 || y_std < 1e-10 {
491                return 0.0;
492            }
493
494            let mut correlation = 0.0;
495            for i in 0..n {
496                correlation += (x[i] - x_mean) * (y[i] - y_mean);
497            }
498            correlation /= (n as f64 - 1.0) * x_std * y_std;
499
500            // Convert correlation to mutual information approximation
501            // MI ≈ -0.5 * log(1 - r²) for Gaussian variables
502            if correlation.abs() >= 1.0 {
503                return 5.0; // Cap at reasonable value
504            }
505            (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
506        } else {
507            // For discrete targets, use a simple grouping approach
508            let mut groups = std::collections::HashMap::new();
509
510            for i in 0..n {
511                let key = y[i].round() as i64;
512                groups.entry(key).or_insert_with(Vec::new).push(x[i]);
513            }
514
515            // Calculate between-group variance / total variance ratio
516            let total_mean = x.mean().unwrap_or(0.0);
517            let total_var = x.variance();
518
519            if total_var < 1e-10 {
520                return 0.0;
521            }
522
523            let mut between_var = 0.0;
524            for (_, values) in groups {
525                let group_mean = values.iter().sum::<f64>() / values.len() as f64;
526                let weight = values.len() as f64 / n as f64;
527                between_var += weight * (group_mean - total_mean).powi(2);
528            }
529
530            (between_var / total_var).min(1.0) * 2.0 // Scale to reasonable range
531        }
532    }
533
534    /// Fit the selector
535    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
536        let n_features = x.shape()[1];
537
538        if self.k > n_features {
539            return Err(TransformError::InvalidInput(format!(
540                "k={} must be <= n_features={}",
541                self.k, n_features
542            )));
543        }
544
545        // Compute mutual information for each feature
546        let mut scores = Array1::zeros(n_features);
547
548        for j in 0..n_features {
549            let feature = x.column(j).to_owned();
550            scores[j] = self.estimate_mutual_info(&feature, y);
551        }
552
553        // Select top k features
554        let mut indices: Vec<usize> = (0..n_features).collect();
555        indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).unwrap());
556
557        let selected_features = indices.into_iter().take(self.k).collect();
558
559        self.scores_ = Some(scores);
560        self.selected_features_ = Some(selected_features);
561
562        Ok(())
563    }
564
565    /// Transform data by selecting features
566    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
567        if self.selected_features_.is_none() {
568            return Err(TransformError::TransformationError(
569                "MutualInfoSelector has not been fitted".to_string(),
570            ));
571        }
572
573        let selected = self.selected_features_.as_ref().unwrap();
574        let n_samples = x.shape()[0];
575        let mut transformed = Array2::zeros((n_samples, self.k));
576
577        for (new_idx, &old_idx) in selected.iter().enumerate() {
578            transformed.column_mut(new_idx).assign(&x.column(old_idx));
579        }
580
581        Ok(transformed)
582    }
583
584    /// Fit and transform in one step
585    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
586        self.fit(x, y)?;
587        self.transform(x)
588    }
589
590    /// Get selected feature indices
591    pub fn get_support(&self) -> Option<&Vec<usize>> {
592        self.selected_features_.as_ref()
593    }
594
595    /// Get mutual information scores
596    pub fn scores(&self) -> Option<&Array1<f64>> {
597        self.scores_.as_ref()
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use approx::assert_abs_diff_eq;
605    use scirs2_core::ndarray::Array;
606
607    #[test]
608    fn test_variance_threshold_basic() {
609        // Create test data with different variances
610        // Feature 0: [1, 1, 1] - constant, variance = 0
611        // Feature 1: [1, 2, 3] - varying, variance > 0
612        // Feature 2: [5, 5, 5] - constant, variance = 0
613        // Feature 3: [1, 3, 5] - varying, variance > 0
614        let data = Array::from_shape_vec(
615            (3, 4),
616            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
617        )
618        .unwrap();
619
620        let mut selector = VarianceThreshold::with_defaults();
621        let transformed = selector.fit_transform(&data).unwrap();
622
623        // Should keep features 1 and 3 (indices 1 and 3)
624        assert_eq!(transformed.shape(), &[3, 2]);
625
626        // Check that we kept the right features
627        let selected = selector.get_support().unwrap();
628        assert_eq!(selected, &[1, 3]);
629
630        // Check transformed values
631        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // Feature 1, sample 0
632        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); // Feature 1, sample 1
633        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); // Feature 1, sample 2
634
635        assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); // Feature 3, sample 0
636        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); // Feature 3, sample 1
637        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); // Feature 3, sample 2
638    }
639
640    #[test]
641    fn test_variance_threshold_custom() {
642        // Create test data with specific variances
643        let data = Array::from_shape_vec(
644            (4, 3),
645            vec![
646                1.0, 1.0, 1.0, // Sample 0
647                2.0, 1.1, 2.0, // Sample 1
648                3.0, 1.0, 3.0, // Sample 2
649                4.0, 1.1, 4.0, // Sample 3
650            ],
651        )
652        .unwrap();
653
654        // Set threshold to remove features with very low variance
655        let mut selector = VarianceThreshold::new(0.1).unwrap();
656        let transformed = selector.fit_transform(&data).unwrap();
657
658        // Feature 1 has very low variance (between 1.0 and 1.1), should be removed
659        // Features 0 and 2 have higher variance, should be kept
660        assert_eq!(transformed.shape(), &[4, 2]);
661
662        let selected = selector.get_support().unwrap();
663        assert_eq!(selected, &[0, 2]);
664
665        // Check variances
666        let variances = selector.variances().unwrap();
667        assert!(variances[0] > 0.1); // Feature 0 variance
668        assert!(variances[1] <= 0.1); // Feature 1 variance (should be low)
669        assert!(variances[2] > 0.1); // Feature 2 variance
670    }
671
672    #[test]
673    fn test_variance_threshold_support_mask() {
674        let data = Array::from_shape_vec(
675            (3, 4),
676            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
677        )
678        .unwrap();
679
680        let mut selector = VarianceThreshold::with_defaults();
681        selector.fit(&data).unwrap();
682
683        let mask = selector.get_support_mask().unwrap();
684        assert_eq!(mask.len(), 4);
685        assert!(!mask[0]); // Feature 0 is constant
686        assert!(mask[1]); // Feature 1 has variance
687        assert!(!mask[2]); // Feature 2 is constant
688        assert!(mask[3]); // Feature 3 has variance
689
690        assert_eq!(selector.n_features_selected().unwrap(), 2);
691    }
692
693    #[test]
694    fn test_variance_threshold_all_removed() {
695        // Create data where all features are constant
696        let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
697
698        let mut selector = VarianceThreshold::with_defaults();
699        let transformed = selector.fit_transform(&data).unwrap();
700
701        // All features should be removed
702        assert_eq!(transformed.shape(), &[3, 0]);
703        assert_eq!(selector.n_features_selected().unwrap(), 0);
704    }
705
706    #[test]
707    fn test_variance_threshold_errors() {
708        // Test negative threshold
709        assert!(VarianceThreshold::new(-0.1).is_err());
710
711        // Test with insufficient samples
712        let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
713        let mut selector = VarianceThreshold::with_defaults();
714        assert!(selector.fit(&small_data).is_err());
715
716        // Test transform before fit
717        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
718        let selector_unfitted = VarianceThreshold::with_defaults();
719        assert!(selector_unfitted.transform(&data).is_err());
720
721        // Test inverse transform (should always fail)
722        let mut selector = VarianceThreshold::with_defaults();
723        selector.fit(&data).unwrap();
724        assert!(selector.inverse_transform(&data).is_err());
725    }
726
727    #[test]
728    fn test_variance_threshold_feature_mismatch() {
729        let train_data =
730            Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
731                .unwrap();
732        let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); // Different number of features
733
734        let mut selector = VarianceThreshold::with_defaults();
735        selector.fit(&train_data).unwrap();
736        assert!(selector.transform(&test_data).is_err());
737    }
738
739    #[test]
740    fn test_variance_calculation() {
741        // Test variance calculation manually
742        // Data: [1, 2, 3] should have variance = ((1-2)² + (2-2)² + (3-2)²) / 3 = (1 + 0 + 1) / 3 = 2/3
743        let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
744
745        let mut selector = VarianceThreshold::with_defaults();
746        selector.fit(&data).unwrap();
747
748        let variances = selector.variances().unwrap();
749        let expected_variance = 2.0 / 3.0;
750        assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
751    }
752
753    #[test]
754    fn test_rfe_basic() {
755        // Create test data where features have clear importance
756        let n_samples = 100;
757        let mut data_vec = Vec::new();
758        let mut target_vec = Vec::new();
759
760        for i in 0..n_samples {
761            let x1 = i as f64 / n_samples as f64;
762            let x2 = (i as f64 / n_samples as f64).sin();
763            let x3 = scirs2_core::random::random::<f64>(); // Noise
764            let x4 = 2.0 * x1; // Highly correlated with target
765
766            data_vec.extend_from_slice(&[x1, x2, x3, x4]);
767            target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
768        }
769
770        let x = Array::from_shape_vec((n_samples, 4), data_vec).unwrap();
771        let y = Array::from_vec(target_vec);
772
773        // Simple importance function based on correlation
774        let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
775            let n_features = x.shape()[1];
776            let mut scores = Array1::zeros(n_features);
777
778            for j in 0..n_features {
779                let feature = x.column(j);
780                let corr = pearson_correlation(&feature.to_owned(), y);
781                scores[j] = corr.abs();
782            }
783
784            Ok(scores)
785        };
786
787        let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
788        let transformed = rfe.fit_transform(&x, &y).unwrap();
789
790        // Should select 2 features
791        assert_eq!(transformed.shape()[1], 2);
792
793        // Check that features 0 and 3 (most important) were selected
794        let selected = rfe.get_support().unwrap();
795        assert!(selected.contains(&0) || selected.contains(&3));
796    }
797
798    #[test]
799    fn test_mutual_info_continuous() {
800        // Create data with clear relationships
801        let n_samples = 100;
802        let mut x_data = Vec::new();
803        let mut y_data = Vec::new();
804
805        for i in 0..n_samples {
806            let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
807
808            // Feature 0: Strongly related to target
809            let x0 = t;
810            // Feature 1: Noise
811            let x1 = scirs2_core::random::random::<f64>();
812            // Feature 2: Non-linearly related
813            let x2 = t.sin();
814
815            x_data.extend_from_slice(&[x0, x1, x2]);
816            y_data.push(t + 0.5 * t.sin());
817        }
818
819        let x = Array::from_shape_vec((n_samples, 3), x_data).unwrap();
820        let y = Array::from_vec(y_data);
821
822        let mut selector = MutualInfoSelector::new(2);
823        selector.fit(&x, &y).unwrap();
824
825        let scores = selector.scores().unwrap();
826
827        // Feature 0 should have highest score (linear relationship)
828        // Feature 2 should have second highest (non-linear relationship)
829        // Feature 1 should have lowest score (noise)
830        assert!(scores[0] > scores[1]);
831        assert!(scores[2] > scores[1]);
832    }
833
834    #[test]
835    fn test_mutual_info_discrete() {
836        // Create classification-like data
837        let x = Array::from_shape_vec(
838            (6, 3),
839            vec![
840                1.0, 0.1, 5.0, // Class 0
841                1.1, 0.2, 5.1, // Class 0
842                2.0, 0.1, 4.0, // Class 1
843                2.1, 0.2, 4.1, // Class 1
844                3.0, 0.1, 3.0, // Class 2
845                3.1, 0.2, 3.1, // Class 2
846            ],
847        )
848        .unwrap();
849
850        let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
851
852        let mut selector = MutualInfoSelector::new(2).with_discrete_target();
853        let transformed = selector.fit_transform(&x, &y).unwrap();
854
855        assert_eq!(transformed.shape(), &[6, 2]);
856
857        // Feature 1 (middle column) has low variance within groups, should be excluded
858        let selected = selector.get_support().unwrap();
859        assert!(!selected.contains(&1));
860    }
861
862    // Helper function for correlation
863    fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
864        #[allow(unused_variables)]
865        let n = x.len() as f64;
866        let x_mean = x.mean().unwrap_or(0.0);
867        let y_mean = y.mean().unwrap_or(0.0);
868
869        let mut num = 0.0;
870        let mut x_var = 0.0;
871        let mut y_var = 0.0;
872
873        for i in 0..x.len() {
874            let x_diff = x[i] - x_mean;
875            let y_diff = y[i] - y_mean;
876            num += x_diff * y_diff;
877            x_var += x_diff * x_diff;
878            y_var += y_diff * y_diff;
879        }
880
881        if x_var * y_var > 0.0 {
882            num / (x_var * y_var).sqrt()
883        } else {
884            0.0
885        }
886    }
887}