1use crate::lda::{LinearDiscriminantAnalysis, LinearDiscriminantAnalysisConfig};
8use scirs2_core::ndarray::{Array1, Array2, Axis};
10use sklears_core::{
11 error::Result,
12 prelude::SklearsError,
13 traits::{Estimator, Fit, Predict, PredictProba, Trained, Transform},
14 types::Float,
15};
16use std::collections::HashSet;
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum SelectionDirection {
21 Forward,
23 Backward,
25}
26
27#[derive(Debug, Clone)]
29pub struct SequentialFeatureSelectionConfig {
30 pub direction: SelectionDirection,
32 pub n_features_to_select: Option<usize>,
34 pub n_features_fraction: Option<Float>,
36 pub scoring: String,
38 pub cv: usize,
40 pub tol: Float,
42 pub estimator_config: LinearDiscriminantAnalysisConfig,
44 pub random_state: Option<u64>,
46 pub verbose: bool,
48}
49
50impl Default for SequentialFeatureSelectionConfig {
51 fn default() -> Self {
52 Self {
53 direction: SelectionDirection::Forward,
54 n_features_to_select: None,
55 n_features_fraction: Some(0.5),
56 scoring: "accuracy".to_string(),
57 cv: 5,
58 tol: 1e-4,
59 estimator_config: LinearDiscriminantAnalysisConfig::default(),
60 random_state: None,
61 verbose: false,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct SequentialFeatureSelection {
69 config: SequentialFeatureSelectionConfig,
70}
71
72#[derive(Debug, Clone)]
74pub struct TrainedSequentialFeatureSelection {
75 support: Array1<bool>,
77 selection_path: Vec<usize>,
79 scores: Vec<Float>,
81 estimator: LinearDiscriminantAnalysis<Trained>,
83 n_features_in: usize,
85 config: SequentialFeatureSelectionConfig,
87}
88
89#[derive(Debug, Clone)]
91pub struct SelectionStep {
92 pub feature_idx: usize,
94 pub current_features: Vec<usize>,
96 pub score: Float,
98 pub improvement: Float,
100}
101
102impl Default for SequentialFeatureSelection {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl SequentialFeatureSelection {
109 pub fn new() -> Self {
111 Self {
112 config: SequentialFeatureSelectionConfig::default(),
113 }
114 }
115
116 pub fn direction(mut self, direction: SelectionDirection) -> Self {
118 self.config.direction = direction;
119 self
120 }
121
122 pub fn n_features_to_select(mut self, n_features: usize) -> Self {
124 self.config.n_features_to_select = Some(n_features);
125 self.config.n_features_fraction = None;
126 self
127 }
128
129 pub fn n_features_fraction(mut self, fraction: Float) -> Self {
131 self.config.n_features_fraction = Some(fraction.max(0.0).min(1.0));
132 self.config.n_features_to_select = None;
133 self
134 }
135
136 pub fn scoring(mut self, scoring: &str) -> Self {
138 self.config.scoring = scoring.to_string();
139 self
140 }
141
142 pub fn cv(mut self, cv_folds: usize) -> Self {
144 self.config.cv = cv_folds.max(2);
145 self
146 }
147
148 pub fn tol(mut self, tol: Float) -> Self {
150 self.config.tol = tol;
151 self
152 }
153
154 pub fn estimator_config(mut self, config: LinearDiscriminantAnalysisConfig) -> Self {
156 self.config.estimator_config = config;
157 self
158 }
159
160 pub fn random_state(mut self, seed: u64) -> Self {
162 self.config.random_state = Some(seed);
163 self
164 }
165
166 pub fn verbose(mut self, verbose: bool) -> Self {
168 self.config.verbose = verbose;
169 self
170 }
171
172 fn cross_validate(
174 &self,
175 x: &Array2<Float>,
176 y: &Array1<i32>,
177 feature_indices: &[usize],
178 ) -> Result<Float> {
179 if feature_indices.is_empty() {
180 return Ok(0.0);
181 }
182
183 let cv_folds = self.config.cv;
184 let n_samples = x.nrows();
185 let fold_size = n_samples / cv_folds;
186 let mut scores = Vec::new();
187
188 for fold in 0..cv_folds {
189 let test_start = fold * fold_size;
190 let test_end = if fold == cv_folds - 1 {
191 n_samples
192 } else {
193 (fold + 1) * fold_size
194 };
195
196 let mut train_indices = Vec::new();
198 let mut test_indices = Vec::new();
199
200 for i in 0..n_samples {
201 if i >= test_start && i < test_end {
202 test_indices.push(i);
203 } else {
204 train_indices.push(i);
205 }
206 }
207
208 if train_indices.is_empty() || test_indices.is_empty() {
209 continue;
210 }
211
212 let x_train = x.select(Axis(0), &train_indices);
214 let x_train_selected = x_train.select(Axis(1), feature_indices);
215 let y_train = y.select(Axis(0), &train_indices);
216
217 let x_test = x.select(Axis(0), &test_indices);
218 let x_test_selected = x_test.select(Axis(1), feature_indices);
219 let y_test = y.select(Axis(0), &test_indices);
220
221 let estimator = LinearDiscriminantAnalysis::new();
223
224 if let Ok(fitted) = estimator.fit(&x_train_selected, &y_train) {
225 let score = match self.config.scoring.as_str() {
226 "accuracy" => {
227 if let Ok(predictions) = fitted.predict(&x_test_selected) {
228 let correct = predictions
229 .iter()
230 .zip(y_test.iter())
231 .filter(|(&pred, &true_val)| pred == true_val)
232 .count();
233 correct as Float / y_test.len() as Float
234 } else {
235 0.0
236 }
237 }
238 "neg_log_loss" => {
239 if let Ok(probas) = fitted.predict_proba(&x_test_selected) {
240 let mut log_loss = 0.0;
242 let classes = fitted.classes();
243
244 for (i, &true_label) in y_test.iter().enumerate() {
245 if let Some(class_idx) =
246 classes.iter().position(|&c| c == true_label)
247 {
248 let prob = probas[[i, class_idx]].max(1e-15); log_loss -= prob.ln();
250 }
251 }
252 -log_loss / y_test.len() as Float
253 } else {
254 0.0
255 }
256 }
257 _ => {
258 if let Ok(predictions) = fitted.predict(&x_test_selected) {
260 let correct = predictions
261 .iter()
262 .zip(y_test.iter())
263 .filter(|(&pred, &true_val)| pred == true_val)
264 .count();
265 correct as Float / y_test.len() as Float
266 } else {
267 0.0
268 }
269 }
270 };
271
272 scores.push(score);
273 }
274 }
275
276 if scores.is_empty() {
277 return Ok(0.0);
278 }
279
280 let mean_score = scores.iter().sum::<Float>() / scores.len() as Float;
282 Ok(mean_score)
283 }
284
285 fn forward_selection(
287 &self,
288 x: &Array2<Float>,
289 y: &Array1<i32>,
290 n_features_to_select: usize,
291 ) -> Result<(Vec<usize>, Vec<Float>)> {
292 let n_features = x.ncols();
293 let mut selected_features: HashSet<usize> = HashSet::new();
294 let mut selection_path = Vec::new();
295 let mut scores = Vec::new();
296 let mut best_score = 0.0;
297
298 if self.config.verbose {
299 println!(
300 "Starting forward selection to select {} features",
301 n_features_to_select
302 );
303 }
304
305 for step in 0..n_features_to_select {
306 let mut best_feature = None;
307 let mut best_step_score = -Float::INFINITY;
308
309 for feature_idx in 0..n_features {
311 if selected_features.contains(&feature_idx) {
312 continue;
313 }
314
315 let mut candidate_features: Vec<usize> =
317 selected_features.iter().cloned().collect();
318 candidate_features.push(feature_idx);
319 candidate_features.sort_unstable();
320
321 let score = self.cross_validate(x, y, &candidate_features)?;
323
324 if score > best_step_score {
325 best_step_score = score;
326 best_feature = Some(feature_idx);
327 }
328 }
329
330 if let Some(feature_idx) = best_feature {
332 selected_features.insert(feature_idx);
333 selection_path.push(feature_idx);
334 scores.push(best_step_score);
335
336 let improvement = best_step_score - best_score;
337 best_score = best_step_score;
338
339 if self.config.verbose {
340 println!(
341 "Step {}: Added feature {}, score: {:.4}, improvement: {:.4}",
342 step + 1,
343 feature_idx,
344 best_step_score,
345 improvement
346 );
347 }
348
349 if step > 1 && improvement < self.config.tol && improvement < 0.0 {
351 if self.config.verbose {
352 println!(
353 "Early stopping: improvement {:.6} < tolerance {:.6}",
354 improvement, self.config.tol
355 );
356 }
357 break;
358 }
359 } else {
360 break; }
362 }
363
364 Ok((selection_path, scores))
365 }
366
367 fn backward_elimination(
369 &self,
370 x: &Array2<Float>,
371 y: &Array1<i32>,
372 n_features_to_select: usize,
373 ) -> Result<(Vec<usize>, Vec<Float>)> {
374 let n_features = x.ncols();
375 let mut remaining_features: HashSet<usize> = (0..n_features).collect();
376 let mut elimination_path = Vec::new();
377 let mut scores = Vec::new();
378
379 let all_features: Vec<usize> = (0..n_features).collect();
381 let mut best_score = self.cross_validate(x, y, &all_features)?;
382 scores.push(best_score);
383
384 if self.config.verbose {
385 println!(
386 "Starting backward elimination from {} to {} features",
387 n_features, n_features_to_select
388 );
389 println!("Initial score with all features: {:.4}", best_score);
390 }
391
392 while remaining_features.len() > n_features_to_select {
394 let mut worst_feature = None;
395 let mut best_step_score = -Float::INFINITY;
396
397 for &feature_idx in &remaining_features {
399 let candidate_features: Vec<usize> = remaining_features
401 .iter()
402 .filter(|&&f| f != feature_idx)
403 .cloned()
404 .collect();
405
406 if candidate_features.is_empty() {
407 continue;
408 }
409
410 let score = self.cross_validate(x, y, &candidate_features)?;
412
413 if score > best_step_score {
414 best_step_score = score;
415 worst_feature = Some(feature_idx);
416 }
417 }
418
419 if let Some(feature_idx) = worst_feature {
421 remaining_features.remove(&feature_idx);
422 elimination_path.push(feature_idx);
423 scores.push(best_step_score);
424
425 let improvement = best_step_score - best_score;
426 best_score = best_step_score;
427
428 if self.config.verbose {
429 println!(
430 "Step {}: Removed feature {}, score: {:.4}, improvement: {:.4}",
431 elimination_path.len(),
432 feature_idx,
433 best_step_score,
434 improvement
435 );
436 }
437
438 if improvement < -self.config.tol {
440 if self.config.verbose {
441 println!(
442 "Early stopping: performance degradation {:.6} > tolerance {:.6}",
443 -improvement, self.config.tol
444 );
445 }
446 remaining_features.insert(feature_idx);
448 elimination_path.pop();
449 scores.pop();
450 break;
451 }
452 } else {
453 break; }
455 }
456
457 let remaining_features_vec: Vec<usize> = remaining_features.into_iter().collect();
459
460 Ok((remaining_features_vec, scores))
461 }
462}
463
464impl Estimator for SequentialFeatureSelection {
465 type Config = SequentialFeatureSelectionConfig;
466 type Error = SklearsError;
467 type Float = Float;
468
469 fn config(&self) -> &Self::Config {
470 &self.config
471 }
472}
473
474impl Fit<Array2<Float>, Array1<i32>> for SequentialFeatureSelection {
475 type Fitted = TrainedSequentialFeatureSelection;
476
477 fn fit(self, x: &Array2<Float>, y: &Array1<i32>) -> Result<TrainedSequentialFeatureSelection> {
478 if x.nrows() != y.len() {
479 return Err(SklearsError::ShapeMismatch {
480 expected: "X.shape[0] == y.shape[0]".to_string(),
481 actual: format!("X.shape[0]={}, y.shape[0]={}", x.nrows(), y.len()),
482 });
483 }
484
485 let n_features = x.ncols();
486
487 let n_features_to_select = if let Some(n) = self.config.n_features_to_select {
489 n.min(n_features)
490 } else if let Some(fraction) = self.config.n_features_fraction {
491 ((n_features as Float * fraction).round() as usize)
492 .max(1)
493 .min(n_features)
494 } else {
495 n_features / 2
496 };
497
498 if n_features_to_select == 0 {
499 return Err(SklearsError::InvalidParameter {
500 name: "n_features_to_select".to_string(),
501 reason: "Number of features to select must be greater than 0".to_string(),
502 });
503 }
504
505 let (selection_path, scores) = match self.config.direction {
507 SelectionDirection::Forward => self.forward_selection(x, y, n_features_to_select)?,
508 SelectionDirection::Backward => {
509 self.backward_elimination(x, y, n_features_to_select)?
510 }
511 };
512
513 let mut support = Array1::from_elem(n_features, false);
515 let selected_features = match self.config.direction {
516 SelectionDirection::Forward => selection_path.clone(),
517 SelectionDirection::Backward => selection_path.clone(), };
519
520 for &feature_idx in &selected_features {
521 if feature_idx < n_features {
522 support[feature_idx] = true;
523 }
524 }
525
526 if selected_features.is_empty() {
528 return Err(SklearsError::InvalidParameter {
529 name: "feature_selection".to_string(),
530 reason: "No features were selected".to_string(),
531 });
532 }
533
534 let x_selected = x.select(Axis(1), &selected_features);
535 let estimator = LinearDiscriminantAnalysis::new();
536 let final_estimator = estimator.fit(&x_selected, y)?;
537
538 if self.config.verbose {
539 println!(
540 "Sequential feature selection completed. Selected {} features.",
541 selected_features.len()
542 );
543 println!("Selected features: {:?}", selected_features);
544 }
545
546 Ok(TrainedSequentialFeatureSelection {
547 support,
548 selection_path,
549 scores,
550 estimator: final_estimator,
551 n_features_in: n_features,
552 config: self.config.clone(),
553 })
554 }
555}
556
557impl TrainedSequentialFeatureSelection {
558 pub fn support(&self) -> &Array1<bool> {
560 &self.support
561 }
562
563 pub fn selection_path(&self) -> &[usize] {
565 &self.selection_path
566 }
567
568 pub fn scores(&self) -> &[Float] {
570 &self.scores
571 }
572
573 pub fn n_features(&self) -> usize {
575 self.support.iter().filter(|&&x| x).count()
576 }
577
578 pub fn get_support_indices(&self) -> Vec<usize> {
580 self.support
581 .iter()
582 .enumerate()
583 .filter_map(|(i, &selected)| if selected { Some(i) } else { None })
584 .collect()
585 }
586
587 pub fn estimator(&self) -> &LinearDiscriminantAnalysis<Trained> {
589 &self.estimator
590 }
591}
592
593impl Transform<Array2<Float>, Array2<Float>> for TrainedSequentialFeatureSelection {
594 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
595 if x.ncols() != self.n_features_in {
596 return Err(SklearsError::FeatureMismatch {
597 expected: self.n_features_in,
598 actual: x.ncols(),
599 });
600 }
601
602 let selected_features = self.get_support_indices();
603 Ok(x.select(Axis(1), &selected_features))
604 }
605}
606
607impl Predict<Array2<Float>, Array1<i32>> for TrainedSequentialFeatureSelection {
608 fn predict(&self, x: &Array2<Float>) -> Result<Array1<i32>> {
609 let x_transformed = self.transform(x)?;
610 self.estimator.predict(&x_transformed)
611 }
612}
613
614#[allow(non_snake_case)]
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use scirs2_core::ndarray::array;
619
620 #[test]
621 fn test_forward_selection() {
622 let x = array![
623 [1.0, 2.0, 3.0, 4.0, 5.0],
624 [2.0, 3.0, 4.0, 5.0, 6.0],
625 [3.0, 4.0, 5.0, 6.0, 7.0],
626 [4.0, 5.0, 6.0, 7.0, 8.0],
627 [5.0, 6.0, 7.0, 8.0, 9.0],
628 [6.0, 7.0, 8.0, 9.0, 10.0]
629 ];
630 let y = array![0, 0, 0, 1, 1, 1];
631
632 let sfs = SequentialFeatureSelection::new()
633 .direction(SelectionDirection::Forward)
634 .n_features_to_select(3)
635 .cv(2);
636
637 let fitted = sfs.fit(&x, &y).unwrap();
638
639 assert_eq!(fitted.n_features(), 3);
640 assert_eq!(fitted.support().len(), 5);
641 assert_eq!(fitted.selection_path().len(), 3);
642 assert!(!fitted.scores().is_empty());
643 }
644
645 #[test]
646 fn test_backward_elimination() {
647 let x = array![
648 [1.0, 2.0, 3.0, 4.0, 5.0],
649 [2.0, 3.0, 4.0, 5.0, 6.0],
650 [3.0, 4.0, 5.0, 6.0, 7.0],
651 [4.0, 5.0, 6.0, 7.0, 8.0],
652 [5.0, 6.0, 7.0, 8.0, 9.0],
653 [6.0, 7.0, 8.0, 9.0, 10.0]
654 ];
655 let y = array![0, 0, 0, 1, 1, 1];
656
657 let sfs = SequentialFeatureSelection::new()
658 .direction(SelectionDirection::Backward)
659 .n_features_to_select(3)
660 .cv(2);
661
662 let fitted = sfs.fit(&x, &y).unwrap();
663
664 assert_eq!(fitted.n_features(), 3);
665 assert_eq!(fitted.support().len(), 5);
666 assert!(!fitted.scores().is_empty());
667 }
668
669 #[test]
670 fn test_sfs_transform() {
671 let x = array![
672 [1.0, 2.0, 3.0, 4.0],
673 [2.0, 3.0, 4.0, 5.0],
674 [3.0, 4.0, 5.0, 6.0],
675 [4.0, 5.0, 6.0, 7.0]
676 ];
677 let y = array![0, 0, 1, 1];
678
679 let sfs = SequentialFeatureSelection::new()
680 .direction(SelectionDirection::Forward)
681 .n_features_to_select(2);
682
683 let fitted = sfs.fit(&x, &y).unwrap();
684 let x_transformed = fitted.transform(&x).unwrap();
685
686 assert_eq!(x_transformed.ncols(), 2);
687 assert_eq!(x_transformed.nrows(), 4);
688 }
689
690 #[test]
691 fn test_sfs_predict() {
692 let x = array![
693 [1.0, 2.0, 3.0, 4.0],
694 [2.0, 3.0, 4.0, 5.0],
695 [3.0, 4.0, 5.0, 6.0],
696 [4.0, 5.0, 6.0, 7.0]
697 ];
698 let y = array![0, 0, 1, 1];
699
700 let sfs = SequentialFeatureSelection::new()
701 .direction(SelectionDirection::Forward)
702 .n_features_to_select(2);
703
704 let fitted = sfs.fit(&x, &y).unwrap();
705 let predictions = fitted.predict(&x).unwrap();
706
707 assert_eq!(predictions.len(), 4);
708 }
709
710 #[test]
711 fn test_sfs_with_fraction() {
712 let x = array![
713 [1.0, 2.0, 3.0, 4.0, 5.0],
714 [2.0, 3.0, 4.0, 5.0, 6.0],
715 [3.0, 4.0, 5.0, 6.0, 7.0],
716 [4.0, 5.0, 6.0, 7.0, 8.0]
717 ];
718 let y = array![0, 0, 1, 1];
719
720 let sfs = SequentialFeatureSelection::new()
721 .direction(SelectionDirection::Forward)
722 .n_features_fraction(0.6); let fitted = sfs.fit(&x, &y).unwrap();
725
726 assert_eq!(fitted.n_features(), 3);
727 }
728
729 #[test]
730 fn test_different_scoring_methods() {
731 let x = array![
732 [1.0, 2.0, 3.0],
733 [2.0, 3.0, 4.0],
734 [3.0, 4.0, 5.0],
735 [4.0, 5.0, 6.0]
736 ];
737 let y = array![0, 0, 1, 1];
738
739 let scoring_methods = ["accuracy", "neg_log_loss"];
740
741 for method in &scoring_methods {
742 let sfs = SequentialFeatureSelection::new()
743 .direction(SelectionDirection::Forward)
744 .n_features_to_select(2)
745 .scoring(method);
746
747 let fitted = sfs.fit(&x, &y).unwrap();
748 assert_eq!(fitted.n_features(), 2);
749 }
750 }
751
752 #[test]
753 fn test_sfs_support_indices() {
754 let x = array![
755 [1.0, 2.0, 3.0, 4.0, 5.0],
756 [2.0, 3.0, 4.0, 5.0, 6.0],
757 [3.0, 4.0, 5.0, 6.0, 7.0],
758 [4.0, 5.0, 6.0, 7.0, 8.0]
759 ];
760 let y = array![0, 0, 1, 1];
761
762 let sfs = SequentialFeatureSelection::new()
763 .direction(SelectionDirection::Forward)
764 .n_features_to_select(3);
765
766 let fitted = sfs.fit(&x, &y).unwrap();
767 let support_indices = fitted.get_support_indices();
768
769 assert_eq!(support_indices.len(), 3);
770 assert!(support_indices.iter().all(|&i| i < 5));
771
772 let support = fitted.support();
774 for (i, &selected) in support.iter().enumerate() {
775 if selected {
776 assert!(support_indices.contains(&i));
777 } else {
778 assert!(!support_indices.contains(&i));
779 }
780 }
781 }
782}