tensorlogic_train/callbacks/
profiling.rs1use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6#[derive(Debug, Clone, Default)]
8pub struct ProfilingStats {
9 pub total_time: f64,
11 pub epoch_times: Vec<f64>,
13 pub samples_per_sec: f64,
15 pub batches_per_sec: f64,
17 pub avg_batch_time: f64,
19 pub peak_memory_mb: f64,
21}
22
23impl ProfilingStats {
24 pub fn display(&self) {
26 println!("\n=== Profiling Statistics ===");
27 println!("Total time: {:.2}s", self.total_time);
28 println!("Samples/sec: {:.2}", self.samples_per_sec);
29 println!("Batches/sec: {:.2}", self.batches_per_sec);
30 println!("Avg batch time: {:.4}s", self.avg_batch_time);
31
32 if !self.epoch_times.is_empty() {
33 let avg_epoch = self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64;
34 let min_epoch = self
35 .epoch_times
36 .iter()
37 .copied()
38 .fold(f64::INFINITY, f64::min);
39 let max_epoch = self
40 .epoch_times
41 .iter()
42 .copied()
43 .fold(f64::NEG_INFINITY, f64::max);
44
45 println!("\nEpoch times:");
46 println!(" Average: {:.2}s", avg_epoch);
47 println!(" Min: {:.2}s", min_epoch);
48 println!(" Max: {:.2}s", max_epoch);
49 }
50 }
51}
52
53pub struct ProfilingCallback {
74 verbose: bool,
76 log_frequency: usize,
78 start_time: Option<std::time::Instant>,
80 epoch_start_time: Option<std::time::Instant>,
82 batch_start_time: Option<std::time::Instant>,
84 pub stats: ProfilingStats,
86 current_epoch_batch_times: Vec<f64>,
88 total_batches: usize,
90}
91
92impl ProfilingCallback {
93 pub fn new(verbose: bool, log_frequency: usize) -> Self {
99 Self {
100 verbose,
101 log_frequency,
102 start_time: None,
103 epoch_start_time: None,
104 batch_start_time: None,
105 stats: ProfilingStats::default(),
106 current_epoch_batch_times: Vec::new(),
107 total_batches: 0,
108 }
109 }
110
111 pub fn get_stats(&self) -> &ProfilingStats {
113 &self.stats
114 }
115}
116
117impl Callback for ProfilingCallback {
118 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
119 self.start_time = Some(std::time::Instant::now());
120 if self.verbose {
121 println!("Profiling started");
122 }
123 Ok(())
124 }
125
126 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
127 if let Some(start) = self.start_time {
128 self.stats.total_time = start.elapsed().as_secs_f64();
129
130 if self.total_batches > 0 {
132 self.stats.avg_batch_time = self.stats.total_time / self.total_batches as f64;
133 self.stats.batches_per_sec = self.total_batches as f64 / self.stats.total_time;
134 }
135
136 if self.verbose {
137 println!("\nProfiling completed");
138 self.stats.display();
139 }
140 }
141 Ok(())
142 }
143
144 fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
145 self.epoch_start_time = Some(std::time::Instant::now());
146 self.current_epoch_batch_times.clear();
147
148 if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
149 println!("\nEpoch {} profiling started", epoch + 1);
150 }
151 Ok(())
152 }
153
154 fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
155 if let Some(epoch_start) = self.epoch_start_time {
156 let epoch_time = epoch_start.elapsed().as_secs_f64();
157 self.stats.epoch_times.push(epoch_time);
158
159 if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
160 let avg_batch = if !self.current_epoch_batch_times.is_empty() {
161 self.current_epoch_batch_times.iter().sum::<f64>()
162 / self.current_epoch_batch_times.len() as f64
163 } else {
164 0.0
165 };
166
167 println!("Epoch {} completed:", epoch + 1);
168 println!(" Time: {:.2}s", epoch_time);
169 println!(
170 " Batches: {} ({:.4}s avg)",
171 self.current_epoch_batch_times.len(),
172 avg_batch
173 );
174 }
175 }
176 Ok(())
177 }
178
179 fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
180 self.batch_start_time = Some(std::time::Instant::now());
181 Ok(())
182 }
183
184 fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
185 if let Some(batch_start) = self.batch_start_time {
186 let batch_time = batch_start.elapsed().as_secs_f64();
187 self.current_epoch_batch_times.push(batch_time);
188 self.total_batches += 1;
189 }
190 Ok(())
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use std::collections::HashMap;
198
199 #[test]
200 fn test_profiling_callback() {
201 let mut callback = ProfilingCallback::new(false, 1);
202 let state = TrainingState {
203 epoch: 0,
204 batch: 0,
205 train_loss: 0.5,
206 batch_loss: 0.5,
207 val_loss: Some(0.6),
208 learning_rate: 0.01,
209 metrics: HashMap::new(),
210 };
211
212 callback.on_train_begin(&state).unwrap();
213 assert!(callback.start_time.is_some());
214
215 callback.on_epoch_begin(0, &state).unwrap();
216 assert!(callback.epoch_start_time.is_some());
217
218 callback.on_batch_begin(0, &state).unwrap();
219 std::thread::sleep(std::time::Duration::from_millis(10));
220 callback.on_batch_end(0, &state).unwrap();
221
222 assert_eq!(callback.total_batches, 1);
223 assert_eq!(callback.current_epoch_batch_times.len(), 1);
224
225 callback.on_epoch_end(0, &state).unwrap();
226 assert_eq!(callback.stats.epoch_times.len(), 1);
227
228 callback.on_train_end(&state).unwrap();
229 assert!(callback.stats.total_time > 0.0);
230 }
231}