scirs2_neural/training/
mod.rs1use scirs2_core::ndarray::ScalarOperand;
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12pub mod gradient_accumulation;
14pub mod gradient_checkpointing;
15pub mod mixed_precision;
16pub mod progress_monitor;
17pub mod quantization_aware;
18pub mod sparse_training;
19
20pub use gradient_accumulation::*;
21pub use gradient_checkpointing::*;
22pub use mixed_precision::*;
23pub use progress_monitor::*;
24pub use quantization_aware::*;
25pub use sparse_training::*;
26
27#[derive(Debug, Clone)]
29pub struct TrainingConfig {
30 pub batch_size: usize,
32 pub shuffle: bool,
34 pub num_workers: usize,
36 pub learning_rate: f64,
38 pub epochs: usize,
40 pub verbose: usize,
42 pub validation: Option<ValidationSettings>,
44 pub gradient_accumulation: Option<GradientAccumulationConfig>,
46 pub mixed_precision: Option<MixedPrecisionConfig>,
48}
49
50impl Default for TrainingConfig {
51 fn default() -> Self {
52 Self {
53 batch_size: 32,
54 shuffle: true,
55 num_workers: 0,
56 learning_rate: 0.001,
57 epochs: 10,
58 verbose: 1,
59 validation: None,
60 gradient_accumulation: None,
61 mixed_precision: None,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ValidationSettings {
69 pub enabled: bool,
71 pub validation_split: f64,
73 pub batch_size: usize,
75 pub num_workers: usize,
77}
78
79impl Default for ValidationSettings {
80 fn default() -> Self {
81 Self {
82 enabled: true,
83 validation_split: 0.2,
84 batch_size: 32,
85 num_workers: 0,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
93 pub history: HashMap<String, Vec<F>>,
95 pub initial_learning_rate: F,
97 pub epochs_trained: usize,
99 pub current_epoch: usize,
101 pub best_validation_score: Option<F>,
103 pub early_stopped: bool,
105}
106
107impl<F: Float + Debug + ScalarOperand> TrainingSession<F> {
108 pub fn new(config: TrainingConfig) -> Self {
110 Self {
111 history: HashMap::new(),
112 initial_learning_rate: F::from(config.learning_rate).unwrap(),
113 epochs_trained: 0,
114 current_epoch: 0,
115 best_validation_score: None,
116 early_stopped: false,
117 }
118 }
119
120 pub fn add_metric(&mut self, metricname: &str, value: F) {
122 self.history
123 .entry(metricname.to_string())
124 .or_default()
125 .push(value);
126 }
127
128 pub fn get_metric_history(&self, metricname: &str) -> Option<&Vec<F>> {
130 self.history.get(metricname)
131 }
132
133 pub fn get_metric_names(&self) -> Vec<&String> {
135 self.history.keys().collect()
136 }
137
138 pub fn next_epoch(&mut self) {
140 self.current_epoch += 1;
141 self.epochs_trained += 1;
142 }
143
144 pub fn finish_training(&mut self) {
146 }
148
149 pub fn early_stop(&mut self) {
151 self.early_stopped = true;
152 }
153}
154
155impl<F: Float + Debug + ScalarOperand> Default for TrainingSession<F> {
156 fn default() -> Self {
157 Self {
158 history: HashMap::new(),
159 initial_learning_rate: F::from(0.001).unwrap(),
160 epochs_trained: 0,
161 current_epoch: 0,
162 best_validation_score: None,
163 early_stopped: false,
164 }
165 }
166}