Skip to main content

scirs2_transform/
selection.rs

1//! Feature selection utilities
2//!
3//! This module provides methods for selecting relevant features from datasets,
4//! which can help reduce dimensionality and improve model performance.
5
6use scirs2_core::ndarray::ArrayStatCompat;
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9
10use crate::error::{Result, TransformError};
11use statrs::statistics::Statistics;
12
13/// VarianceThreshold for removing low-variance features
14///
15/// Features with variance below the threshold are removed. This is useful for
16/// removing features that are mostly constant and don't provide much information.
17pub struct VarianceThreshold {
18    /// Variance threshold for feature selection
19    threshold: f64,
20    /// Variances computed for each feature (learned during fit)
21    variances_: Option<Array1<f64>>,
22    /// Indices of selected features
23    selected_features_: Option<Vec<usize>>,
24}
25
26impl VarianceThreshold {
27    /// Creates a new VarianceThreshold selector
28    ///
29    /// # Arguments
30    /// * `threshold` - Features with variance below this threshold are removed (default: 0.0)
31    ///
32    /// # Returns
33    /// * A new VarianceThreshold instance
34    ///
35    /// # Examples
36    /// ```
37    /// use scirs2_transform::selection::VarianceThreshold;
38    ///
39    /// // Remove features with variance less than 0.1
40    /// let selector = VarianceThreshold::new(0.1);
41    /// ```
42    pub fn new(threshold: f64) -> Result<Self> {
43        if threshold < 0.0 {
44            return Err(TransformError::InvalidInput(
45                "Threshold must be non-negative".to_string(),
46            ));
47        }
48
49        Ok(VarianceThreshold {
50            threshold,
51            variances_: None,
52            selected_features_: None,
53        })
54    }
55
56    /// Creates a VarianceThreshold with default threshold (0.0)
57    ///
58    /// This will only remove features that are completely constant.
59    pub fn with_defaults() -> Self {
60        Self::new(0.0).expect("Operation failed")
61    }
62
63    /// Fits the VarianceThreshold to the input data
64    ///
65    /// # Arguments
66    /// * `x` - The input data, shape (n_samples, n_features)
67    ///
68    /// # Returns
69    /// * `Result<()>` - Ok if successful, Err otherwise
70    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
71    where
72        S: Data,
73        S::Elem: Float + NumCast,
74    {
75        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
76
77        let n_samples = x_f64.shape()[0];
78        let n_features = x_f64.shape()[1];
79
80        if n_samples == 0 || n_features == 0 {
81            return Err(TransformError::InvalidInput("Empty input data".to_string()));
82        }
83
84        if n_samples < 2 {
85            return Err(TransformError::InvalidInput(
86                "At least 2 samples required to compute variance".to_string(),
87            ));
88        }
89
90        // Compute variance for each feature
91        let mut variances = Array1::zeros(n_features);
92        let mut selected_features = Vec::new();
93
94        for j in 0..n_features {
95            let feature_data = x_f64.column(j);
96
97            // Calculate mean
98            let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
99
100            // Calculate variance (using population variance for consistency with sklearn)
101            let variance = feature_data
102                .iter()
103                .map(|&x| (x - mean).powi(2))
104                .sum::<f64>()
105                / n_samples as f64;
106
107            variances[j] = variance;
108
109            // Select feature if variance is above threshold
110            if variance > self.threshold {
111                selected_features.push(j);
112            }
113        }
114
115        self.variances_ = Some(variances);
116        self.selected_features_ = Some(selected_features);
117
118        Ok(())
119    }
120
121    /// Transforms the input data by removing low-variance features
122    ///
123    /// # Arguments
124    /// * `x` - The input data, shape (n_samples, n_features)
125    ///
126    /// # Returns
127    /// * `Result<Array2<f64>>` - The transformed data with selected features only
128    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
129    where
130        S: Data,
131        S::Elem: Float + NumCast,
132    {
133        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
134
135        let n_samples = x_f64.shape()[0];
136        let n_features = x_f64.shape()[1];
137
138        if self.selected_features_.is_none() {
139            return Err(TransformError::TransformationError(
140                "VarianceThreshold has not been fitted".to_string(),
141            ));
142        }
143
144        let selected_features = self.selected_features_.as_ref().expect("Operation failed");
145
146        // Check feature consistency
147        if let Some(ref variances) = self.variances_ {
148            if n_features != variances.len() {
149                return Err(TransformError::InvalidInput(format!(
150                    "x has {} features, but VarianceThreshold was fitted with {} features",
151                    n_features,
152                    variances.len()
153                )));
154            }
155        }
156
157        let n_selected = selected_features.len();
158        let mut transformed = Array2::zeros((n_samples, n_selected));
159
160        // Copy selected features
161        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
162            for i in 0..n_samples {
163                transformed[[i, new_idx]] = x_f64[[i, old_idx]];
164            }
165        }
166
167        Ok(transformed)
168    }
169
170    /// Fits the VarianceThreshold to the input data and transforms it
171    ///
172    /// # Arguments
173    /// * `x` - The input data, shape (n_samples, n_features)
174    ///
175    /// # Returns
176    /// * `Result<Array2<f64>>` - The transformed data with selected features only
177    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
178    where
179        S: Data,
180        S::Elem: Float + NumCast,
181    {
182        self.fit(x)?;
183        self.transform(x)
184    }
185
186    /// Returns the variances computed for each feature
187    ///
188    /// # Returns
189    /// * `Option<&Array1<f64>>` - The variances for each feature
190    pub fn variances(&self) -> Option<&Array1<f64>> {
191        self.variances_.as_ref()
192    }
193
194    /// Returns the indices of selected features
195    ///
196    /// # Returns
197    /// * `Option<&Vec<usize>>` - Indices of features that pass the variance threshold
198    pub fn get_support(&self) -> Option<&Vec<usize>> {
199        self.selected_features_.as_ref()
200    }
201
202    /// Returns a boolean mask indicating which features are selected
203    ///
204    /// # Returns
205    /// * `Option<Array1<bool>>` - Boolean mask where true indicates selected features
206    pub fn get_support_mask(&self) -> Option<Array1<bool>> {
207        if let (Some(ref variances), Some(ref selected)) =
208            (&self.variances_, &self.selected_features_)
209        {
210            let n_features = variances.len();
211            let mut mask = Array1::from_elem(n_features, false);
212
213            for &idx in selected {
214                mask[idx] = true;
215            }
216
217            Some(mask)
218        } else {
219            None
220        }
221    }
222
223    /// Returns the number of selected features
224    ///
225    /// # Returns
226    /// * `Option<usize>` - Number of features that pass the variance threshold
227    pub fn n_features_selected(&self) -> Option<usize> {
228        self.selected_features_.as_ref().map(|s| s.len())
229    }
230
231    /// Inverse transform - not applicable for feature selection
232    ///
233    /// This method is not implemented for feature selection as it's not possible
234    /// to reconstruct removed features.
235    pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
236    where
237        S: Data,
238        S::Elem: Float + NumCast,
239    {
240        Err(TransformError::TransformationError(
241            "inverse_transform is not supported for feature selection".to_string(),
242        ))
243    }
244}
245
246/// Recursive Feature Elimination (RFE) for feature selection
247///
248/// RFE works by recursively removing features and evaluating model performance.
249/// This implementation uses a feature importance scoring function to rank features.
250#[derive(Debug, Clone)]
251pub struct RecursiveFeatureElimination<F>
252where
253    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
254{
255    /// Number of features to select
256    n_features_to_select: usize,
257    /// Number of features to remove at each iteration
258    step: usize,
259    /// Feature importance scoring function
260    /// Takes (X, y) and returns importance scores for each feature
261    importance_func: F,
262    /// Indices of selected features
263    selected_features_: Option<Vec<usize>>,
264    /// Feature rankings (1 is best)
265    ranking_: Option<Array1<usize>>,
266    /// Feature importance scores
267    scores_: Option<Array1<f64>>,
268}
269
270impl<F> RecursiveFeatureElimination<F>
271where
272    F: Fn(&Array2<f64>, &Array1<f64>) -> Result<Array1<f64>>,
273{
274    /// Creates a new RFE selector
275    ///
276    /// # Arguments
277    /// * `n_features_to_select` - Number of features to select
278    /// * `importance_func` - Function that computes feature importance scores
279    pub fn new(n_features_to_select: usize, importancefunc: F) -> Self {
280        RecursiveFeatureElimination {
281            n_features_to_select,
282            step: 1,
283            importance_func: importancefunc,
284            selected_features_: None,
285            ranking_: None,
286            scores_: None,
287        }
288    }
289
290    /// Set the number of features to remove at each iteration
291    pub fn with_step(mut self, step: usize) -> Self {
292        self.step = step.max(1);
293        self
294    }
295
296    /// Fit the RFE selector
297    ///
298    /// # Arguments
299    /// * `x` - Training data, shape (n_samples, n_features)
300    /// * `y` - Target values, shape (n_samples,)
301    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
302        let n_samples = x.shape()[0];
303        let n_features = x.shape()[1];
304
305        if n_samples != y.len() {
306            return Err(TransformError::InvalidInput(format!(
307                "X has {} samples but y has {} samples",
308                n_samples,
309                y.len()
310            )));
311        }
312
313        if self.n_features_to_select > n_features {
314            return Err(TransformError::InvalidInput(format!(
315                "n_features_to_select={} must be <= n_features={}",
316                self.n_features_to_select, n_features
317            )));
318        }
319
320        // Initialize with all features
321        let mut remaining_features: Vec<usize> = (0..n_features).collect();
322        let mut ranking = Array1::zeros(n_features);
323        let mut current_rank = 1;
324
325        // Recursively eliminate features
326        while remaining_features.len() > self.n_features_to_select {
327            // Create subset of data with remaining features
328            let x_subset = self.subset_features(x, &remaining_features);
329
330            // Get feature importances
331            let importances = (self.importance_func)(&x_subset, y)?;
332
333            if importances.len() != remaining_features.len() {
334                return Err(TransformError::InvalidInput(
335                    "Importance function returned wrong number of scores".to_string(),
336                ));
337            }
338
339            // Find features to eliminate
340            let n_to_remove = (self.step).min(remaining_features.len() - self.n_features_to_select);
341
342            // Get indices of features with lowest importance
343            let mut indices: Vec<usize> = (0..importances.len()).collect();
344            indices.sort_by(|&i, &j| {
345                importances[i]
346                    .partial_cmp(&importances[j])
347                    .expect("Operation failed")
348            });
349
350            // Mark eliminated features with current rank
351            for i in 0..n_to_remove {
352                let feature_idx = remaining_features[indices[i]];
353                ranking[feature_idx] = n_features - current_rank + 1;
354                current_rank += 1;
355            }
356
357            // Remove eliminated features
358            let eliminated: std::collections::HashSet<usize> =
359                indices.iter().take(n_to_remove).cloned().collect();
360            let features_to_retain: Vec<usize> = remaining_features
361                .iter()
362                .filter(|&&idx| !eliminated.contains(&idx))
363                .cloned()
364                .collect();
365            remaining_features = features_to_retain;
366        }
367
368        // Mark remaining features as rank 1
369        for &feature_idx in &remaining_features {
370            ranking[feature_idx] = 1;
371        }
372
373        // Compute final scores for selected features
374        let x_final = self.subset_features(x, &remaining_features);
375        let final_scores = (self.importance_func)(&x_final, y)?;
376
377        let mut scores = Array1::zeros(n_features);
378        for (i, &feature_idx) in remaining_features.iter().enumerate() {
379            scores[feature_idx] = final_scores[i];
380        }
381
382        self.selected_features_ = Some(remaining_features);
383        self.ranking_ = Some(ranking);
384        self.scores_ = Some(scores);
385
386        Ok(())
387    }
388
389    /// Create a subset of features
390    fn subset_features(&self, x: &Array2<f64>, features: &[usize]) -> Array2<f64> {
391        let n_samples = x.shape()[0];
392        let n_selected = features.len();
393        let mut subset = Array2::zeros((n_samples, n_selected));
394
395        for (new_idx, &old_idx) in features.iter().enumerate() {
396            subset.column_mut(new_idx).assign(&x.column(old_idx));
397        }
398
399        subset
400    }
401
402    /// Transform data by selecting features
403    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
404        if self.selected_features_.is_none() {
405            return Err(TransformError::TransformationError(
406                "RFE has not been fitted".to_string(),
407            ));
408        }
409
410        let selected = self.selected_features_.as_ref().expect("Operation failed");
411        Ok(self.subset_features(x, selected))
412    }
413
414    /// Fit and transform in one step
415    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
416        self.fit(x, y)?;
417        self.transform(x)
418    }
419
420    /// Get selected feature indices
421    pub fn get_support(&self) -> Option<&Vec<usize>> {
422        self.selected_features_.as_ref()
423    }
424
425    /// Get feature rankings (1 is best)
426    pub fn ranking(&self) -> Option<&Array1<usize>> {
427        self.ranking_.as_ref()
428    }
429
430    /// Get feature scores
431    pub fn scores(&self) -> Option<&Array1<f64>> {
432        self.scores_.as_ref()
433    }
434}
435
436/// Mutual information based feature selection
437///
438/// Selects features based on mutual information between features and target.
439#[derive(Debug, Clone)]
440pub struct MutualInfoSelector {
441    /// Number of features to select
442    k: usize,
443    /// Whether to use discrete mutual information
444    discrete_target: bool,
445    /// Number of neighbors for KNN estimation
446    n_neighbors: usize,
447    /// Selected feature indices
448    selected_features_: Option<Vec<usize>>,
449    /// Mutual information scores
450    scores_: Option<Array1<f64>>,
451}
452
453impl MutualInfoSelector {
454    /// Create a new mutual information selector
455    ///
456    /// # Arguments
457    /// * `k` - Number of top features to select
458    pub fn new(k: usize) -> Self {
459        MutualInfoSelector {
460            k,
461            discrete_target: false,
462            n_neighbors: 3,
463            selected_features_: None,
464            scores_: None,
465        }
466    }
467
468    /// Use discrete mutual information (for classification)
469    pub fn with_discrete_target(mut self) -> Self {
470        self.discrete_target = true;
471        self
472    }
473
474    /// Set number of neighbors for KNN estimation
475    pub fn with_n_neighbors(mut self, nneighbors: usize) -> Self {
476        self.n_neighbors = nneighbors;
477        self
478    }
479
480    /// Estimate mutual information using KNN method (simplified)
481    fn estimate_mutual_info(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
482        let n = x.len();
483        if n < self.n_neighbors + 1 {
484            return 0.0;
485        }
486
487        // Simple correlation-based approximation for continuous variables
488        if !self.discrete_target {
489            // Standardize variables
490            let x_mean = x.mean_or(0.0);
491            let y_mean = y.mean_or(0.0);
492            let x_std = x.std(0.0);
493            let y_std = y.std(0.0);
494
495            if x_std < 1e-10 || y_std < 1e-10 {
496                return 0.0;
497            }
498
499            let mut correlation = 0.0;
500            for i in 0..n {
501                correlation += (x[i] - x_mean) * (y[i] - y_mean);
502            }
503            correlation /= (n as f64 - 1.0) * x_std * y_std;
504
505            // Convert correlation to mutual information approximation
506            // MI ≈ -0.5 * log(1 - r²) for Gaussian variables
507            if correlation.abs() >= 1.0 {
508                return 5.0; // Cap at reasonable value
509            }
510            (-0.5 * (1.0 - correlation * correlation).ln()).max(0.0)
511        } else {
512            // For discrete targets, use a simple grouping approach
513            let mut groups = std::collections::HashMap::new();
514
515            for i in 0..n {
516                let key = y[i].round() as i64;
517                groups.entry(key).or_insert_with(Vec::new).push(x[i]);
518            }
519
520            // Calculate between-group variance / total variance ratio
521            let total_mean = x.mean_or(0.0);
522            let total_var = x.variance();
523
524            if total_var < 1e-10 {
525                return 0.0;
526            }
527
528            let mut between_var = 0.0;
529            for (_, values) in groups {
530                let group_mean = values.iter().sum::<f64>() / values.len() as f64;
531                let weight = values.len() as f64 / n as f64;
532                between_var += weight * (group_mean - total_mean).powi(2);
533            }
534
535            (between_var / total_var).min(1.0) * 2.0 // Scale to reasonable range
536        }
537    }
538
539    /// Fit the selector
540    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
541        let n_features = x.shape()[1];
542
543        if self.k > n_features {
544            return Err(TransformError::InvalidInput(format!(
545                "k={} must be <= n_features={}",
546                self.k, n_features
547            )));
548        }
549
550        // Compute mutual information for each feature
551        let mut scores = Array1::zeros(n_features);
552
553        for j in 0..n_features {
554            let feature = x.column(j).to_owned();
555            scores[j] = self.estimate_mutual_info(&feature, y);
556        }
557
558        // Select top k features
559        let mut indices: Vec<usize> = (0..n_features).collect();
560        indices.sort_by(|&i, &j| scores[j].partial_cmp(&scores[i]).expect("Operation failed"));
561
562        let selected_features = indices.into_iter().take(self.k).collect();
563
564        self.scores_ = Some(scores);
565        self.selected_features_ = Some(selected_features);
566
567        Ok(())
568    }
569
570    /// Transform data by selecting features
571    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
572        if self.selected_features_.is_none() {
573            return Err(TransformError::TransformationError(
574                "MutualInfoSelector has not been fitted".to_string(),
575            ));
576        }
577
578        let selected = self.selected_features_.as_ref().expect("Operation failed");
579        let n_samples = x.shape()[0];
580        let mut transformed = Array2::zeros((n_samples, self.k));
581
582        for (new_idx, &old_idx) in selected.iter().enumerate() {
583            transformed.column_mut(new_idx).assign(&x.column(old_idx));
584        }
585
586        Ok(transformed)
587    }
588
589    /// Fit and transform in one step
590    pub fn fit_transform(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Array2<f64>> {
591        self.fit(x, y)?;
592        self.transform(x)
593    }
594
595    /// Get selected feature indices
596    pub fn get_support(&self) -> Option<&Vec<usize>> {
597        self.selected_features_.as_ref()
598    }
599
600    /// Get mutual information scores
601    pub fn scores(&self) -> Option<&Array1<f64>> {
602        self.scores_.as_ref()
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609    use approx::assert_abs_diff_eq;
610    use scirs2_core::ndarray::Array;
611
612    #[test]
613    fn test_variance_threshold_basic() {
614        // Create test data with different variances
615        // Feature 0: [1, 1, 1] - constant, variance = 0
616        // Feature 1: [1, 2, 3] - varying, variance > 0
617        // Feature 2: [5, 5, 5] - constant, variance = 0
618        // Feature 3: [1, 3, 5] - varying, variance > 0
619        let data = Array::from_shape_vec(
620            (3, 4),
621            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
622        )
623        .expect("Operation failed");
624
625        let mut selector = VarianceThreshold::with_defaults();
626        let transformed = selector.fit_transform(&data).expect("Operation failed");
627
628        // Should keep features 1 and 3 (indices 1 and 3)
629        assert_eq!(transformed.shape(), &[3, 2]);
630
631        // Check that we kept the right features
632        let selected = selector.get_support().expect("Operation failed");
633        assert_eq!(selected, &[1, 3]);
634
635        // Check transformed values
636        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // Feature 1, sample 0
637        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); // Feature 1, sample 1
638        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); // Feature 1, sample 2
639
640        assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); // Feature 3, sample 0
641        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); // Feature 3, sample 1
642        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); // Feature 3, sample 2
643    }
644
645    #[test]
646    fn test_variance_threshold_custom() {
647        // Create test data with specific variances
648        let data = Array::from_shape_vec(
649            (4, 3),
650            vec![
651                1.0, 1.0, 1.0, // Sample 0
652                2.0, 1.1, 2.0, // Sample 1
653                3.0, 1.0, 3.0, // Sample 2
654                4.0, 1.1, 4.0, // Sample 3
655            ],
656        )
657        .expect("Operation failed");
658
659        // Set threshold to remove features with very low variance
660        let mut selector = VarianceThreshold::new(0.1).expect("Operation failed");
661        let transformed = selector.fit_transform(&data).expect("Operation failed");
662
663        // Feature 1 has very low variance (between 1.0 and 1.1), should be removed
664        // Features 0 and 2 have higher variance, should be kept
665        assert_eq!(transformed.shape(), &[4, 2]);
666
667        let selected = selector.get_support().expect("Operation failed");
668        assert_eq!(selected, &[0, 2]);
669
670        // Check variances
671        let variances = selector.variances().expect("Operation failed");
672        assert!(variances[0] > 0.1); // Feature 0 variance
673        assert!(variances[1] <= 0.1); // Feature 1 variance (should be low)
674        assert!(variances[2] > 0.1); // Feature 2 variance
675    }
676
677    #[test]
678    fn test_variance_threshold_support_mask() {
679        let data = Array::from_shape_vec(
680            (3, 4),
681            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
682        )
683        .expect("Operation failed");
684
685        let mut selector = VarianceThreshold::with_defaults();
686        selector.fit(&data).expect("Operation failed");
687
688        let mask = selector.get_support_mask().expect("Operation failed");
689        assert_eq!(mask.len(), 4);
690        assert!(!mask[0]); // Feature 0 is constant
691        assert!(mask[1]); // Feature 1 has variance
692        assert!(!mask[2]); // Feature 2 is constant
693        assert!(mask[3]); // Feature 3 has variance
694
695        assert_eq!(selector.n_features_selected().expect("Operation failed"), 2);
696    }
697
698    #[test]
699    fn test_variance_threshold_all_removed() {
700        // Create data where all features are constant
701        let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0])
702            .expect("Operation failed");
703
704        let mut selector = VarianceThreshold::with_defaults();
705        let transformed = selector.fit_transform(&data).expect("Operation failed");
706
707        // All features should be removed
708        assert_eq!(transformed.shape(), &[3, 0]);
709        assert_eq!(selector.n_features_selected().expect("Operation failed"), 0);
710    }
711
712    #[test]
713    fn test_variance_threshold_errors() {
714        // Test negative threshold
715        assert!(VarianceThreshold::new(-0.1).is_err());
716
717        // Test with insufficient samples
718        let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("Operation failed");
719        let mut selector = VarianceThreshold::with_defaults();
720        assert!(selector.fit(&small_data).is_err());
721
722        // Test transform before fit
723        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
724            .expect("Operation failed");
725        let selector_unfitted = VarianceThreshold::with_defaults();
726        assert!(selector_unfitted.transform(&data).is_err());
727
728        // Test inverse transform (should always fail)
729        let mut selector = VarianceThreshold::with_defaults();
730        selector.fit(&data).expect("Operation failed");
731        assert!(selector.inverse_transform(&data).is_err());
732    }
733
734    #[test]
735    fn test_variance_threshold_feature_mismatch() {
736        let train_data =
737            Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
738                .expect("Operation failed");
739        let test_data =
740            Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed"); // Different number of features
741
742        let mut selector = VarianceThreshold::with_defaults();
743        selector.fit(&train_data).expect("Operation failed");
744        assert!(selector.transform(&test_data).is_err());
745    }
746
747    #[test]
748    fn test_variance_calculation() {
749        // Test variance calculation manually
750        // Data: [1, 2, 3] should have variance = ((1-2)² + (2-2)² + (3-2)²) / 3 = (1 + 0 + 1) / 3 = 2/3
751        let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).expect("Operation failed");
752
753        let mut selector = VarianceThreshold::with_defaults();
754        selector.fit(&data).expect("Operation failed");
755
756        let variances = selector.variances().expect("Operation failed");
757        let expected_variance = 2.0 / 3.0;
758        assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
759    }
760
761    #[test]
762    fn test_rfe_basic() {
763        // Create test data where features have clear importance
764        let n_samples = 100;
765        let mut data_vec = Vec::new();
766        let mut target_vec = Vec::new();
767
768        for i in 0..n_samples {
769            let x1 = i as f64 / n_samples as f64;
770            let x2 = (i as f64 / n_samples as f64).sin();
771            let x3 = scirs2_core::random::random::<f64>(); // Noise
772            let x4 = 2.0 * x1; // Highly correlated with target
773
774            data_vec.extend_from_slice(&[x1, x2, x3, x4]);
775            target_vec.push(3.0 * x1 + x4 + 0.1 * scirs2_core::random::random::<f64>());
776        }
777
778        let x = Array::from_shape_vec((n_samples, 4), data_vec).expect("Operation failed");
779        let y = Array::from_vec(target_vec);
780
781        // Simple importance function based on correlation
782        let importance_func = |x: &Array2<f64>, y: &Array1<f64>| -> Result<Array1<f64>> {
783            let n_features = x.shape()[1];
784            let mut scores = Array1::zeros(n_features);
785
786            for j in 0..n_features {
787                let feature = x.column(j);
788                let corr = pearson_correlation(&feature.to_owned(), y);
789                scores[j] = corr.abs();
790            }
791
792            Ok(scores)
793        };
794
795        let mut rfe = RecursiveFeatureElimination::new(2, importance_func);
796        let transformed = rfe.fit_transform(&x, &y).expect("Operation failed");
797
798        // Should select 2 features
799        assert_eq!(transformed.shape()[1], 2);
800
801        // Check that features 0 and 3 (most important) were selected
802        let selected = rfe.get_support().expect("Operation failed");
803        assert!(selected.contains(&0) || selected.contains(&3));
804    }
805
806    #[test]
807    fn test_mutual_info_continuous() {
808        // Create data with clear relationships
809        let n_samples = 100;
810        let mut x_data = Vec::new();
811        let mut y_data = Vec::new();
812
813        for i in 0..n_samples {
814            let t = i as f64 / n_samples as f64 * 2.0 * std::f64::consts::PI;
815
816            // Feature 0: Strongly related to target
817            let x0 = t;
818            // Feature 1: Noise
819            let x1 = scirs2_core::random::random::<f64>();
820            // Feature 2: Non-linearly related
821            let x2 = t.sin();
822
823            x_data.extend_from_slice(&[x0, x1, x2]);
824            y_data.push(t + 0.5 * t.sin());
825        }
826
827        let x = Array::from_shape_vec((n_samples, 3), x_data).expect("Operation failed");
828        let y = Array::from_vec(y_data);
829
830        let mut selector = MutualInfoSelector::new(2);
831        selector.fit(&x, &y).expect("Operation failed");
832
833        let scores = selector.scores().expect("Operation failed");
834
835        // Feature 0 should have highest score (linear relationship)
836        // Feature 2 should have second highest (non-linear relationship)
837        // Feature 1 should have lowest score (noise)
838        assert!(scores[0] > scores[1]);
839        assert!(scores[2] > scores[1]);
840    }
841
842    #[test]
843    fn test_mutual_info_discrete() {
844        // Create classification-like data
845        let x = Array::from_shape_vec(
846            (6, 3),
847            vec![
848                1.0, 0.1, 5.0, // Class 0
849                1.1, 0.2, 5.1, // Class 0
850                2.0, 0.1, 4.0, // Class 1
851                2.1, 0.2, 4.1, // Class 1
852                3.0, 0.1, 3.0, // Class 2
853                3.1, 0.2, 3.1, // Class 2
854            ],
855        )
856        .expect("Operation failed");
857
858        let y = Array::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);
859
860        let mut selector = MutualInfoSelector::new(2).with_discrete_target();
861        let transformed = selector.fit_transform(&x, &y).expect("Operation failed");
862
863        assert_eq!(transformed.shape(), &[6, 2]);
864
865        // Feature 1 (middle column) has low variance within groups, should be excluded
866        let selected = selector.get_support().expect("Operation failed");
867        assert!(!selected.contains(&1));
868    }
869
870    // Helper function for correlation
871    fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
872        #[allow(unused_variables)]
873        let n = x.len() as f64;
874        let x_mean = x.mean_or(0.0);
875        let y_mean = y.mean_or(0.0);
876
877        let mut num = 0.0;
878        let mut x_var = 0.0;
879        let mut y_var = 0.0;
880
881        for i in 0..x.len() {
882            let x_diff = x[i] - x_mean;
883            let y_diff = y[i] - y_mean;
884            num += x_diff * y_diff;
885            x_var += x_diff * x_diff;
886            y_var += y_diff * y_diff;
887        }
888
889        if x_var * y_var > 0.0 {
890            num / (x_var * y_var).sqrt()
891        } else {
892            0.0
893        }
894    }
895}