scirs2_integrate/analysis/ml_prediction/
training.rs1#[derive(Debug, Clone)]
8pub struct TrainingConfiguration {
9 pub learning_rate: LearningRateSchedule,
11 pub optimizer: Optimizer,
13 pub loss_function: LossFunction,
15 pub regularization: RegularizationConfig,
17 pub batch_size: usize,
19 pub epochs: usize,
21 pub validation_split: f64,
23 pub early_stopping: EarlyStoppingConfig,
25}
26
27#[derive(Debug, Clone)]
29pub enum LearningRateSchedule {
30 Constant(f64),
32 ExponentialDecay {
34 initial_lr: f64,
35 decay_rate: f64,
36 decay_steps: usize,
37 },
38 CosineAnnealing {
40 initial_lr: f64,
41 min_lr: f64,
42 cycle_length: usize,
43 },
44 StepDecay {
46 initial_lr: f64,
47 drop_rate: f64,
48 epochs_drop: usize,
49 },
50 Adaptive {
52 initial_lr: f64,
53 patience: usize,
54 factor: f64,
55 },
56}
57
58#[derive(Debug, Clone)]
60pub enum Optimizer {
61 SGD { momentum: f64, nesterov: bool },
63 Adam {
65 beta1: f64,
66 beta2: f64,
67 epsilon: f64,
68 },
69 AdamW {
71 beta1: f64,
72 beta2: f64,
73 epsilon: f64,
74 weight_decay: f64,
75 },
76 RMSprop { alpha: f64, epsilon: f64 },
78 AdaGrad { epsilon: f64 },
80}
81
82#[derive(Debug, Clone, Copy)]
84pub enum LossFunction {
85 MSE,
87 CrossEntropy,
89 FocalLoss(f64, f64), HuberLoss(f64), WeightedMSE,
95}
96
97#[derive(Debug, Clone)]
99pub struct RegularizationConfig {
100 pub l1_lambda: f64,
102 pub l2_lambda: f64,
104 pub dropout_prob: f64,
106 pub data_augmentation: Vec<DataAugmentation>,
108 pub label_smoothing: f64,
110}
111
112#[derive(Debug, Clone)]
114pub enum DataAugmentation {
115 GaussianNoise(f64), TimeShift(f64), Scaling(f64, f64), FeaturePermutation,
123 Mixup(f64), }
126
127#[derive(Debug, Clone)]
129pub struct EarlyStoppingConfig {
130 pub enabled: bool,
132 pub monitor: String,
134 pub min_delta: f64,
136 pub patience: usize,
138 pub maximize: bool,
140}
141
142impl Default for TrainingConfiguration {
143 fn default() -> Self {
144 Self {
145 learning_rate: LearningRateSchedule::Constant(0.001),
146 optimizer: Optimizer::Adam {
147 beta1: 0.9,
148 beta2: 0.999,
149 epsilon: 1e-8,
150 },
151 loss_function: LossFunction::CrossEntropy,
152 regularization: RegularizationConfig::default(),
153 batch_size: 32,
154 epochs: 100,
155 validation_split: 0.2,
156 early_stopping: EarlyStoppingConfig::default(),
157 }
158 }
159}
160
161impl Default for RegularizationConfig {
162 fn default() -> Self {
163 Self {
164 l1_lambda: 0.0,
165 l2_lambda: 0.001,
166 dropout_prob: 0.1,
167 data_augmentation: Vec::new(),
168 label_smoothing: 0.0,
169 }
170 }
171}
172
173impl Default for EarlyStoppingConfig {
174 fn default() -> Self {
175 Self {
176 enabled: true,
177 monitor: "val_loss".to_string(),
178 min_delta: 1e-4,
179 patience: 10,
180 maximize: false,
181 }
182 }
183}