scirs2_integrate/analysis/ml_prediction/
ensemble.rs1use scirs2_core::ndarray::Array1;
7
8#[derive(Debug, Clone)]
10pub struct BifurcationEnsembleClassifier {
11 pub base_classifiers: Vec<BaseClassifier>,
13 pub meta_learner: Option<MetaLearner>,
15 pub training_strategy: EnsembleTrainingStrategy,
17 pub cross_validation: CrossValidationConfig,
19 pub feature_selection: FeatureSelectionConfig,
21}
22
23#[derive(Debug, Clone)]
25pub enum BaseClassifier {
26 NeuralNetwork(Box<super::neural_network::BifurcationPredictionNetwork>),
28 RandomForest {
30 n_trees: usize,
31 max_depth: Option<usize>,
32 min_samples_split: usize,
33 min_samples_leaf: usize,
34 },
35 SVM {
37 kernel: SVMKernel,
38 c_parameter: f64,
39 gamma: Option<f64>,
40 },
41 GradientBoosting {
43 n_estimators: usize,
44 learning_rate: f64,
45 max_depth: usize,
46 subsample: f64,
47 },
48 KNN {
50 n_neighbors: usize,
51 weights: KNNWeights,
52 distance_metric: DistanceMetric,
53 },
54}
55
56#[derive(Debug, Clone, Copy)]
58pub enum SVMKernel {
59 Linear,
60 RBF,
61 Polynomial(usize), Sigmoid,
63}
64
65#[derive(Debug, Clone, Copy)]
67pub enum KNNWeights {
68 Uniform,
69 Distance,
70}
71
72#[derive(Debug, Clone, Copy)]
74pub enum DistanceMetric {
75 Euclidean,
76 Manhattan,
77 Minkowski(f64), Cosine,
79 Hamming,
80}
81
82#[derive(Debug, Clone)]
84pub enum MetaLearner {
85 LinearCombination { weights: Array1<f64> },
87 LogisticRegression { regularization: f64 },
89 NeuralNetwork { hidden_layers: Vec<usize> },
91 DecisionTree { max_depth: Option<usize> },
93}
94
95#[derive(Debug, Clone)]
97pub enum EnsembleTrainingStrategy {
98 FullDataset,
100 Bagging { n_samples: usize, replacement: bool },
102 CrossValidation { n_folds: usize, stratified: bool },
104 Stacking { holdout_ratio: f64 },
106}
107
108#[derive(Debug, Clone)]
110pub struct CrossValidationConfig {
111 pub n_folds: usize,
113 pub stratified: bool,
115 pub random_seed: Option<u64>,
117 pub shuffle: bool,
119}
120
121#[derive(Debug, Clone)]
123pub struct FeatureSelectionConfig {
124 pub methods: Vec<FeatureSelectionMethod>,
126 pub n_features: Option<usize>,
128 pub threshold: Option<f64>,
130 pub cross_validate: bool,
132}
133
134#[derive(Debug, Clone)]
136pub enum FeatureSelectionMethod {
137 UnivariateSelection { score_func: ScoreFunction },
139 RecursiveElimination {
141 estimator: String, },
143 L1BasedSelection { alpha: f64 },
145 TreeBasedSelection { importance_threshold: f64 },
147 MutualInformation,
149 PCA { n_components: usize },
151}
152
153#[derive(Debug, Clone, Copy)]
155pub enum ScoreFunction {
156 FClassif,
158 Chi2,
160 MutualInfoClassif,
162 FRegression,
164 MutualInfoRegression,
166}