sklears_feature_selection/
fluent_api.rs

1//! Fluent API for Feature Selection Configuration
2//!
3//! This module provides a fluent, builder-style API for configuring complex feature selection
4//! pipelines with method chaining and configuration presets for common use cases.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::error::{Result as SklResult, SklearsError};
8use std::collections::HashMap;
9
10type Result<T> = SklResult<T>;
11
12/// Fluent API for building feature selection configurations
13#[derive(Debug, Clone)]
14pub struct FeatureSelectionBuilder {
15    steps: Vec<SelectionStep>,
16    config: FluentConfig,
17    presets_applied: Vec<String>,
18}
19
20/// Individual step in the fluent selection pipeline
21#[derive(Debug, Clone)]
22pub enum SelectionStep {
23    /// VarianceFilter
24    VarianceFilter { threshold: f64 },
25    /// SelectKBestFilter
26    SelectKBestFilter { k: usize, score_func: String },
27    /// RFEWrapper
28    RFEWrapper {
29        estimator_name: String,
30
31        n_features: Option<usize>,
32    },
33    RFECVWrapper {
34        estimator_name: String,
35        cv_folds: usize,
36    },
37    CustomFilter {
38        name: String,
39        params: HashMap<String, f64>,
40    },
41}
42
43/// Configuration for the fluent API
44#[derive(Debug, Clone)]
45pub struct FluentConfig {
46    pub parallel: bool,
47    pub random_state: Option<u64>,
48    pub verbose: bool,
49    pub cache_results: bool,
50    pub validation_split: Option<f64>,
51    pub scoring_metric: String,
52}
53
54impl Default for FluentConfig {
55    fn default() -> Self {
56        Self {
57            parallel: false,
58            random_state: None,
59            verbose: false,
60            cache_results: true,
61            validation_split: None,
62            scoring_metric: "f1_score".to_string(),
63        }
64    }
65}
66
67/// Results from fluent feature selection
68#[derive(Debug, Clone)]
69pub struct FluentSelectionResult {
70    pub selected_features: Vec<usize>,
71    pub feature_scores: Array1<f64>,
72    pub step_results: Vec<StepResult>,
73    pub total_execution_time: f64,
74    pub config_used: FluentConfig,
75}
76
77#[derive(Debug, Clone)]
78pub struct StepResult {
79    pub step_name: String,
80    pub features_before: usize,
81    pub features_after: usize,
82    pub execution_time: f64,
83    pub step_scores: Option<Array1<f64>>,
84}
85
86impl FeatureSelectionBuilder {
87    /// Create a new fluent feature selection builder
88    pub fn new() -> Self {
89        Self {
90            steps: Vec::new(),
91            config: FluentConfig::default(),
92            presets_applied: Vec::new(),
93        }
94    }
95
96    /// Apply a preset configuration for common use cases
97    pub fn preset(mut self, preset_name: &str) -> Self {
98        self.presets_applied.push(preset_name.to_string());
99
100        match preset_name {
101            "high_dimensional" => self.apply_high_dimensional_preset(),
102            "quick_filter" => self.apply_quick_filter_preset(),
103            "comprehensive" => self.apply_comprehensive_preset(),
104            "time_series" => self.apply_time_series_preset(),
105            "text_data" => self.apply_text_data_preset(),
106            "biomedical" => self.apply_biomedical_preset(),
107            "finance" => self.apply_finance_preset(),
108            "computer_vision" => self.apply_computer_vision_preset(),
109            _ => {
110                eprintln!(
111                    "Warning: Unknown preset '{}', using default configuration",
112                    preset_name
113                );
114                self
115            }
116        }
117    }
118
119    /// Enable parallel processing
120    pub fn parallel(mut self) -> Self {
121        self.config.parallel = true;
122        self
123    }
124
125    /// Set random state for reproducibility
126    pub fn random_state(mut self, seed: u64) -> Self {
127        self.config.random_state = Some(seed);
128        self
129    }
130
131    /// Enable verbose output
132    pub fn verbose(mut self) -> Self {
133        self.config.verbose = true;
134        self
135    }
136
137    /// Set validation split ratio
138    pub fn validation_split(mut self, ratio: f64) -> Self {
139        self.config.validation_split = Some(ratio);
140        self
141    }
142
143    /// Set scoring metric
144    pub fn scoring(mut self, metric: &str) -> Self {
145        self.config.scoring_metric = metric.to_string();
146        self
147    }
148
149    /// Add variance threshold filtering step
150    pub fn remove_low_variance(mut self, threshold: f64) -> Self {
151        self.steps.push(SelectionStep::VarianceFilter { threshold });
152        self
153    }
154
155    /// Add SelectKBest filtering step
156    pub fn select_k_best(mut self, k: usize) -> Self {
157        self.steps.push(SelectionStep::SelectKBestFilter {
158            k,
159            score_func: "f_classif".to_string(),
160        });
161        self
162    }
163
164    /// Add SelectKBest with custom scoring function
165    pub fn select_k_best_with_scorer(mut self, k: usize, score_func: &str) -> Self {
166        self.steps.push(SelectionStep::SelectKBestFilter {
167            k,
168            score_func: score_func.to_string(),
169        });
170        self
171    }
172
173    /// Add Recursive Feature Elimination step
174    pub fn rfe(mut self, estimator: &str, n_features: Option<usize>) -> Self {
175        self.steps.push(SelectionStep::RFEWrapper {
176            estimator_name: estimator.to_string(),
177            n_features,
178        });
179        self
180    }
181
182    /// Add RFE with Cross-Validation step
183    pub fn rfe_cv(mut self, estimator: &str, cv_folds: usize) -> Self {
184        self.steps.push(SelectionStep::RFECVWrapper {
185            estimator_name: estimator.to_string(),
186            cv_folds,
187        });
188        self
189    }
190
191    /// Add custom filter step
192    pub fn custom_filter(mut self, name: &str, params: HashMap<String, f64>) -> Self {
193        self.steps.push(SelectionStep::CustomFilter {
194            name: name.to_string(),
195            params,
196        });
197        self
198    }
199
200    /// Build and execute the feature selection pipeline
201    pub fn fit_transform(
202        &self,
203        X: ArrayView2<f64>,
204        y: ArrayView1<f64>,
205    ) -> Result<FluentSelectionResult> {
206        let start_time = std::time::Instant::now();
207        let mut current_X = X.to_owned();
208        let mut selected_features: Vec<usize> = (0..X.ncols()).collect();
209        let mut step_results = Vec::new();
210
211        for (step_idx, step) in self.steps.iter().enumerate() {
212            let step_start = std::time::Instant::now();
213            let features_before = current_X.ncols();
214
215            let step_result = match step {
216                SelectionStep::VarianceFilter { threshold } => {
217                    self.apply_variance_filter(&mut current_X, &mut selected_features, *threshold)?
218                }
219                SelectionStep::SelectKBestFilter { k, score_func } => self.apply_select_k_best(
220                    &mut current_X,
221                    &y,
222                    &mut selected_features,
223                    *k,
224                    score_func,
225                )?,
226                SelectionStep::RFEWrapper {
227                    estimator_name,
228                    n_features,
229                } => self.apply_rfe(
230                    &mut current_X,
231                    &y,
232                    &mut selected_features,
233                    estimator_name,
234                    *n_features,
235                )?,
236                SelectionStep::RFECVWrapper {
237                    estimator_name,
238                    cv_folds,
239                } => self.apply_rfe_cv(
240                    &mut current_X,
241                    &y,
242                    &mut selected_features,
243                    estimator_name,
244                    *cv_folds,
245                )?,
246                SelectionStep::CustomFilter { name, params } => self.apply_custom_filter(
247                    &mut current_X,
248                    &y,
249                    &mut selected_features,
250                    name,
251                    params,
252                )?,
253            };
254
255            let step_time = step_start.elapsed().as_secs_f64();
256            step_results.push(StepResult {
257                step_name: format!("Step_{}: {:?}", step_idx + 1, step),
258                features_before,
259                features_after: current_X.ncols(),
260                execution_time: step_time,
261                step_scores: step_result,
262            });
263
264            if self.config.verbose {
265                println!(
266                    "Step {}: {} features -> {} features ({:.3}s)",
267                    step_idx + 1,
268                    features_before,
269                    current_X.ncols(),
270                    step_time
271                );
272            }
273        }
274
275        let total_time = start_time.elapsed().as_secs_f64();
276
277        // Generate final feature scores (simplified)
278        let feature_scores = if selected_features.is_empty() {
279            Array1::zeros(0)
280        } else {
281            Array1::ones(selected_features.len())
282        };
283
284        Ok(FluentSelectionResult {
285            selected_features,
286            feature_scores,
287            step_results,
288            total_execution_time: total_time,
289            config_used: self.config.clone(),
290        })
291    }
292
293    // Private methods for applying different steps
294    fn apply_variance_filter(
295        &self,
296        X: &mut Array2<f64>,
297        selected_features: &mut Vec<usize>,
298        threshold: f64,
299    ) -> Result<Option<Array1<f64>>> {
300        // Simple variance-based filtering implementation
301        let variances: Vec<f64> = (0..X.ncols())
302            .map(|col| {
303                let column = X.column(col);
304                let mean = column.mean().unwrap_or(0.0);
305                column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
306                    / (column.len() as f64 - 1.0)
307            })
308            .collect();
309
310        let keep_indices: Vec<usize> = variances
311            .iter()
312            .enumerate()
313            .filter(|(_, &var)| var >= threshold)
314            .map(|(idx, _)| idx)
315            .collect();
316
317        if keep_indices.is_empty() {
318            return Err(SklearsError::InvalidInput(
319                "All features removed by variance threshold".to_string(),
320            ));
321        }
322
323        // Update selected features
324        *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
325
326        // Create new X with only selected columns
327        let new_X = Array2::from_shape_fn((X.nrows(), keep_indices.len()), |(row, col)| {
328            X[[row, keep_indices[col]]]
329        });
330        *X = new_X;
331
332        Ok(Some(Array1::from(variances)))
333    }
334
335    fn apply_select_k_best(
336        &self,
337        X: &mut Array2<f64>,
338        y: &ArrayView1<f64>,
339        selected_features: &mut Vec<usize>,
340        k: usize,
341        _score_func: &str,
342    ) -> Result<Option<Array1<f64>>> {
343        if k >= X.ncols() {
344            return Ok(None); // No filtering needed
345        }
346
347        // Simple correlation-based scoring (placeholder implementation)
348        let scores: Vec<f64> = (0..X.ncols())
349            .map(|col| {
350                let x_col = X.column(col);
351                let x_mean = x_col.mean().unwrap_or(0.0);
352                let y_mean = y.mean().unwrap_or(0.0);
353
354                let numerator: f64 = x_col
355                    .iter()
356                    .zip(y.iter())
357                    .map(|(&x, &y_val)| (x - x_mean) * (y_val - y_mean))
358                    .sum();
359
360                let x_var: f64 = x_col.iter().map(|&x| (x - x_mean).powi(2)).sum();
361                let y_var: f64 = y.iter().map(|&y_val| (y_val - y_mean).powi(2)).sum();
362
363                if x_var > 0.0 && y_var > 0.0 {
364                    numerator.abs() / (x_var * y_var).sqrt()
365                } else {
366                    0.0
367                }
368            })
369            .collect();
370
371        // Get top k features
372        let mut score_indices: Vec<(usize, f64)> =
373            scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
374        score_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
375
376        let keep_indices: Vec<usize> = score_indices.iter().take(k).map(|(idx, _)| *idx).collect();
377
378        // Update selected features
379        *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
380
381        // Create new X with only selected columns
382        let new_X = Array2::from_shape_fn((X.nrows(), k), |(row, col)| X[[row, keep_indices[col]]]);
383        *X = new_X;
384
385        Ok(Some(Array1::from(scores)))
386    }
387
388    fn apply_rfe(
389        &self,
390        X: &mut Array2<f64>,
391        _y: &ArrayView1<f64>,
392        selected_features: &mut Vec<usize>,
393        _estimator_name: &str,
394        n_features: Option<usize>,
395    ) -> Result<Option<Array1<f64>>> {
396        let target_features = n_features.unwrap_or(X.ncols() / 2).min(X.ncols());
397
398        if target_features >= X.ncols() {
399            return Ok(None); // No elimination needed
400        }
401
402        // Simple RFE implementation using feature importance
403        let mut current_features: Vec<usize> = (0..X.ncols()).collect();
404        let mut current_X = X.clone();
405
406        while current_features.len() > target_features {
407            // Calculate feature importance (simplified using variance)
408            let importances: Vec<f64> = (0..current_X.ncols())
409                .map(|col| {
410                    let column = current_X.column(col);
411                    let mean = column.mean().unwrap_or(0.0);
412                    column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64
413                })
414                .collect();
415
416            // Remove feature with lowest importance
417            let min_idx = importances
418                .iter()
419                .enumerate()
420                .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
421                .map(|(idx, _)| idx)
422                .unwrap();
423
424            current_features.remove(min_idx);
425
426            // Create new X without the removed feature
427            let mut new_data = Vec::new();
428            for row in 0..current_X.nrows() {
429                for col in 0..current_X.ncols() {
430                    if col != min_idx {
431                        new_data.push(current_X[[row, col]]);
432                    }
433                }
434            }
435            current_X =
436                Array2::from_shape_vec((current_X.nrows(), current_features.len()), new_data)
437                    .map_err(|_| {
438                        SklearsError::InvalidInput("Failed to reshape array".to_string())
439                    })?;
440        }
441
442        // Update selected features
443        *selected_features = current_features
444            .iter()
445            .map(|&i| selected_features[i])
446            .collect();
447        *X = current_X;
448
449        Ok(Some(Array1::ones(selected_features.len())))
450    }
451
452    fn apply_rfe_cv(
453        &self,
454        X: &mut Array2<f64>,
455        y: &ArrayView1<f64>,
456        selected_features: &mut Vec<usize>,
457        estimator_name: &str,
458        _cv_folds: usize,
459    ) -> Result<Option<Array1<f64>>> {
460        // For simplicity, use regular RFE with cross-validation target
461        let optimal_features = X.ncols() / 3; // Simplified CV-based selection
462        self.apply_rfe(
463            X,
464            y,
465            selected_features,
466            estimator_name,
467            Some(optimal_features),
468        )
469    }
470
471    fn apply_custom_filter(
472        &self,
473        X: &mut Array2<f64>,
474        _y: &ArrayView1<f64>,
475        selected_features: &mut Vec<usize>,
476        name: &str,
477        params: &HashMap<String, f64>,
478    ) -> Result<Option<Array1<f64>>> {
479        match name {
480            "correlation_threshold" => {
481                let threshold = params.get("threshold").unwrap_or(&0.5);
482                // Simple correlation-based filtering
483                let keep_ratio = 1.0 - threshold;
484                let target_features = ((X.ncols() as f64) * keep_ratio) as usize;
485
486                if target_features >= X.ncols() {
487                    return Ok(None);
488                }
489
490                // Keep random subset (simplified)
491                let keep_indices: Vec<usize> = (0..target_features).collect();
492                *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
493
494                let new_X =
495                    Array2::from_shape_fn((X.nrows(), target_features), |(row, col)| X[[row, col]]);
496                *X = new_X;
497
498                Ok(Some(Array1::ones(target_features)))
499            }
500            _ => Err(SklearsError::InvalidInput(format!(
501                "Unknown custom filter: {}",
502                name
503            ))),
504        }
505    }
506
507    // Preset implementations
508    fn apply_high_dimensional_preset(mut self) -> Self {
509        self.config.parallel = true;
510        self.remove_low_variance(0.01)
511            .select_k_best(1000)
512            .rfe("linear_svm", Some(100))
513    }
514
515    fn apply_quick_filter_preset(self) -> Self {
516        self.remove_low_variance(0.0).select_k_best(50)
517    }
518
519    fn apply_comprehensive_preset(mut self) -> Self {
520        self.config.parallel = true;
521        self.config.validation_split = Some(0.2);
522        self.remove_low_variance(0.001)
523            .select_k_best_with_scorer(500, "mutual_info")
524            .rfe_cv("random_forest", 5)
525    }
526
527    fn apply_time_series_preset(mut self) -> Self {
528        self.config.scoring_metric = "mse".to_string();
529        self.remove_low_variance(0.001)
530            .select_k_best_with_scorer(100, "f_regression")
531    }
532
533    fn apply_text_data_preset(self) -> Self {
534        self.remove_low_variance(0.0)
535            .select_k_best_with_scorer(1000, "chi2")
536            .rfe("naive_bayes", Some(200))
537    }
538
539    fn apply_biomedical_preset(mut self) -> Self {
540        self.config.validation_split = Some(0.3);
541        self.remove_low_variance(0.01)
542            .select_k_best_with_scorer(500, "mutual_info")
543            .rfe_cv("svm", 10)
544    }
545
546    fn apply_finance_preset(mut self) -> Self {
547        self.config.scoring_metric = "sharpe_ratio".to_string();
548        self.remove_low_variance(0.001)
549            .select_k_best_with_scorer(50, "f_regression")
550            .custom_filter("correlation_threshold", {
551                let mut params = HashMap::new();
552                params.insert("threshold".to_string(), 0.8);
553                params
554            })
555    }
556
557    fn apply_computer_vision_preset(mut self) -> Self {
558        self.config.parallel = true;
559        self.remove_low_variance(0.0)
560            .select_k_best(2000)
561            .rfe("cnn", Some(500))
562    }
563}
564
565impl Default for FeatureSelectionBuilder {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571/// Convenience functions for common use cases
572pub mod presets {
573    use super::*;
574
575    pub fn quick_eda() -> FeatureSelectionBuilder {
576        FeatureSelectionBuilder::new().preset("quick_filter")
577    }
578
579    /// High-dimensional data feature selection
580    pub fn high_dimensional() -> FeatureSelectionBuilder {
581        FeatureSelectionBuilder::new()
582            .preset("high_dimensional")
583            .parallel()
584    }
585
586    /// Comprehensive feature selection with validation
587    pub fn comprehensive() -> FeatureSelectionBuilder {
588        FeatureSelectionBuilder::new()
589            .preset("comprehensive")
590            .verbose()
591            .validation_split(0.2)
592    }
593
594    /// Time series feature selection
595    pub fn time_series() -> FeatureSelectionBuilder {
596        FeatureSelectionBuilder::new()
597            .preset("time_series")
598            .scoring("mse")
599    }
600
601    /// Text classification feature selection
602    pub fn text_classification() -> FeatureSelectionBuilder {
603        FeatureSelectionBuilder::new()
604            .preset("text_data")
605            .scoring("f1_score")
606    }
607
608    /// Biomedical data feature selection
609    pub fn biomedical() -> FeatureSelectionBuilder {
610        FeatureSelectionBuilder::new()
611            .preset("biomedical")
612            .validation_split(0.3)
613            .random_state(42)
614    }
615
616    /// Financial data feature selection
617    pub fn finance() -> FeatureSelectionBuilder {
618        FeatureSelectionBuilder::new()
619            .preset("finance")
620            .scoring("sharpe_ratio")
621    }
622
623    /// Computer vision feature selection
624    pub fn computer_vision() -> FeatureSelectionBuilder {
625        FeatureSelectionBuilder::new()
626            .preset("computer_vision")
627            .parallel()
628    }
629}
630
631#[allow(non_snake_case)]
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    #[test]
637    fn test_fluent_api_basic() {
638        let builder = FeatureSelectionBuilder::new()
639            .remove_low_variance(0.1)
640            .select_k_best(10)
641            .verbose();
642
643        assert_eq!(builder.steps.len(), 2);
644        assert!(builder.config.verbose);
645    }
646
647    #[test]
648    fn test_preset_application() {
649        let builder = FeatureSelectionBuilder::new().preset("high_dimensional");
650
651        assert!(builder.config.parallel);
652        assert_eq!(builder.presets_applied, vec!["high_dimensional"]);
653        assert_eq!(builder.steps.len(), 3); // variance + select_k_best + rfe
654    }
655
656    #[test]
657    fn test_method_chaining() {
658        let builder = FeatureSelectionBuilder::new()
659            .parallel()
660            .random_state(42)
661            .verbose()
662            .validation_split(0.2)
663            .scoring("f1_score")
664            .remove_low_variance(0.01)
665            .select_k_best(100);
666
667        assert!(builder.config.parallel);
668        assert_eq!(builder.config.random_state, Some(42));
669        assert!(builder.config.verbose);
670        assert_eq!(builder.config.validation_split, Some(0.2));
671        assert_eq!(builder.config.scoring_metric, "f1_score");
672        assert_eq!(builder.steps.len(), 2);
673    }
674
675    #[test]
676    fn test_convenience_presets() {
677        let quick = presets::quick_eda();
678        assert_eq!(quick.presets_applied, vec!["quick_filter"]);
679
680        let comprehensive = presets::comprehensive();
681        assert!(comprehensive.config.verbose);
682        assert_eq!(comprehensive.config.validation_split, Some(0.2));
683
684        let biomedical = presets::biomedical();
685        assert_eq!(biomedical.config.random_state, Some(42));
686        assert_eq!(biomedical.config.validation_split, Some(0.3));
687    }
688}