1use crate::dataset::Dataset;
18use crate::error::{Result, ScryLearnError};
19use crate::preprocess::Transformer;
20
21#[derive(Clone, Debug)]
52#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
53#[non_exhaustive]
54pub struct VarianceThreshold {
55 threshold: f64,
56 variances_: Vec<f64>,
57 mask_: Vec<bool>,
58 fitted: bool,
59}
60
61impl VarianceThreshold {
62 pub fn new() -> Self {
64 Self {
65 threshold: 0.0,
66 variances_: Vec::new(),
67 mask_: Vec::new(),
68 fitted: false,
69 }
70 }
71
72 pub fn threshold(mut self, t: f64) -> Self {
76 self.threshold = t;
77 self
78 }
79
80 pub fn variances(&self) -> &[f64] {
86 &self.variances_
87 }
88
89 pub fn get_support(&self) -> &[bool] {
93 &self.mask_
94 }
95}
96
97impl Default for VarianceThreshold {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl Transformer for VarianceThreshold {
104 fn fit(&mut self, data: &Dataset) -> Result<()> {
105 let n = data.n_samples();
106 if n == 0 {
107 return Err(ScryLearnError::EmptyDataset);
108 }
109 let nf = n as f64;
110
111 self.variances_ = Vec::with_capacity(data.n_features());
112 self.mask_ = Vec::with_capacity(data.n_features());
113
114 for col in &data.features {
115 let mean = col.iter().sum::<f64>() / nf;
116 let var = col.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / nf;
117 self.variances_.push(var);
118 self.mask_.push(var > self.threshold);
119 }
120
121 self.fitted = true;
122 Ok(())
123 }
124
125 fn transform(&self, data: &mut Dataset) -> Result<()> {
126 if !self.fitted {
127 return Err(ScryLearnError::NotFitted);
128 }
129 filter_features(data, &self.mask_);
130 Ok(())
131 }
132
133 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
134 Err(ScryLearnError::InvalidParameter(
135 "VarianceThreshold is not invertible — dropped columns cannot be restored".into(),
136 ))
137 }
138}
139
140#[derive(Clone, Debug)]
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[non_exhaustive]
150pub enum ScoreFn {
151 FClassif,
156}
157
158#[derive(Clone, Debug)]
174#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
175#[non_exhaustive]
176pub struct SelectKBest {
177 k: usize,
178 score_fn: ScoreFn,
179 scores_: Vec<f64>,
180 mask_: Vec<bool>,
181 fitted: bool,
182}
183
184impl SelectKBest {
185 pub fn new(score_fn: ScoreFn) -> Self {
189 Self {
190 k: 10,
191 score_fn,
192 scores_: Vec::new(),
193 mask_: Vec::new(),
194 fitted: false,
195 }
196 }
197
198 pub fn k(mut self, k: usize) -> Self {
200 self.k = k;
201 self
202 }
203
204 pub fn scores(&self) -> &[f64] {
208 &self.scores_
209 }
210
211 pub fn get_support(&self) -> &[bool] {
215 &self.mask_
216 }
217}
218
219impl Transformer for SelectKBest {
220 fn fit(&mut self, data: &Dataset) -> Result<()> {
221 let n = data.n_samples();
222 if n == 0 {
223 return Err(ScryLearnError::EmptyDataset);
224 }
225
226 self.scores_ = match self.score_fn {
227 ScoreFn::FClassif => f_classif(data),
228 };
229
230 let k = self.k.min(data.n_features());
231
232 let mut sorted_scores: Vec<f64> = self.scores_.clone();
234 sorted_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
235 let cutoff = if k > 0 && k <= sorted_scores.len() {
236 sorted_scores[k - 1]
237 } else {
238 f64::NEG_INFINITY
239 };
240
241 self.mask_ = vec![false; self.scores_.len()];
243 let mut kept = 0;
244 for (i, &score) in self.scores_.iter().enumerate() {
246 if score > cutoff && kept < k {
247 self.mask_[i] = true;
248 kept += 1;
249 }
250 }
251 for (i, &score) in self.scores_.iter().enumerate() {
253 if kept >= k {
254 break;
255 }
256 if !self.mask_[i] && (score - cutoff).abs() < 1e-12 {
257 self.mask_[i] = true;
258 kept += 1;
259 }
260 }
261
262 self.fitted = true;
263 Ok(())
264 }
265
266 fn transform(&self, data: &mut Dataset) -> Result<()> {
267 if !self.fitted {
268 return Err(ScryLearnError::NotFitted);
269 }
270 filter_features(data, &self.mask_);
271 Ok(())
272 }
273
274 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
275 Err(ScryLearnError::InvalidParameter(
276 "SelectKBest is not invertible — dropped columns cannot be restored".into(),
277 ))
278 }
279}
280
281pub fn f_classif(data: &Dataset) -> Vec<f64> {
297 let n = data.n_samples();
298 let n_features = data.n_features();
299
300 let mut class_set: Vec<i64> = data.target.iter().map(|&v| v as i64).collect();
302 class_set.sort_unstable();
303 class_set.dedup();
304 let n_classes = class_set.len();
305
306 if n_classes < 2 {
307 return vec![0.0; n_features];
308 }
309
310 let class_indices: Vec<Vec<usize>> = class_set
312 .iter()
313 .map(|&c| (0..n).filter(|&i| data.target[i] as i64 == c).collect())
314 .collect();
315
316 let mut f_values = Vec::with_capacity(n_features);
317
318 for j in 0..n_features {
319 let col = &data.features[j];
320 let grand_mean = col.iter().sum::<f64>() / n as f64;
321
322 let mut ss_between = 0.0;
324 let mut ss_within = 0.0;
326
327 for group in &class_indices {
328 let n_g = group.len() as f64;
329 if n_g == 0.0 {
330 continue;
331 }
332 let group_mean = group.iter().map(|&i| col[i]).sum::<f64>() / n_g;
333 ss_between += n_g * (group_mean - grand_mean).powi(2);
334
335 for &i in group {
336 ss_within += (col[i] - group_mean).powi(2);
337 }
338 }
339
340 let df_between = (n_classes - 1) as f64;
341 let df_within = (n - n_classes) as f64;
342
343 let f_val = if df_within > 0.0 && ss_within > 1e-15 {
344 (ss_between / df_between) / (ss_within / df_within)
345 } else if ss_between > 1e-15 {
346 f64::MAX
348 } else {
349 0.0
350 };
351
352 f_values.push(f_val);
353 }
354
355 f_values
356}
357
358fn filter_features(data: &mut Dataset, mask: &[bool]) {
364 let mut new_features = Vec::new();
365 let mut new_names = Vec::new();
366
367 for (j, &keep) in mask.iter().enumerate() {
368 if keep {
369 new_features.push(data.features[j].clone());
370 new_names.push(data.feature_names[j].clone());
371 }
372 }
373
374 data.features = new_features;
375 data.feature_names = new_names;
376 data.sync_matrix();
377}
378
379#[cfg(test)]
384mod tests {
385 use super::*;
386 use crate::pipeline::Pipeline;
387 use crate::preprocess::StandardScaler;
388 use crate::tree::DecisionTreeClassifier;
389
390 fn iris_like() -> Dataset {
393 let n_per_class = 30;
394 let n = n_per_class * 3;
395 let mut f0 = Vec::with_capacity(n);
396 let mut f1 = Vec::with_capacity(n);
397 let mut f2 = Vec::with_capacity(n);
398 let mut f3 = Vec::with_capacity(n);
399 let mut target = Vec::with_capacity(n);
400
401 let mut rng = crate::rng::FastRng::new(123);
402
403 for _ in 0..n_per_class {
404 f0.push(5.0 + rng.f64() * 0.5); f1.push(3.4 + rng.f64() * 0.4); f2.push(1.0 + rng.f64() * 0.5); f3.push(0.1 + rng.f64() * 0.2); target.push(0.0);
410 }
411 for _ in 0..n_per_class {
412 f0.push(5.5 + rng.f64() * 0.8); f1.push(2.5 + rng.f64() * 0.5); f2.push(4.0 + rng.f64() * 0.5); f3.push(1.2 + rng.f64() * 0.3); target.push(1.0);
418 }
419 for _ in 0..n_per_class {
420 f0.push(6.0 + rng.f64() * 1.0); f1.push(2.8 + rng.f64() * 0.5); f2.push(5.5 + rng.f64() * 0.5); f3.push(2.0 + rng.f64() * 0.3); target.push(2.0);
426 }
427
428 Dataset::new(
429 vec![f0, f1, f2, f3],
430 target,
431 vec![
432 "sepal_len".into(),
433 "sepal_wid".into(),
434 "petal_len".into(),
435 "petal_wid".into(),
436 ],
437 "species",
438 )
439 }
440
441 #[test]
442 fn test_variance_threshold_removes_constant() {
443 let mut data = Dataset::new(
444 vec![
445 vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 5.0, 5.0, 5.0], vec![0.0, 1.0, 0.0, 1.0], ],
449 vec![0.0, 1.0, 0.0, 1.0],
450 vec!["a".into(), "b".into(), "c".into()],
451 "t",
452 );
453
454 let mut vt = VarianceThreshold::new();
455 vt.fit_transform(&mut data).unwrap();
456
457 assert_eq!(data.n_features(), 2);
458 assert_eq!(data.feature_names, vec!["a", "c"]);
459 }
460
461 #[test]
462 fn test_variance_threshold_custom() {
463 let mut data = Dataset::new(
464 vec![
465 vec![1.0, 1.0, 1.0, 1.1], vec![0.0, 10.0, 0.0, 10.0], ],
468 vec![0.0; 4],
469 vec!["low_var".into(), "high_var".into()],
470 "t",
471 );
472
473 let mut vt = VarianceThreshold::new().threshold(0.01);
474 vt.fit_transform(&mut data).unwrap();
475
476 assert_eq!(data.n_features(), 1);
477 assert_eq!(data.feature_names, vec!["high_var"]);
478 }
479
480 #[test]
481 fn test_select_k_best_petal_features_rank_highest() {
482 let data = iris_like();
483
484 let mut sel = SelectKBest::new(ScoreFn::FClassif).k(2);
485 sel.fit(&data).unwrap();
486
487 let scores = sel.scores();
488 assert!(
491 scores[2] > scores[0],
492 "petal_len ({:.1}) should rank higher than sepal_len ({:.1})",
493 scores[2],
494 scores[0]
495 );
496 assert!(
497 scores[3] > scores[1],
498 "petal_wid ({:.1}) should rank higher than sepal_wid ({:.1})",
499 scores[3],
500 scores[1]
501 );
502
503 let mut data_copy = data.clone();
505 sel.transform(&mut data_copy).unwrap();
506 assert_eq!(data_copy.n_features(), 2);
507
508 let support = sel.get_support();
510 assert!(!support[0], "sepal_len should be dropped");
511 assert!(!support[1], "sepal_wid should be dropped");
512 assert!(support[2], "petal_len should be kept");
513 assert!(support[3], "petal_wid should be kept");
514 }
515
516 #[test]
517 fn test_select_k_best_not_fitted() {
518 let sel = SelectKBest::new(ScoreFn::FClassif);
519 let mut data = Dataset::new(vec![vec![1.0]], vec![0.0], vec!["x".into()], "t");
520 assert!(sel.transform(&mut data).is_err());
521 }
522
523 #[test]
524 fn test_f_classif_basic() {
525 let data = Dataset::new(
527 vec![
528 vec![1.0, 1.0, 1.0, 10.0, 10.0, 10.0], vec![3.0, 7.0, 2.0, 5.0, 8.0, 1.0], ],
531 vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
532 vec!["good".into(), "noise".into()],
533 "class",
534 );
535
536 let scores = f_classif(&data);
537 assert!(
538 scores[0] > scores[1],
539 "good feature ({:.1}) should have higher F-value than noise ({:.1})",
540 scores[0],
541 scores[1]
542 );
543 }
544
545 #[test]
546 fn test_pipeline_vt_scaler_dt() {
547 let features = vec![
549 vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0], vec![5.0, 5.0, 5.0, 5.0, 5.0, 5.0], vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0], ];
553 let target = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
554 let data = Dataset::new(
555 features,
556 target,
557 vec!["a".into(), "b".into(), "c".into()],
558 "class",
559 );
560
561 let mut pipeline = Pipeline::new()
562 .add_transformer(VarianceThreshold::new())
563 .add_transformer(StandardScaler::new())
564 .set_model(DecisionTreeClassifier::new());
565
566 pipeline.fit(&data).unwrap();
567 let preds = pipeline.predict(&data).unwrap();
568 assert_eq!(preds.len(), 6);
569 }
570}