1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::error::{Result as SklResult, SklearsError};
8use std::collections::HashMap;
9
10type Result<T> = SklResult<T>;
11
12#[derive(Debug, Clone)]
14pub struct FeatureSelectionBuilder {
15 steps: Vec<SelectionStep>,
16 config: FluentConfig,
17 presets_applied: Vec<String>,
18}
19
20#[derive(Debug, Clone)]
22pub enum SelectionStep {
23 VarianceFilter { threshold: f64 },
25 SelectKBestFilter { k: usize, score_func: String },
27 RFEWrapper {
29 estimator_name: String,
30
31 n_features: Option<usize>,
32 },
33 RFECVWrapper {
34 estimator_name: String,
35 cv_folds: usize,
36 },
37 CustomFilter {
38 name: String,
39 params: HashMap<String, f64>,
40 },
41}
42
43#[derive(Debug, Clone)]
45pub struct FluentConfig {
46 pub parallel: bool,
47 pub random_state: Option<u64>,
48 pub verbose: bool,
49 pub cache_results: bool,
50 pub validation_split: Option<f64>,
51 pub scoring_metric: String,
52}
53
54impl Default for FluentConfig {
55 fn default() -> Self {
56 Self {
57 parallel: false,
58 random_state: None,
59 verbose: false,
60 cache_results: true,
61 validation_split: None,
62 scoring_metric: "f1_score".to_string(),
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct FluentSelectionResult {
70 pub selected_features: Vec<usize>,
71 pub feature_scores: Array1<f64>,
72 pub step_results: Vec<StepResult>,
73 pub total_execution_time: f64,
74 pub config_used: FluentConfig,
75}
76
77#[derive(Debug, Clone)]
78pub struct StepResult {
79 pub step_name: String,
80 pub features_before: usize,
81 pub features_after: usize,
82 pub execution_time: f64,
83 pub step_scores: Option<Array1<f64>>,
84}
85
86impl FeatureSelectionBuilder {
87 pub fn new() -> Self {
89 Self {
90 steps: Vec::new(),
91 config: FluentConfig::default(),
92 presets_applied: Vec::new(),
93 }
94 }
95
96 pub fn preset(mut self, preset_name: &str) -> Self {
98 self.presets_applied.push(preset_name.to_string());
99
100 match preset_name {
101 "high_dimensional" => self.apply_high_dimensional_preset(),
102 "quick_filter" => self.apply_quick_filter_preset(),
103 "comprehensive" => self.apply_comprehensive_preset(),
104 "time_series" => self.apply_time_series_preset(),
105 "text_data" => self.apply_text_data_preset(),
106 "biomedical" => self.apply_biomedical_preset(),
107 "finance" => self.apply_finance_preset(),
108 "computer_vision" => self.apply_computer_vision_preset(),
109 _ => {
110 eprintln!(
111 "Warning: Unknown preset '{}', using default configuration",
112 preset_name
113 );
114 self
115 }
116 }
117 }
118
119 pub fn parallel(mut self) -> Self {
121 self.config.parallel = true;
122 self
123 }
124
125 pub fn random_state(mut self, seed: u64) -> Self {
127 self.config.random_state = Some(seed);
128 self
129 }
130
131 pub fn verbose(mut self) -> Self {
133 self.config.verbose = true;
134 self
135 }
136
137 pub fn validation_split(mut self, ratio: f64) -> Self {
139 self.config.validation_split = Some(ratio);
140 self
141 }
142
143 pub fn scoring(mut self, metric: &str) -> Self {
145 self.config.scoring_metric = metric.to_string();
146 self
147 }
148
149 pub fn remove_low_variance(mut self, threshold: f64) -> Self {
151 self.steps.push(SelectionStep::VarianceFilter { threshold });
152 self
153 }
154
155 pub fn select_k_best(mut self, k: usize) -> Self {
157 self.steps.push(SelectionStep::SelectKBestFilter {
158 k,
159 score_func: "f_classif".to_string(),
160 });
161 self
162 }
163
164 pub fn select_k_best_with_scorer(mut self, k: usize, score_func: &str) -> Self {
166 self.steps.push(SelectionStep::SelectKBestFilter {
167 k,
168 score_func: score_func.to_string(),
169 });
170 self
171 }
172
173 pub fn rfe(mut self, estimator: &str, n_features: Option<usize>) -> Self {
175 self.steps.push(SelectionStep::RFEWrapper {
176 estimator_name: estimator.to_string(),
177 n_features,
178 });
179 self
180 }
181
182 pub fn rfe_cv(mut self, estimator: &str, cv_folds: usize) -> Self {
184 self.steps.push(SelectionStep::RFECVWrapper {
185 estimator_name: estimator.to_string(),
186 cv_folds,
187 });
188 self
189 }
190
191 pub fn custom_filter(mut self, name: &str, params: HashMap<String, f64>) -> Self {
193 self.steps.push(SelectionStep::CustomFilter {
194 name: name.to_string(),
195 params,
196 });
197 self
198 }
199
200 pub fn fit_transform(
202 &self,
203 X: ArrayView2<f64>,
204 y: ArrayView1<f64>,
205 ) -> Result<FluentSelectionResult> {
206 let start_time = std::time::Instant::now();
207 let mut current_X = X.to_owned();
208 let mut selected_features: Vec<usize> = (0..X.ncols()).collect();
209 let mut step_results = Vec::new();
210
211 for (step_idx, step) in self.steps.iter().enumerate() {
212 let step_start = std::time::Instant::now();
213 let features_before = current_X.ncols();
214
215 let step_result = match step {
216 SelectionStep::VarianceFilter { threshold } => {
217 self.apply_variance_filter(&mut current_X, &mut selected_features, *threshold)?
218 }
219 SelectionStep::SelectKBestFilter { k, score_func } => self.apply_select_k_best(
220 &mut current_X,
221 &y,
222 &mut selected_features,
223 *k,
224 score_func,
225 )?,
226 SelectionStep::RFEWrapper {
227 estimator_name,
228 n_features,
229 } => self.apply_rfe(
230 &mut current_X,
231 &y,
232 &mut selected_features,
233 estimator_name,
234 *n_features,
235 )?,
236 SelectionStep::RFECVWrapper {
237 estimator_name,
238 cv_folds,
239 } => self.apply_rfe_cv(
240 &mut current_X,
241 &y,
242 &mut selected_features,
243 estimator_name,
244 *cv_folds,
245 )?,
246 SelectionStep::CustomFilter { name, params } => self.apply_custom_filter(
247 &mut current_X,
248 &y,
249 &mut selected_features,
250 name,
251 params,
252 )?,
253 };
254
255 let step_time = step_start.elapsed().as_secs_f64();
256 step_results.push(StepResult {
257 step_name: format!("Step_{}: {:?}", step_idx + 1, step),
258 features_before,
259 features_after: current_X.ncols(),
260 execution_time: step_time,
261 step_scores: step_result,
262 });
263
264 if self.config.verbose {
265 println!(
266 "Step {}: {} features -> {} features ({:.3}s)",
267 step_idx + 1,
268 features_before,
269 current_X.ncols(),
270 step_time
271 );
272 }
273 }
274
275 let total_time = start_time.elapsed().as_secs_f64();
276
277 let feature_scores = if selected_features.is_empty() {
279 Array1::zeros(0)
280 } else {
281 Array1::ones(selected_features.len())
282 };
283
284 Ok(FluentSelectionResult {
285 selected_features,
286 feature_scores,
287 step_results,
288 total_execution_time: total_time,
289 config_used: self.config.clone(),
290 })
291 }
292
293 fn apply_variance_filter(
295 &self,
296 X: &mut Array2<f64>,
297 selected_features: &mut Vec<usize>,
298 threshold: f64,
299 ) -> Result<Option<Array1<f64>>> {
300 let variances: Vec<f64> = (0..X.ncols())
302 .map(|col| {
303 let column = X.column(col);
304 let mean = column.mean().unwrap_or(0.0);
305 column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
306 / (column.len() as f64 - 1.0)
307 })
308 .collect();
309
310 let keep_indices: Vec<usize> = variances
311 .iter()
312 .enumerate()
313 .filter(|(_, &var)| var >= threshold)
314 .map(|(idx, _)| idx)
315 .collect();
316
317 if keep_indices.is_empty() {
318 return Err(SklearsError::InvalidInput(
319 "All features removed by variance threshold".to_string(),
320 ));
321 }
322
323 *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
325
326 let new_X = Array2::from_shape_fn((X.nrows(), keep_indices.len()), |(row, col)| {
328 X[[row, keep_indices[col]]]
329 });
330 *X = new_X;
331
332 Ok(Some(Array1::from(variances)))
333 }
334
335 fn apply_select_k_best(
336 &self,
337 X: &mut Array2<f64>,
338 y: &ArrayView1<f64>,
339 selected_features: &mut Vec<usize>,
340 k: usize,
341 _score_func: &str,
342 ) -> Result<Option<Array1<f64>>> {
343 if k >= X.ncols() {
344 return Ok(None); }
346
347 let scores: Vec<f64> = (0..X.ncols())
349 .map(|col| {
350 let x_col = X.column(col);
351 let x_mean = x_col.mean().unwrap_or(0.0);
352 let y_mean = y.mean().unwrap_or(0.0);
353
354 let numerator: f64 = x_col
355 .iter()
356 .zip(y.iter())
357 .map(|(&x, &y_val)| (x - x_mean) * (y_val - y_mean))
358 .sum();
359
360 let x_var: f64 = x_col.iter().map(|&x| (x - x_mean).powi(2)).sum();
361 let y_var: f64 = y.iter().map(|&y_val| (y_val - y_mean).powi(2)).sum();
362
363 if x_var > 0.0 && y_var > 0.0 {
364 numerator.abs() / (x_var * y_var).sqrt()
365 } else {
366 0.0
367 }
368 })
369 .collect();
370
371 let mut score_indices: Vec<(usize, f64)> =
373 scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
374 score_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
375
376 let keep_indices: Vec<usize> = score_indices.iter().take(k).map(|(idx, _)| *idx).collect();
377
378 *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
380
381 let new_X = Array2::from_shape_fn((X.nrows(), k), |(row, col)| X[[row, keep_indices[col]]]);
383 *X = new_X;
384
385 Ok(Some(Array1::from(scores)))
386 }
387
388 fn apply_rfe(
389 &self,
390 X: &mut Array2<f64>,
391 _y: &ArrayView1<f64>,
392 selected_features: &mut Vec<usize>,
393 _estimator_name: &str,
394 n_features: Option<usize>,
395 ) -> Result<Option<Array1<f64>>> {
396 let target_features = n_features.unwrap_or(X.ncols() / 2).min(X.ncols());
397
398 if target_features >= X.ncols() {
399 return Ok(None); }
401
402 let mut current_features: Vec<usize> = (0..X.ncols()).collect();
404 let mut current_X = X.clone();
405
406 while current_features.len() > target_features {
407 let importances: Vec<f64> = (0..current_X.ncols())
409 .map(|col| {
410 let column = current_X.column(col);
411 let mean = column.mean().unwrap_or(0.0);
412 column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / column.len() as f64
413 })
414 .collect();
415
416 let min_idx = importances
418 .iter()
419 .enumerate()
420 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
421 .map(|(idx, _)| idx)
422 .unwrap();
423
424 current_features.remove(min_idx);
425
426 let mut new_data = Vec::new();
428 for row in 0..current_X.nrows() {
429 for col in 0..current_X.ncols() {
430 if col != min_idx {
431 new_data.push(current_X[[row, col]]);
432 }
433 }
434 }
435 current_X =
436 Array2::from_shape_vec((current_X.nrows(), current_features.len()), new_data)
437 .map_err(|_| {
438 SklearsError::InvalidInput("Failed to reshape array".to_string())
439 })?;
440 }
441
442 *selected_features = current_features
444 .iter()
445 .map(|&i| selected_features[i])
446 .collect();
447 *X = current_X;
448
449 Ok(Some(Array1::ones(selected_features.len())))
450 }
451
452 fn apply_rfe_cv(
453 &self,
454 X: &mut Array2<f64>,
455 y: &ArrayView1<f64>,
456 selected_features: &mut Vec<usize>,
457 estimator_name: &str,
458 _cv_folds: usize,
459 ) -> Result<Option<Array1<f64>>> {
460 let optimal_features = X.ncols() / 3; self.apply_rfe(
463 X,
464 y,
465 selected_features,
466 estimator_name,
467 Some(optimal_features),
468 )
469 }
470
471 fn apply_custom_filter(
472 &self,
473 X: &mut Array2<f64>,
474 _y: &ArrayView1<f64>,
475 selected_features: &mut Vec<usize>,
476 name: &str,
477 params: &HashMap<String, f64>,
478 ) -> Result<Option<Array1<f64>>> {
479 match name {
480 "correlation_threshold" => {
481 let threshold = params.get("threshold").unwrap_or(&0.5);
482 let keep_ratio = 1.0 - threshold;
484 let target_features = ((X.ncols() as f64) * keep_ratio) as usize;
485
486 if target_features >= X.ncols() {
487 return Ok(None);
488 }
489
490 let keep_indices: Vec<usize> = (0..target_features).collect();
492 *selected_features = keep_indices.iter().map(|&i| selected_features[i]).collect();
493
494 let new_X =
495 Array2::from_shape_fn((X.nrows(), target_features), |(row, col)| X[[row, col]]);
496 *X = new_X;
497
498 Ok(Some(Array1::ones(target_features)))
499 }
500 _ => Err(SklearsError::InvalidInput(format!(
501 "Unknown custom filter: {}",
502 name
503 ))),
504 }
505 }
506
507 fn apply_high_dimensional_preset(mut self) -> Self {
509 self.config.parallel = true;
510 self.remove_low_variance(0.01)
511 .select_k_best(1000)
512 .rfe("linear_svm", Some(100))
513 }
514
515 fn apply_quick_filter_preset(self) -> Self {
516 self.remove_low_variance(0.0).select_k_best(50)
517 }
518
519 fn apply_comprehensive_preset(mut self) -> Self {
520 self.config.parallel = true;
521 self.config.validation_split = Some(0.2);
522 self.remove_low_variance(0.001)
523 .select_k_best_with_scorer(500, "mutual_info")
524 .rfe_cv("random_forest", 5)
525 }
526
527 fn apply_time_series_preset(mut self) -> Self {
528 self.config.scoring_metric = "mse".to_string();
529 self.remove_low_variance(0.001)
530 .select_k_best_with_scorer(100, "f_regression")
531 }
532
533 fn apply_text_data_preset(self) -> Self {
534 self.remove_low_variance(0.0)
535 .select_k_best_with_scorer(1000, "chi2")
536 .rfe("naive_bayes", Some(200))
537 }
538
539 fn apply_biomedical_preset(mut self) -> Self {
540 self.config.validation_split = Some(0.3);
541 self.remove_low_variance(0.01)
542 .select_k_best_with_scorer(500, "mutual_info")
543 .rfe_cv("svm", 10)
544 }
545
546 fn apply_finance_preset(mut self) -> Self {
547 self.config.scoring_metric = "sharpe_ratio".to_string();
548 self.remove_low_variance(0.001)
549 .select_k_best_with_scorer(50, "f_regression")
550 .custom_filter("correlation_threshold", {
551 let mut params = HashMap::new();
552 params.insert("threshold".to_string(), 0.8);
553 params
554 })
555 }
556
557 fn apply_computer_vision_preset(mut self) -> Self {
558 self.config.parallel = true;
559 self.remove_low_variance(0.0)
560 .select_k_best(2000)
561 .rfe("cnn", Some(500))
562 }
563}
564
565impl Default for FeatureSelectionBuilder {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571pub mod presets {
573 use super::*;
574
575 pub fn quick_eda() -> FeatureSelectionBuilder {
576 FeatureSelectionBuilder::new().preset("quick_filter")
577 }
578
579 pub fn high_dimensional() -> FeatureSelectionBuilder {
581 FeatureSelectionBuilder::new()
582 .preset("high_dimensional")
583 .parallel()
584 }
585
586 pub fn comprehensive() -> FeatureSelectionBuilder {
588 FeatureSelectionBuilder::new()
589 .preset("comprehensive")
590 .verbose()
591 .validation_split(0.2)
592 }
593
594 pub fn time_series() -> FeatureSelectionBuilder {
596 FeatureSelectionBuilder::new()
597 .preset("time_series")
598 .scoring("mse")
599 }
600
601 pub fn text_classification() -> FeatureSelectionBuilder {
603 FeatureSelectionBuilder::new()
604 .preset("text_data")
605 .scoring("f1_score")
606 }
607
608 pub fn biomedical() -> FeatureSelectionBuilder {
610 FeatureSelectionBuilder::new()
611 .preset("biomedical")
612 .validation_split(0.3)
613 .random_state(42)
614 }
615
616 pub fn finance() -> FeatureSelectionBuilder {
618 FeatureSelectionBuilder::new()
619 .preset("finance")
620 .scoring("sharpe_ratio")
621 }
622
623 pub fn computer_vision() -> FeatureSelectionBuilder {
625 FeatureSelectionBuilder::new()
626 .preset("computer_vision")
627 .parallel()
628 }
629}
630
631#[allow(non_snake_case)]
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_fluent_api_basic() {
638 let builder = FeatureSelectionBuilder::new()
639 .remove_low_variance(0.1)
640 .select_k_best(10)
641 .verbose();
642
643 assert_eq!(builder.steps.len(), 2);
644 assert!(builder.config.verbose);
645 }
646
647 #[test]
648 fn test_preset_application() {
649 let builder = FeatureSelectionBuilder::new().preset("high_dimensional");
650
651 assert!(builder.config.parallel);
652 assert_eq!(builder.presets_applied, vec!["high_dimensional"]);
653 assert_eq!(builder.steps.len(), 3); }
655
656 #[test]
657 fn test_method_chaining() {
658 let builder = FeatureSelectionBuilder::new()
659 .parallel()
660 .random_state(42)
661 .verbose()
662 .validation_split(0.2)
663 .scoring("f1_score")
664 .remove_low_variance(0.01)
665 .select_k_best(100);
666
667 assert!(builder.config.parallel);
668 assert_eq!(builder.config.random_state, Some(42));
669 assert!(builder.config.verbose);
670 assert_eq!(builder.config.validation_split, Some(0.2));
671 assert_eq!(builder.config.scoring_metric, "f1_score");
672 assert_eq!(builder.steps.len(), 2);
673 }
674
675 #[test]
676 fn test_convenience_presets() {
677 let quick = presets::quick_eda();
678 assert_eq!(quick.presets_applied, vec!["quick_filter"]);
679
680 let comprehensive = presets::comprehensive();
681 assert!(comprehensive.config.verbose);
682 assert_eq!(comprehensive.config.validation_split, Some(0.2));
683
684 let biomedical = presets::biomedical();
685 assert_eq!(biomedical.config.random_state, Some(42));
686 assert_eq!(biomedical.config.validation_split, Some(0.3));
687 }
688}