scirs2_integrate/analysis/ml_prediction/
training.rs

1//! Training Configuration for Neural Networks
2//!
3//! This module contains training configuration, optimization algorithms,
4//! learning rate schedules, and regularization techniques.
5
6/// Training configuration
7#[derive(Debug, Clone)]
8pub struct TrainingConfiguration {
9    /// Learning rate schedule
10    pub learning_rate: LearningRateSchedule,
11    /// Optimization algorithm
12    pub optimizer: Optimizer,
13    /// Loss function
14    pub loss_function: LossFunction,
15    /// Regularization techniques
16    pub regularization: RegularizationConfig,
17    /// Training batch size
18    pub batch_size: usize,
19    /// Number of training epochs
20    pub epochs: usize,
21    /// Validation split ratio
22    pub validation_split: f64,
23    /// Early stopping configuration
24    pub early_stopping: EarlyStoppingConfig,
25}
26
27/// Learning rate scheduling strategies
28#[derive(Debug, Clone)]
29pub enum LearningRateSchedule {
30    /// Constant learning rate
31    Constant(f64),
32    /// Exponential decay
33    ExponentialDecay {
34        initial_lr: f64,
35        decay_rate: f64,
36        decay_steps: usize,
37    },
38    /// Cosine annealing
39    CosineAnnealing {
40        initial_lr: f64,
41        min_lr: f64,
42        cycle_length: usize,
43    },
44    /// Step decay
45    StepDecay {
46        initial_lr: f64,
47        drop_rate: f64,
48        epochs_drop: usize,
49    },
50    /// Adaptive learning rate
51    Adaptive {
52        initial_lr: f64,
53        patience: usize,
54        factor: f64,
55    },
56}
57
58/// Optimization algorithms
59#[derive(Debug, Clone)]
60pub enum Optimizer {
61    /// Stochastic Gradient Descent
62    SGD { momentum: f64, nesterov: bool },
63    /// Adam optimizer
64    Adam {
65        beta1: f64,
66        beta2: f64,
67        epsilon: f64,
68    },
69    /// AdamW (Adam with weight decay)
70    AdamW {
71        beta1: f64,
72        beta2: f64,
73        epsilon: f64,
74        weight_decay: f64,
75    },
76    /// RMSprop optimizer
77    RMSprop { alpha: f64, epsilon: f64 },
78    /// AdaGrad optimizer
79    AdaGrad { epsilon: f64 },
80}
81
82/// Loss function types
83#[derive(Debug, Clone, Copy)]
84pub enum LossFunction {
85    /// Mean Squared Error (for regression)
86    MSE,
87    /// Cross-entropy (for classification)
88    CrossEntropy,
89    /// Focal loss (for imbalanced classification)
90    FocalLoss(f64, f64), // alpha, gamma
91    /// Huber loss (robust regression)
92    HuberLoss(f64), // delta
93    /// Custom weighted loss
94    WeightedMSE,
95}
96
97/// Regularization configuration
98#[derive(Debug, Clone)]
99pub struct RegularizationConfig {
100    /// L1 regularization strength
101    pub l1_lambda: f64,
102    /// L2 regularization strength
103    pub l2_lambda: f64,
104    /// Dropout probability
105    pub dropout_prob: f64,
106    /// Data augmentation techniques
107    pub data_augmentation: Vec<DataAugmentation>,
108    /// Label smoothing factor
109    pub label_smoothing: f64,
110}
111
112/// Data augmentation techniques
113#[derive(Debug, Clone)]
114pub enum DataAugmentation {
115    /// Add Gaussian noise
116    GaussianNoise(f64), // standard deviation
117    /// Time shift augmentation
118    TimeShift(f64), // maximum shift ratio
119    /// Scaling augmentation
120    Scaling(f64, f64), // min_scale, max_scale
121    /// Feature permutation
122    FeaturePermutation,
123    /// Mixup augmentation
124    Mixup(f64), // alpha parameter
125}
126
127/// Early stopping configuration
128#[derive(Debug, Clone)]
129pub struct EarlyStoppingConfig {
130    /// Enable early stopping
131    pub enabled: bool,
132    /// Metric to monitor
133    pub monitor: String,
134    /// Minimum change to qualify as improvement
135    pub min_delta: f64,
136    /// Number of epochs with no improvement to stop
137    pub patience: usize,
138    /// Whether higher metric values are better
139    pub maximize: bool,
140}
141
142impl Default for TrainingConfiguration {
143    fn default() -> Self {
144        Self {
145            learning_rate: LearningRateSchedule::Constant(0.001),
146            optimizer: Optimizer::Adam {
147                beta1: 0.9,
148                beta2: 0.999,
149                epsilon: 1e-8,
150            },
151            loss_function: LossFunction::CrossEntropy,
152            regularization: RegularizationConfig::default(),
153            batch_size: 32,
154            epochs: 100,
155            validation_split: 0.2,
156            early_stopping: EarlyStoppingConfig::default(),
157        }
158    }
159}
160
161impl Default for RegularizationConfig {
162    fn default() -> Self {
163        Self {
164            l1_lambda: 0.0,
165            l2_lambda: 0.001,
166            dropout_prob: 0.1,
167            data_augmentation: Vec::new(),
168            label_smoothing: 0.0,
169        }
170    }
171}
172
173impl Default for EarlyStoppingConfig {
174    fn default() -> Self {
175        Self {
176            enabled: true,
177            monitor: "val_loss".to_string(),
178            min_delta: 1e-4,
179            patience: 10,
180            maximize: false,
181        }
182    }
183}