scirs2_neural/training/
mod.rs

1//! Training utilities and infrastructure
2//!
3//! This module provides comprehensive utilities for training neural networks,
4//! including advanced features like gradient accumulation, mixed precision training,
5//! distributed training, and sophisticated training loop management.
6
7use scirs2_core::ndarray::ScalarOperand;
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12// Re-export submodules
13pub mod gradient_accumulation;
14pub mod gradient_checkpointing;
15pub mod mixed_precision;
16pub mod progress_monitor;
17pub mod quantization_aware;
18pub mod sparse_training;
19
20pub use gradient_accumulation::*;
21pub use gradient_checkpointing::*;
22pub use mixed_precision::*;
23pub use progress_monitor::*;
24pub use quantization_aware::*;
25pub use sparse_training::*;
26
27/// Configuration structure for training neural networks
28#[derive(Debug, Clone)]
29pub struct TrainingConfig {
30    /// Number of samples in each training batch
31    pub batch_size: usize,
32    /// Whether to shuffle the training data between epochs
33    pub shuffle: bool,
34    /// Number of parallel workers for data loading
35    pub num_workers: usize,
36    /// Base learning rate for the optimizer
37    pub learning_rate: f64,
38    /// Number of complete passes through the training dataset
39    pub epochs: usize,
40    /// Verbosity level for training output
41    pub verbose: usize,
42    /// Validation configuration
43    pub validation: Option<ValidationSettings>,
44    /// Gradient accumulation configuration
45    pub gradient_accumulation: Option<GradientAccumulationConfig>,
46    /// Mixed precision training configuration
47    pub mixed_precision: Option<MixedPrecisionConfig>,
48}
49
50impl Default for TrainingConfig {
51    fn default() -> Self {
52        Self {
53            batch_size: 32,
54            shuffle: true,
55            num_workers: 0,
56            learning_rate: 0.001,
57            epochs: 10,
58            verbose: 1,
59            validation: None,
60            gradient_accumulation: None,
61            mixed_precision: None,
62        }
63    }
64}
65
66/// Configuration for validation during training
67#[derive(Debug, Clone)]
68pub struct ValidationSettings {
69    /// Whether to enable validation during training
70    pub enabled: bool,
71    /// Fraction of training data to use for validation (0.0 to 1.0)
72    pub validation_split: f64,
73    /// Batch size for validation
74    pub batch_size: usize,
75    /// Number of parallel workers for validation data loading
76    pub num_workers: usize,
77}
78
79impl Default for ValidationSettings {
80    fn default() -> Self {
81        Self {
82            enabled: true,
83            validation_split: 0.2,
84            batch_size: 32,
85            num_workers: 0,
86        }
87    }
88}
89
90/// Training session for tracking training history
91#[derive(Debug, Clone)]
92pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
93    /// Training metrics history
94    pub history: HashMap<String, Vec<F>>,
95    /// Initial learning rate
96    pub initial_learning_rate: F,
97    /// Number of epochs trained
98    pub epochs_trained: usize,
99    /// Current epoch number
100    pub current_epoch: usize,
101    /// Best validation score achieved
102    pub best_validation_score: Option<F>,
103    /// Whether training has been stopped early
104    pub early_stopped: bool,
105}
106
107impl<F: Float + Debug + ScalarOperand> TrainingSession<F> {
108    /// Create a new training session
109    pub fn new(config: TrainingConfig) -> Self {
110        Self {
111            history: HashMap::new(),
112            initial_learning_rate: F::from(config.learning_rate).unwrap(),
113            epochs_trained: 0,
114            current_epoch: 0,
115            best_validation_score: None,
116            early_stopped: false,
117        }
118    }
119
120    /// Add a metric value to the history
121    pub fn add_metric(&mut self, metricname: &str, value: F) {
122        self.history
123            .entry(metricname.to_string())
124            .or_default()
125            .push(value);
126    }
127
128    /// Get the history for a specific metric
129    pub fn get_metric_history(&self, metricname: &str) -> Option<&Vec<F>> {
130        self.history.get(metricname)
131    }
132
133    /// Get all metric names
134    pub fn get_metric_names(&self) -> Vec<&String> {
135        self.history.keys().collect()
136    }
137
138    /// Update the current epoch
139    pub fn next_epoch(&mut self) {
140        self.current_epoch += 1;
141        self.epochs_trained += 1;
142    }
143
144    /// Mark training as completed
145    pub fn finish_training(&mut self) {
146        // Training completed normally
147    }
148
149    /// Mark training as early stopped
150    pub fn early_stop(&mut self) {
151        self.early_stopped = true;
152    }
153}
154
155impl<F: Float + Debug + ScalarOperand> Default for TrainingSession<F> {
156    fn default() -> Self {
157        Self {
158            history: HashMap::new(),
159            initial_learning_rate: F::from(0.001).unwrap(),
160            epochs_trained: 0,
161            current_epoch: 0,
162            best_validation_score: None,
163            early_stopped: false,
164        }
165    }
166}