scirs2_integrate/analysis/ml_prediction/
ensemble.rs

1//! Ensemble Learning for Bifurcation Classification
2//!
3//! This module contains ensemble learning methods, feature selection,
4//! and cross-validation configurations for bifurcation prediction.
5
6use scirs2_core::ndarray::Array1;
7
8/// Advanced ensemble learning for bifurcation classification
9#[derive(Debug, Clone)]
10pub struct BifurcationEnsembleClassifier {
11    /// Individual classifiers in the ensemble
12    pub base_classifiers: Vec<BaseClassifier>,
13    /// Meta-learner for ensemble combination
14    pub meta_learner: Option<MetaLearner>,
15    /// Ensemble training strategy
16    pub training_strategy: EnsembleTrainingStrategy,
17    /// Cross-validation configuration
18    pub cross_validation: CrossValidationConfig,
19    /// Feature selection methods
20    pub feature_selection: FeatureSelectionConfig,
21}
22
23/// Base classifier types for ensemble
24#[derive(Debug, Clone)]
25pub enum BaseClassifier {
26    /// Neural network classifier
27    NeuralNetwork(Box<super::neural_network::BifurcationPredictionNetwork>),
28    /// Random forest classifier
29    RandomForest {
30        n_trees: usize,
31        max_depth: Option<usize>,
32        min_samples_split: usize,
33        min_samples_leaf: usize,
34    },
35    /// Support Vector Machine
36    SVM {
37        kernel: SVMKernel,
38        c_parameter: f64,
39        gamma: Option<f64>,
40    },
41    /// Gradient boosting classifier
42    GradientBoosting {
43        n_estimators: usize,
44        learning_rate: f64,
45        max_depth: usize,
46        subsample: f64,
47    },
48    /// K-Nearest Neighbors
49    KNN {
50        n_neighbors: usize,
51        weights: KNNWeights,
52        distance_metric: DistanceMetric,
53    },
54}
55
56/// SVM kernel types
57#[derive(Debug, Clone, Copy)]
58pub enum SVMKernel {
59    Linear,
60    RBF,
61    Polynomial(usize), // degree
62    Sigmoid,
63}
64
65/// KNN weight functions
66#[derive(Debug, Clone, Copy)]
67pub enum KNNWeights {
68    Uniform,
69    Distance,
70}
71
72/// Distance metrics for KNN
73#[derive(Debug, Clone, Copy)]
74pub enum DistanceMetric {
75    Euclidean,
76    Manhattan,
77    Minkowski(f64), // p parameter
78    Cosine,
79    Hamming,
80}
81
82/// Meta-learner for ensemble combination
83#[derive(Debug, Clone)]
84pub enum MetaLearner {
85    /// Linear combination
86    LinearCombination { weights: Array1<f64> },
87    /// Logistic regression meta-learner
88    LogisticRegression { regularization: f64 },
89    /// Neural network meta-learner
90    NeuralNetwork { hidden_layers: Vec<usize> },
91    /// Decision tree meta-learner
92    DecisionTree { max_depth: Option<usize> },
93}
94
95/// Ensemble training strategies
96#[derive(Debug, Clone)]
97pub enum EnsembleTrainingStrategy {
98    /// Train all models on full dataset
99    FullDataset,
100    /// Bootstrap aggregating (bagging)
101    Bagging { n_samples: usize, replacement: bool },
102    /// Cross-validation based training
103    CrossValidation { n_folds: usize, stratified: bool },
104    /// Stacking with holdout validation
105    Stacking { holdout_ratio: f64 },
106}
107
108/// Cross-validation configuration
109#[derive(Debug, Clone)]
110pub struct CrossValidationConfig {
111    /// Number of folds
112    pub n_folds: usize,
113    /// Use stratified CV
114    pub stratified: bool,
115    /// Random seed for reproducibility
116    pub random_seed: Option<u64>,
117    /// Shuffle data before splitting
118    pub shuffle: bool,
119}
120
121/// Feature selection configuration
122#[derive(Debug, Clone)]
123pub struct FeatureSelectionConfig {
124    /// Feature selection methods to apply
125    pub methods: Vec<FeatureSelectionMethod>,
126    /// Number of features to select
127    pub n_features: Option<usize>,
128    /// Selection threshold
129    pub threshold: Option<f64>,
130    /// Cross-validation for feature selection
131    pub cross_validate: bool,
132}
133
134/// Feature selection methods
135#[derive(Debug, Clone)]
136pub enum FeatureSelectionMethod {
137    /// Univariate statistical tests
138    UnivariateSelection { score_func: ScoreFunction },
139    /// Recursive feature elimination
140    RecursiveElimination {
141        estimator: String, // estimator type
142    },
143    /// L1-based selection (Lasso)
144    L1BasedSelection { alpha: f64 },
145    /// Tree-based feature importance
146    TreeBasedSelection { importance_threshold: f64 },
147    /// Mutual information
148    MutualInformation,
149    /// Principal component analysis
150    PCA { n_components: usize },
151}
152
153/// Statistical score functions for feature selection
154#[derive(Debug, Clone, Copy)]
155pub enum ScoreFunction {
156    /// F-statistic for classification
157    FClassif,
158    /// Chi-squared test
159    Chi2,
160    /// Mutual information for classification
161    MutualInfoClassif,
162    /// F-statistic for regression
163    FRegression,
164    /// Mutual information for regression
165    MutualInfoRegression,
166}