Skip to main content

tensorlogic_train/callbacks/
profiling.rs

1//! Performance profiling callbacks for training.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5
6/// Performance profiling statistics.
7#[derive(Debug, Clone, Default)]
8pub struct ProfilingStats {
9    /// Total training time (seconds).
10    pub total_time: f64,
11    /// Time per epoch (seconds).
12    pub epoch_times: Vec<f64>,
13    /// Samples per second.
14    pub samples_per_sec: f64,
15    /// Batches per second.
16    pub batches_per_sec: f64,
17    /// Average batch time (seconds).
18    pub avg_batch_time: f64,
19    /// Peak memory usage (MB) - placeholder.
20    pub peak_memory_mb: f64,
21}
22
23impl ProfilingStats {
24    /// Pretty print profiling statistics.
25    pub fn display(&self) {
26        println!("\n=== Profiling Statistics ===");
27        println!("Total time: {:.2}s", self.total_time);
28        println!("Samples/sec: {:.2}", self.samples_per_sec);
29        println!("Batches/sec: {:.2}", self.batches_per_sec);
30        println!("Avg batch time: {:.4}s", self.avg_batch_time);
31
32        if !self.epoch_times.is_empty() {
33            let avg_epoch = self.epoch_times.iter().sum::<f64>() / self.epoch_times.len() as f64;
34            let min_epoch = self
35                .epoch_times
36                .iter()
37                .copied()
38                .fold(f64::INFINITY, f64::min);
39            let max_epoch = self
40                .epoch_times
41                .iter()
42                .copied()
43                .fold(f64::NEG_INFINITY, f64::max);
44
45            println!("\nEpoch times:");
46            println!("  Average: {:.2}s", avg_epoch);
47            println!("  Min: {:.2}s", min_epoch);
48            println!("  Max: {:.2}s", max_epoch);
49        }
50    }
51}
52
53/// Callback for profiling training performance.
54///
55/// Tracks timing information and throughput metrics during training.
56/// Useful for:
57/// - Identifying performance bottlenecks
58/// - Comparing different configurations
59/// - Monitoring training speed
60/// - Resource utilization tracking
61///
62/// # Example
63///
64/// ```no_run
65/// use tensorlogic_train::{CallbackList, ProfilingCallback};
66///
67/// let mut callbacks = CallbackList::new();
68/// callbacks.add(Box::new(ProfilingCallback::new(
69///     true,  // verbose: Print detailed stats
70///     5,     // log_frequency: Every 5 epochs
71/// )));
72/// ```
73pub struct ProfilingCallback {
74    /// Whether to print detailed profiling info.
75    verbose: bool,
76    /// Frequency of logging (every N epochs).
77    log_frequency: usize,
78    /// Training start time.
79    start_time: Option<std::time::Instant>,
80    /// Last epoch start time.
81    epoch_start_time: Option<std::time::Instant>,
82    /// Batch start time.
83    batch_start_time: Option<std::time::Instant>,
84    /// Accumulated statistics.
85    pub stats: ProfilingStats,
86    /// Batch times for current epoch.
87    current_epoch_batch_times: Vec<f64>,
88    /// Total batches processed.
89    total_batches: usize,
90}
91
92impl ProfilingCallback {
93    /// Create a new profiling callback.
94    ///
95    /// # Arguments
96    /// * `verbose` - Print detailed profiling information
97    /// * `log_frequency` - Log stats every N epochs
98    pub fn new(verbose: bool, log_frequency: usize) -> Self {
99        Self {
100            verbose,
101            log_frequency,
102            start_time: None,
103            epoch_start_time: None,
104            batch_start_time: None,
105            stats: ProfilingStats::default(),
106            current_epoch_batch_times: Vec::new(),
107            total_batches: 0,
108        }
109    }
110
111    /// Get profiling statistics.
112    pub fn get_stats(&self) -> &ProfilingStats {
113        &self.stats
114    }
115}
116
117impl Callback for ProfilingCallback {
118    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
119        self.start_time = Some(std::time::Instant::now());
120        if self.verbose {
121            println!("Profiling started");
122        }
123        Ok(())
124    }
125
126    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
127        if let Some(start) = self.start_time {
128            self.stats.total_time = start.elapsed().as_secs_f64();
129
130            // Compute aggregate statistics
131            if self.total_batches > 0 {
132                self.stats.avg_batch_time = self.stats.total_time / self.total_batches as f64;
133                self.stats.batches_per_sec = self.total_batches as f64 / self.stats.total_time;
134            }
135
136            if self.verbose {
137                println!("\nProfiling completed");
138                self.stats.display();
139            }
140        }
141        Ok(())
142    }
143
144    fn on_epoch_begin(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
145        self.epoch_start_time = Some(std::time::Instant::now());
146        self.current_epoch_batch_times.clear();
147
148        if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
149            println!("\nEpoch {} profiling started", epoch + 1);
150        }
151        Ok(())
152    }
153
154    fn on_epoch_end(&mut self, epoch: usize, _state: &TrainingState) -> TrainResult<()> {
155        if let Some(epoch_start) = self.epoch_start_time {
156            let epoch_time = epoch_start.elapsed().as_secs_f64();
157            self.stats.epoch_times.push(epoch_time);
158
159            if self.verbose && (epoch + 1).is_multiple_of(self.log_frequency) {
160                let avg_batch = if !self.current_epoch_batch_times.is_empty() {
161                    self.current_epoch_batch_times.iter().sum::<f64>()
162                        / self.current_epoch_batch_times.len() as f64
163                } else {
164                    0.0
165                };
166
167                println!("Epoch {} completed:", epoch + 1);
168                println!("    Time: {:.2}s", epoch_time);
169                println!(
170                    "    Batches: {} ({:.4}s avg)",
171                    self.current_epoch_batch_times.len(),
172                    avg_batch
173                );
174            }
175        }
176        Ok(())
177    }
178
179    fn on_batch_begin(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
180        self.batch_start_time = Some(std::time::Instant::now());
181        Ok(())
182    }
183
184    fn on_batch_end(&mut self, _batch: usize, _state: &TrainingState) -> TrainResult<()> {
185        if let Some(batch_start) = self.batch_start_time {
186            let batch_time = batch_start.elapsed().as_secs_f64();
187            self.current_epoch_batch_times.push(batch_time);
188            self.total_batches += 1;
189        }
190        Ok(())
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use std::collections::HashMap;
198
199    #[test]
200    fn test_profiling_callback() {
201        let mut callback = ProfilingCallback::new(false, 1);
202        let state = TrainingState {
203            epoch: 0,
204            batch: 0,
205            train_loss: 0.5,
206            batch_loss: 0.5,
207            val_loss: Some(0.6),
208            learning_rate: 0.01,
209            metrics: HashMap::new(),
210        };
211
212        callback.on_train_begin(&state).unwrap();
213        assert!(callback.start_time.is_some());
214
215        callback.on_epoch_begin(0, &state).unwrap();
216        assert!(callback.epoch_start_time.is_some());
217
218        callback.on_batch_begin(0, &state).unwrap();
219        std::thread::sleep(std::time::Duration::from_millis(10));
220        callback.on_batch_end(0, &state).unwrap();
221
222        assert_eq!(callback.total_batches, 1);
223        assert_eq!(callback.current_epoch_batch_times.len(), 1);
224
225        callback.on_epoch_end(0, &state).unwrap();
226        assert_eq!(callback.stats.epoch_times.len(), 1);
227
228        callback.on_train_end(&state).unwrap();
229        assert!(callback.stats.total_time > 0.0);
230    }
231}