sklears_feature_selection/automl/
method_selector.rs1use super::automl_core::{AutoMLMethod, DataCharacteristics, TargetType};
7use sklears_core::error::Result as SklResult;
8
9type Result<T> = SklResult<T>;
10
11#[derive(Debug, Clone)]
13pub struct MethodSelector;
14
15impl MethodSelector {
16 pub fn new() -> Self {
17 Self
18 }
19
20 pub fn select_methods(
21 &self,
22 characteristics: &DataCharacteristics,
23 ) -> Result<Vec<AutoMLMethod>> {
24 let mut selected_methods = Vec::new();
25
26 match characteristics.target_type {
28 TargetType::BinaryClassification | TargetType::MultiClassification => {
29 selected_methods.push(AutoMLMethod::UnivariateFiltering);
31
32 if characteristics.n_features > 100 {
33 selected_methods.push(AutoMLMethod::LassoBased);
34 }
35
36 if characteristics.computational_budget.allow_complex_methods {
37 selected_methods.push(AutoMLMethod::TreeBased);
38 }
39 }
40 TargetType::Regression => {
41 selected_methods.push(AutoMLMethod::CorrelationBased);
42
43 if characteristics.n_features > 50 {
44 selected_methods.push(AutoMLMethod::LassoBased);
45 }
46
47 if characteristics.computational_budget.allow_complex_methods {
48 selected_methods.push(AutoMLMethod::WrapperBased);
49 }
50 }
51 TargetType::MultiLabel => {
52 selected_methods.push(AutoMLMethod::UnivariateFiltering);
53 selected_methods.push(AutoMLMethod::EnsembleBased);
54 }
55 TargetType::Survival => {
56 selected_methods.push(AutoMLMethod::UnivariateFiltering);
57 selected_methods.push(AutoMLMethod::CorrelationBased);
58 }
59 }
60
61 if characteristics.correlation_structure.high_correlation_pairs
63 > characteristics.n_features / 4
64 && !selected_methods.contains(&AutoMLMethod::CorrelationBased)
65 {
66 selected_methods.push(AutoMLMethod::CorrelationBased);
67 }
68
69 if characteristics.feature_to_sample_ratio > 0.5
71 && characteristics.computational_budget.allow_complex_methods
72 {
73 selected_methods.push(AutoMLMethod::EnsembleBased);
74 }
75
76 if characteristics.n_features > 1000 && characteristics.n_samples > 1000 {
78 selected_methods.push(AutoMLMethod::Hybrid);
79 }
80
81 if characteristics.feature_to_sample_ratio > 2.0
83 && characteristics.computational_budget.allow_complex_methods
84 && !characteristics.computational_budget.prefer_speed
85 {
86 selected_methods.push(AutoMLMethod::NeuralArchitectureSearch);
87 }
88
89 if characteristics.n_features > 100
91 && characteristics.n_samples >= 500
92 && characteristics.computational_budget.allow_complex_methods
93 {
94 selected_methods.push(AutoMLMethod::TransferLearning);
95 }
96
97 if characteristics.n_features > 50
99 && characteristics.n_samples > 200
100 && characteristics.computational_budget.allow_complex_methods
101 && selected_methods.len() >= 2
102 {
103 selected_methods.push(AutoMLMethod::MetaLearningEnsemble);
104 }
105
106 if selected_methods.is_empty() {
107 selected_methods.push(AutoMLMethod::UnivariateFiltering);
109 }
110
111 Ok(selected_methods)
112 }
113}
114
115impl Default for MethodSelector {
116 fn default() -> Self {
117 Self::new()
118 }
119}