Skip to main content

scry_learn/
feature_selection.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Feature selection transformers.
3//!
4//! Remove low-information features before training to reduce
5//! overfitting and speed up downstream models.
6//!
7//! # Examples
8//!
9//! ```ignore
10//! use scry_learn::prelude::*;
11//!
12//! let mut vt = VarianceThreshold::new().threshold(0.1);
13//! vt.fit(&data)?;
14//! vt.transform(&mut data)?; // drops constant / near-constant columns
15//! ```
16
17use crate::dataset::Dataset;
18use crate::error::{Result, ScryLearnError};
19use crate::preprocess::Transformer;
20
21// ---------------------------------------------------------------------------
22// VarianceThreshold
23// ---------------------------------------------------------------------------
24
25/// Remove features whose variance falls below a threshold.
26///
27/// By default, removes only constant features (threshold = 0.0).
28/// Useful as a lightweight first step before expensive feature selectors.
29///
30/// # Examples
31///
32/// ```
33/// use scry_learn::dataset::Dataset;
34/// use scry_learn::preprocess::Transformer;
35/// use scry_learn::feature_selection::VarianceThreshold;
36///
37/// let mut data = Dataset::new(
38///     vec![
39///         vec![1.0, 2.0, 3.0],  // variable
40///         vec![5.0, 5.0, 5.0],  // constant — will be removed
41///     ],
42///     vec![0.0, 1.0, 0.0],
43///     vec!["a".into(), "b".into()],
44///     "target",
45/// );
46///
47/// let mut vt = VarianceThreshold::new();
48/// vt.fit_transform(&mut data).unwrap();
49/// assert_eq!(data.n_features(), 1); // only "a" remains
50/// ```
51#[derive(Clone, Debug)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[non_exhaustive]
54pub struct VarianceThreshold {
55    threshold: f64,
56    variances_: Vec<f64>,
57    mask_: Vec<bool>,
58    fitted: bool,
59}
60
61impl VarianceThreshold {
62    /// Create a new selector with threshold 0.0 (remove only constants).
63    pub fn new() -> Self {
64        Self {
65            threshold: 0.0,
66            variances_: Vec::new(),
67            mask_: Vec::new(),
68            fitted: false,
69        }
70    }
71
72    /// Set the variance threshold.
73    ///
74    /// Features with variance ≤ this value are removed.
75    pub fn threshold(mut self, t: f64) -> Self {
76        self.threshold = t;
77        self
78    }
79
80    /// Per-feature variances computed during fit.
81    ///
82    /// # Panics
83    ///
84    /// Panics if called before [`VarianceThreshold::fit`].
85    pub fn variances(&self) -> &[f64] {
86        &self.variances_
87    }
88
89    /// Boolean mask of selected features.
90    ///
91    /// `true` at index `j` means feature `j` was kept.
92    pub fn get_support(&self) -> &[bool] {
93        &self.mask_
94    }
95}
96
97impl Default for VarianceThreshold {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl Transformer for VarianceThreshold {
104    fn fit(&mut self, data: &Dataset) -> Result<()> {
105        let n = data.n_samples();
106        if n == 0 {
107            return Err(ScryLearnError::EmptyDataset);
108        }
109        let nf = n as f64;
110
111        self.variances_ = Vec::with_capacity(data.n_features());
112        self.mask_ = Vec::with_capacity(data.n_features());
113
114        for col in &data.features {
115            let mean = col.iter().sum::<f64>() / nf;
116            let var = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / nf;
117            self.variances_.push(var);
118            self.mask_.push(var > self.threshold);
119        }
120
121        self.fitted = true;
122        Ok(())
123    }
124
125    fn transform(&self, data: &mut Dataset) -> Result<()> {
126        if !self.fitted {
127            return Err(ScryLearnError::NotFitted);
128        }
129        filter_features(data, &self.mask_);
130        Ok(())
131    }
132
133    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
134        Err(ScryLearnError::InvalidParameter(
135            "VarianceThreshold is not invertible — dropped columns cannot be restored".into(),
136        ))
137    }
138}
139
140// ---------------------------------------------------------------------------
141// SelectKBest
142// ---------------------------------------------------------------------------
143
144/// Scoring function for feature selection.
145///
146/// Determines how each feature is scored relative to the target.
147#[derive(Clone, Debug)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[non_exhaustive]
150pub enum ScoreFn {
151    /// ANOVA F-value: ratio of between-group variance to within-group variance.
152    ///
153    /// Best suited for classification tasks where higher F-values indicate
154    /// features that separate classes well.
155    FClassif,
156}
157
158/// Select the top-k highest-scoring features.
159///
160/// Uses a scoring function (e.g. ANOVA F-value) to rank features by their
161/// discriminative power, then keeps only the `k` best.
162///
163/// # Examples
164///
165/// ```ignore
166/// use scry_learn::prelude::*;
167///
168/// let mut sel = SelectKBest::new(ScoreFn::FClassif).k(2);
169/// sel.fit(&data)?;
170/// sel.transform(&mut data)?;
171/// assert_eq!(data.n_features(), 2);
172/// ```
173#[derive(Clone, Debug)]
174#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
175#[non_exhaustive]
176pub struct SelectKBest {
177    k: usize,
178    score_fn: ScoreFn,
179    scores_: Vec<f64>,
180    mask_: Vec<bool>,
181    fitted: bool,
182}
183
184impl SelectKBest {
185    /// Create a new selector with the given scoring function.
186    ///
187    /// Default: keep top 10 features.
188    pub fn new(score_fn: ScoreFn) -> Self {
189        Self {
190            k: 10,
191            score_fn,
192            scores_: Vec::new(),
193            mask_: Vec::new(),
194            fitted: false,
195        }
196    }
197
198    /// Set the number of top features to keep.
199    pub fn k(mut self, k: usize) -> Self {
200        self.k = k;
201        self
202    }
203
204    /// Per-feature scores computed during fit.
205    ///
206    /// Higher values indicate more discriminative features.
207    pub fn scores(&self) -> &[f64] {
208        &self.scores_
209    }
210
211    /// Boolean mask of selected features.
212    ///
213    /// `true` at index `j` means feature `j` was kept.
214    pub fn get_support(&self) -> &[bool] {
215        &self.mask_
216    }
217}
218
219impl Transformer for SelectKBest {
220    fn fit(&mut self, data: &Dataset) -> Result<()> {
221        let n = data.n_samples();
222        if n == 0 {
223            return Err(ScryLearnError::EmptyDataset);
224        }
225
226        self.scores_ = match self.score_fn {
227            ScoreFn::FClassif => f_classif(data),
228        };
229
230        let k = self.k.min(data.n_features());
231
232        // Find the k-th highest score to determine the cutoff.
233        let mut sorted_scores: Vec<f64> = self.scores_.clone();
234        sorted_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
235        let cutoff = if k > 0 && k <= sorted_scores.len() {
236            sorted_scores[k - 1]
237        } else {
238            f64::NEG_INFINITY
239        };
240
241        // Build mask: keep features with score >= cutoff, but cap at k.
242        self.mask_ = vec![false; self.scores_.len()];
243        let mut kept = 0;
244        // First pass: mark features with score > cutoff.
245        for (i, &score) in self.scores_.iter().enumerate() {
246            if score > cutoff && kept < k {
247                self.mask_[i] = true;
248                kept += 1;
249            }
250        }
251        // Second pass: fill remaining slots with features exactly at cutoff.
252        for (i, &score) in self.scores_.iter().enumerate() {
253            if kept >= k {
254                break;
255            }
256            if !self.mask_[i] && (score - cutoff).abs() < 1e-12 {
257                self.mask_[i] = true;
258                kept += 1;
259            }
260        }
261
262        self.fitted = true;
263        Ok(())
264    }
265
266    fn transform(&self, data: &mut Dataset) -> Result<()> {
267        if !self.fitted {
268            return Err(ScryLearnError::NotFitted);
269        }
270        filter_features(data, &self.mask_);
271        Ok(())
272    }
273
274    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
275        Err(ScryLearnError::InvalidParameter(
276            "SelectKBest is not invertible — dropped columns cannot be restored".into(),
277        ))
278    }
279}
280
281// ---------------------------------------------------------------------------
282// ANOVA F-value (f_classif)
283// ---------------------------------------------------------------------------
284
285/// Compute ANOVA F-value for each feature vs. the target.
286///
287/// The F-value is the ratio of between-group variance to within-group
288/// variance. Higher F-values indicate features that separate classes well.
289///
290/// # Examples
291///
292/// ```ignore
293/// let scores = f_classif(&data);
294/// // scores[j] is the F-value for feature j
295/// ```
296pub fn f_classif(data: &Dataset) -> Vec<f64> {
297    let n = data.n_samples();
298    let n_features = data.n_features();
299
300    // Identify unique classes.
301    let mut class_set: Vec<i64> = data.target.iter().map(|&v| v as i64).collect();
302    class_set.sort_unstable();
303    class_set.dedup();
304    let n_classes = class_set.len();
305
306    if n_classes < 2 {
307        return vec![0.0; n_features];
308    }
309
310    // Build class membership lookup.
311    let class_indices: Vec<Vec<usize>> = class_set
312        .iter()
313        .map(|&c| (0..n).filter(|&i| data.target[i] as i64 == c).collect())
314        .collect();
315
316    let mut f_values = Vec::with_capacity(n_features);
317
318    for j in 0..n_features {
319        let col = &data.features[j];
320        let grand_mean = col.iter().sum::<f64>() / n as f64;
321
322        // Between-group sum of squares.
323        let mut ss_between = 0.0;
324        // Within-group sum of squares.
325        let mut ss_within = 0.0;
326
327        for group in &class_indices {
328            let n_g = group.len() as f64;
329            if n_g == 0.0 {
330                continue;
331            }
332            let group_mean = group.iter().map(|&i| col[i]).sum::<f64>() / n_g;
333            ss_between += n_g * (group_mean - grand_mean).powi(2);
334
335            for &i in group {
336                ss_within += (col[i] - group_mean).powi(2);
337            }
338        }
339
340        let df_between = (n_classes - 1) as f64;
341        let df_within = (n - n_classes) as f64;
342
343        let f_val = if df_within > 0.0 && ss_within > 1e-15 {
344            (ss_between / df_between) / (ss_within / df_within)
345        } else if ss_between > 1e-15 {
346            // Perfect separation: zero within-group variance.
347            f64::MAX
348        } else {
349            0.0
350        };
351
352        f_values.push(f_val);
353    }
354
355    f_values
356}
357
358// ---------------------------------------------------------------------------
359// Helpers
360// ---------------------------------------------------------------------------
361
362/// Filter a dataset's features and feature_names using a boolean mask.
363fn filter_features(data: &mut Dataset, mask: &[bool]) {
364    let mut new_features = Vec::new();
365    let mut new_names = Vec::new();
366
367    for (j, &keep) in mask.iter().enumerate() {
368        if keep {
369            new_features.push(data.features[j].clone());
370            new_names.push(data.feature_names[j].clone());
371        }
372    }
373
374    data.features = new_features;
375    data.feature_names = new_names;
376    data.sync_matrix();
377}
378
379// ---------------------------------------------------------------------------
380// Tests
381// ---------------------------------------------------------------------------
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::pipeline::Pipeline;
387    use crate::preprocess::StandardScaler;
388    use crate::tree::DecisionTreeClassifier;
389
390    /// Iris-like dataset where petal features (f2, f3) are much more
391    /// discriminative than sepal features (f0, f1).
392    fn iris_like() -> Dataset {
393        let n_per_class = 30;
394        let n = n_per_class * 3;
395        let mut f0 = Vec::with_capacity(n);
396        let mut f1 = Vec::with_capacity(n);
397        let mut f2 = Vec::with_capacity(n);
398        let mut f3 = Vec::with_capacity(n);
399        let mut target = Vec::with_capacity(n);
400
401        let mut rng = crate::rng::FastRng::new(123);
402
403        for _ in 0..n_per_class {
404            // Class 0
405            f0.push(5.0 + rng.f64() * 0.5); // overlapping sepal
406            f1.push(3.4 + rng.f64() * 0.4); // overlapping sepal
407            f2.push(1.0 + rng.f64() * 0.5); // small petal — discriminative
408            f3.push(0.1 + rng.f64() * 0.2); // small petal — discriminative
409            target.push(0.0);
410        }
411        for _ in 0..n_per_class {
412            // Class 1
413            f0.push(5.5 + rng.f64() * 0.8); // overlapping sepal
414            f1.push(2.5 + rng.f64() * 0.5); // overlapping sepal
415            f2.push(4.0 + rng.f64() * 0.5); // medium petal
416            f3.push(1.2 + rng.f64() * 0.3); // medium petal
417            target.push(1.0);
418        }
419        for _ in 0..n_per_class {
420            // Class 2
421            f0.push(6.0 + rng.f64() * 1.0); // overlapping sepal
422            f1.push(2.8 + rng.f64() * 0.5); // overlapping sepal
423            f2.push(5.5 + rng.f64() * 0.5); // large petal
424            f3.push(2.0 + rng.f64() * 0.3); // large petal
425            target.push(2.0);
426        }
427
428        Dataset::new(
429            vec![f0, f1, f2, f3],
430            target,
431            vec![
432                "sepal_len".into(),
433                "sepal_wid".into(),
434                "petal_len".into(),
435                "petal_wid".into(),
436            ],
437            "species",
438        )
439    }
440
441    #[test]
442    fn test_variance_threshold_removes_constant() {
443        let mut data = Dataset::new(
444            vec![
445                vec![1.0, 2.0, 3.0, 4.0], // variable
446                vec![5.0, 5.0, 5.0, 5.0], // constant → removed
447                vec![0.0, 1.0, 0.0, 1.0], // variable
448            ],
449            vec![0.0, 1.0, 0.0, 1.0],
450            vec!["a".into(), "b".into(), "c".into()],
451            "t",
452        );
453
454        let mut vt = VarianceThreshold::new();
455        vt.fit_transform(&mut data).unwrap();
456
457        assert_eq!(data.n_features(), 2);
458        assert_eq!(data.feature_names, vec!["a", "c"]);
459    }
460
461    #[test]
462    fn test_variance_threshold_custom() {
463        let mut data = Dataset::new(
464            vec![
465                vec![1.0, 1.0, 1.0, 1.1],   // variance ≈ 0.0019
466                vec![0.0, 10.0, 0.0, 10.0], // variance = 25
467            ],
468            vec![0.0; 4],
469            vec!["low_var".into(), "high_var".into()],
470            "t",
471        );
472
473        let mut vt = VarianceThreshold::new().threshold(0.01);
474        vt.fit_transform(&mut data).unwrap();
475
476        assert_eq!(data.n_features(), 1);
477        assert_eq!(data.feature_names, vec!["high_var"]);
478    }
479
480    #[test]
481    fn test_select_k_best_petal_features_rank_highest() {
482        let data = iris_like();
483
484        let mut sel = SelectKBest::new(ScoreFn::FClassif).k(2);
485        sel.fit(&data).unwrap();
486
487        let scores = sel.scores();
488        // Petal features (indices 2, 3) should have higher F-values
489        // than sepal features (indices 0, 1).
490        assert!(
491            scores[2] > scores[0],
492            "petal_len ({:.1}) should rank higher than sepal_len ({:.1})",
493            scores[2],
494            scores[0]
495        );
496        assert!(
497            scores[3] > scores[1],
498            "petal_wid ({:.1}) should rank higher than sepal_wid ({:.1})",
499            scores[3],
500            scores[1]
501        );
502
503        // After transform, only 2 features remain.
504        let mut data_copy = data.clone();
505        sel.transform(&mut data_copy).unwrap();
506        assert_eq!(data_copy.n_features(), 2);
507
508        // The kept features should be petal_len and petal_wid.
509        let support = sel.get_support();
510        assert!(!support[0], "sepal_len should be dropped");
511        assert!(!support[1], "sepal_wid should be dropped");
512        assert!(support[2], "petal_len should be kept");
513        assert!(support[3], "petal_wid should be kept");
514    }
515
516    #[test]
517    fn test_select_k_best_not_fitted() {
518        let sel = SelectKBest::new(ScoreFn::FClassif);
519        let mut data = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "t");
520        assert!(sel.transform(&mut data).is_err());
521    }
522
523    #[test]
524    fn test_f_classif_basic() {
525        // One perfectly discriminative feature, one random.
526        let data = Dataset::new(
527            vec![
528                vec![1.0, 1.0, 1.0, 10.0, 10.0, 10.0], // perfect separator
529                vec![3.0, 7.0, 2.0, 5.0, 8.0, 1.0],    // noise
530            ],
531            vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
532            vec!["good".into(), "noise".into()],
533            "class",
534        );
535
536        let scores = f_classif(&data);
537        assert!(
538            scores[0] > scores[1],
539            "good feature ({:.1}) should have higher F-value than noise ({:.1})",
540            scores[0],
541            scores[1]
542        );
543    }
544
545    #[test]
546    fn test_pipeline_vt_scaler_dt() {
547        // End-to-end: VarianceThreshold → StandardScaler → DecisionTree.
548        let features = vec![
549            vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], // discriminative
550            vec![5.0, 5.0, 5.0, 5.0, 5.0, 5.0],    // constant → removed
551            vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0],    // discriminative
552        ];
553        let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
554        let data = Dataset::new(
555            features,
556            target,
557            vec!["a".into(), "b".into(), "c".into()],
558            "class",
559        );
560
561        let mut pipeline = Pipeline::new()
562            .add_transformer(VarianceThreshold::new())
563            .add_transformer(StandardScaler::new())
564            .set_model(DecisionTreeClassifier::new());
565
566        pipeline.fit(&data).unwrap();
567        let preds = pipeline.predict(&data).unwrap();
568        assert_eq!(preds.len(), 6);
569    }
570}