scirs2_transform/
selection.rs

1//! Feature selection utilities
2//!
3//! This module provides methods for selecting relevant features from datasets,
4//! which can help reduce dimensionality and improve model performance.
5
6use ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use num_traits::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10
11/// VarianceThreshold for removing low-variance features
12///
13/// Features with variance below the threshold are removed. This is useful for
14/// removing features that are mostly constant and don't provide much information.
15pub struct VarianceThreshold {
16    /// Variance threshold for feature selection
17    threshold: f64,
18    /// Variances computed for each feature (learned during fit)
19    variances_: Option<Array1<f64>>,
20    /// Indices of selected features
21    selected_features_: Option<Vec<usize>>,
22}
23
24impl VarianceThreshold {
25    /// Creates a new VarianceThreshold selector
26    ///
27    /// # Arguments
28    /// * `threshold` - Features with variance below this threshold are removed (default: 0.0)
29    ///
30    /// # Returns
31    /// * A new VarianceThreshold instance
32    ///
33    /// # Examples
34    /// ```
35    /// use scirs2_transform::selection::VarianceThreshold;
36    ///
37    /// // Remove features with variance less than 0.1
38    /// let selector = VarianceThreshold::new(0.1);
39    /// ```
40    pub fn new(threshold: f64) -> Result<Self> {
41        if threshold < 0.0 {
42            return Err(TransformError::InvalidInput(
43                "Threshold must be non-negative".to_string(),
44            ));
45        }
46
47        Ok(VarianceThreshold {
48            threshold,
49            variances_: None,
50            selected_features_: None,
51        })
52    }
53
54    /// Creates a VarianceThreshold with default threshold (0.0)
55    ///
56    /// This will only remove features that are completely constant.
57    pub fn with_defaults() -> Self {
58        Self::new(0.0).unwrap()
59    }
60
61    /// Fits the VarianceThreshold to the input data
62    ///
63    /// # Arguments
64    /// * `x` - The input data, shape (n_samples, n_features)
65    ///
66    /// # Returns
67    /// * `Result<()>` - Ok if successful, Err otherwise
68    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
69    where
70        S: Data,
71        S::Elem: Float + NumCast,
72    {
73        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
74
75        let n_samples = x_f64.shape()[0];
76        let n_features = x_f64.shape()[1];
77
78        if n_samples == 0 || n_features == 0 {
79            return Err(TransformError::InvalidInput("Empty input data".to_string()));
80        }
81
82        if n_samples < 2 {
83            return Err(TransformError::InvalidInput(
84                "At least 2 samples required to compute variance".to_string(),
85            ));
86        }
87
88        // Compute variance for each feature
89        let mut variances = Array1::zeros(n_features);
90        let mut selected_features = Vec::new();
91
92        for j in 0..n_features {
93            let feature_data = x_f64.column(j);
94
95            // Calculate mean
96            let mean = feature_data.iter().sum::<f64>() / n_samples as f64;
97
98            // Calculate variance (using population variance for consistency with sklearn)
99            let variance = feature_data
100                .iter()
101                .map(|&x| (x - mean).powi(2))
102                .sum::<f64>()
103                / n_samples as f64;
104
105            variances[j] = variance;
106
107            // Select feature if variance is above threshold
108            if variance > self.threshold {
109                selected_features.push(j);
110            }
111        }
112
113        self.variances_ = Some(variances);
114        self.selected_features_ = Some(selected_features);
115
116        Ok(())
117    }
118
119    /// Transforms the input data by removing low-variance features
120    ///
121    /// # Arguments
122    /// * `x` - The input data, shape (n_samples, n_features)
123    ///
124    /// # Returns
125    /// * `Result<Array2<f64>>` - The transformed data with selected features only
126    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
127    where
128        S: Data,
129        S::Elem: Float + NumCast,
130    {
131        let x_f64 = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
132
133        let n_samples = x_f64.shape()[0];
134        let n_features = x_f64.shape()[1];
135
136        if self.selected_features_.is_none() {
137            return Err(TransformError::TransformationError(
138                "VarianceThreshold has not been fitted".to_string(),
139            ));
140        }
141
142        let selected_features = self.selected_features_.as_ref().unwrap();
143
144        // Check feature consistency
145        if let Some(ref variances) = self.variances_ {
146            if n_features != variances.len() {
147                return Err(TransformError::InvalidInput(format!(
148                    "x has {} features, but VarianceThreshold was fitted with {} features",
149                    n_features,
150                    variances.len()
151                )));
152            }
153        }
154
155        let n_selected = selected_features.len();
156        let mut transformed = Array2::zeros((n_samples, n_selected));
157
158        // Copy selected features
159        for (new_idx, &old_idx) in selected_features.iter().enumerate() {
160            for i in 0..n_samples {
161                transformed[[i, new_idx]] = x_f64[[i, old_idx]];
162            }
163        }
164
165        Ok(transformed)
166    }
167
168    /// Fits the VarianceThreshold to the input data and transforms it
169    ///
170    /// # Arguments
171    /// * `x` - The input data, shape (n_samples, n_features)
172    ///
173    /// # Returns
174    /// * `Result<Array2<f64>>` - The transformed data with selected features only
175    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
176    where
177        S: Data,
178        S::Elem: Float + NumCast,
179    {
180        self.fit(x)?;
181        self.transform(x)
182    }
183
184    /// Returns the variances computed for each feature
185    ///
186    /// # Returns
187    /// * `Option<&Array1<f64>>` - The variances for each feature
188    pub fn variances(&self) -> Option<&Array1<f64>> {
189        self.variances_.as_ref()
190    }
191
192    /// Returns the indices of selected features
193    ///
194    /// # Returns
195    /// * `Option<&Vec<usize>>` - Indices of features that pass the variance threshold
196    pub fn get_support(&self) -> Option<&Vec<usize>> {
197        self.selected_features_.as_ref()
198    }
199
200    /// Returns a boolean mask indicating which features are selected
201    ///
202    /// # Returns
203    /// * `Option<Array1<bool>>` - Boolean mask where true indicates selected features
204    pub fn get_support_mask(&self) -> Option<Array1<bool>> {
205        if let (Some(ref variances), Some(ref selected)) =
206            (&self.variances_, &self.selected_features_)
207        {
208            let n_features = variances.len();
209            let mut mask = Array1::from_elem(n_features, false);
210
211            for &idx in selected {
212                mask[idx] = true;
213            }
214
215            Some(mask)
216        } else {
217            None
218        }
219    }
220
221    /// Returns the number of selected features
222    ///
223    /// # Returns
224    /// * `Option<usize>` - Number of features that pass the variance threshold
225    pub fn n_features_selected(&self) -> Option<usize> {
226        self.selected_features_.as_ref().map(|s| s.len())
227    }
228
229    /// Inverse transform - not applicable for feature selection
230    ///
231    /// This method is not implemented for feature selection as it's not possible
232    /// to reconstruct removed features.
233    pub fn inverse_transform<S>(&self, _x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
234    where
235        S: Data,
236        S::Elem: Float + NumCast,
237    {
238        Err(TransformError::TransformationError(
239            "inverse_transform is not supported for feature selection".to_string(),
240        ))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use approx::assert_abs_diff_eq;
248    use ndarray::Array;
249
250    #[test]
251    fn test_variance_threshold_basic() {
252        // Create test data with different variances
253        // Feature 0: [1, 1, 1] - constant, variance = 0
254        // Feature 1: [1, 2, 3] - varying, variance > 0
255        // Feature 2: [5, 5, 5] - constant, variance = 0
256        // Feature 3: [1, 3, 5] - varying, variance > 0
257        let data = Array::from_shape_vec(
258            (3, 4),
259            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
260        )
261        .unwrap();
262
263        let mut selector = VarianceThreshold::with_defaults();
264        let transformed = selector.fit_transform(&data).unwrap();
265
266        // Should keep features 1 and 3 (indices 1 and 3)
267        assert_eq!(transformed.shape(), &[3, 2]);
268
269        // Check that we kept the right features
270        let selected = selector.get_support().unwrap();
271        assert_eq!(selected, &[1, 3]);
272
273        // Check transformed values
274        assert_abs_diff_eq!(transformed[[0, 0]], 1.0, epsilon = 1e-10); // Feature 1, sample 0
275        assert_abs_diff_eq!(transformed[[1, 0]], 2.0, epsilon = 1e-10); // Feature 1, sample 1
276        assert_abs_diff_eq!(transformed[[2, 0]], 3.0, epsilon = 1e-10); // Feature 1, sample 2
277
278        assert_abs_diff_eq!(transformed[[0, 1]], 1.0, epsilon = 1e-10); // Feature 3, sample 0
279        assert_abs_diff_eq!(transformed[[1, 1]], 3.0, epsilon = 1e-10); // Feature 3, sample 1
280        assert_abs_diff_eq!(transformed[[2, 1]], 5.0, epsilon = 1e-10); // Feature 3, sample 2
281    }
282
283    #[test]
284    fn test_variance_threshold_custom() {
285        // Create test data with specific variances
286        let data = Array::from_shape_vec(
287            (4, 3),
288            vec![
289                1.0, 1.0, 1.0, // Sample 0
290                2.0, 1.1, 2.0, // Sample 1
291                3.0, 1.0, 3.0, // Sample 2
292                4.0, 1.1, 4.0, // Sample 3
293            ],
294        )
295        .unwrap();
296
297        // Set threshold to remove features with very low variance
298        let mut selector = VarianceThreshold::new(0.1).unwrap();
299        let transformed = selector.fit_transform(&data).unwrap();
300
301        // Feature 1 has very low variance (between 1.0 and 1.1), should be removed
302        // Features 0 and 2 have higher variance, should be kept
303        assert_eq!(transformed.shape(), &[4, 2]);
304
305        let selected = selector.get_support().unwrap();
306        assert_eq!(selected, &[0, 2]);
307
308        // Check variances
309        let variances = selector.variances().unwrap();
310        assert!(variances[0] > 0.1); // Feature 0 variance
311        assert!(variances[1] <= 0.1); // Feature 1 variance (should be low)
312        assert!(variances[2] > 0.1); // Feature 2 variance
313    }
314
315    #[test]
316    fn test_variance_threshold_support_mask() {
317        let data = Array::from_shape_vec(
318            (3, 4),
319            vec![1.0, 1.0, 5.0, 1.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 5.0, 5.0],
320        )
321        .unwrap();
322
323        let mut selector = VarianceThreshold::with_defaults();
324        selector.fit(&data).unwrap();
325
326        let mask = selector.get_support_mask().unwrap();
327        assert_eq!(mask.len(), 4);
328        assert!(!mask[0]); // Feature 0 is constant
329        assert!(mask[1]); // Feature 1 has variance
330        assert!(!mask[2]); // Feature 2 is constant
331        assert!(mask[3]); // Feature 3 has variance
332
333        assert_eq!(selector.n_features_selected().unwrap(), 2);
334    }
335
336    #[test]
337    fn test_variance_threshold_all_removed() {
338        // Create data where all features are constant
339        let data = Array::from_shape_vec((3, 2), vec![5.0, 10.0, 5.0, 10.0, 5.0, 10.0]).unwrap();
340
341        let mut selector = VarianceThreshold::with_defaults();
342        let transformed = selector.fit_transform(&data).unwrap();
343
344        // All features should be removed
345        assert_eq!(transformed.shape(), &[3, 0]);
346        assert_eq!(selector.n_features_selected().unwrap(), 0);
347    }
348
349    #[test]
350    fn test_variance_threshold_errors() {
351        // Test negative threshold
352        assert!(VarianceThreshold::new(-0.1).is_err());
353
354        // Test with insufficient samples
355        let small_data = Array::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
356        let mut selector = VarianceThreshold::with_defaults();
357        assert!(selector.fit(&small_data).is_err());
358
359        // Test transform before fit
360        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
361        let selector_unfitted = VarianceThreshold::with_defaults();
362        assert!(selector_unfitted.transform(&data).is_err());
363
364        // Test inverse transform (should always fail)
365        let mut selector = VarianceThreshold::with_defaults();
366        selector.fit(&data).unwrap();
367        assert!(selector.inverse_transform(&data).is_err());
368    }
369
370    #[test]
371    fn test_variance_threshold_feature_mismatch() {
372        let train_data =
373            Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
374                .unwrap();
375        let test_data = Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); // Different number of features
376
377        let mut selector = VarianceThreshold::with_defaults();
378        selector.fit(&train_data).unwrap();
379        assert!(selector.transform(&test_data).is_err());
380    }
381
382    #[test]
383    fn test_variance_calculation() {
384        // Test variance calculation manually
385        // Data: [1, 2, 3] should have variance = ((1-2)² + (2-2)² + (3-2)²) / 3 = (1 + 0 + 1) / 3 = 2/3
386        let data = Array::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
387
388        let mut selector = VarianceThreshold::with_defaults();
389        selector.fit(&data).unwrap();
390
391        let variances = selector.variances().unwrap();
392        let expected_variance = 2.0 / 3.0;
393        assert_abs_diff_eq!(variances[0], expected_variance, epsilon = 1e-10);
394    }
395}