Skip to main content

tensorlogic_train/callbacks/
histogram.rs

1//! Histogram monitoring callbacks for weight distributions.
2
3use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5use std::collections::HashMap;
6
7/// Weight histogram statistics for debugging and monitoring.
8#[derive(Debug, Clone)]
9pub struct HistogramStats {
10    /// Parameter name.
11    pub name: String,
12    /// Minimum value.
13    pub min: f64,
14    /// Maximum value.
15    pub max: f64,
16    /// Mean value.
17    pub mean: f64,
18    /// Standard deviation.
19    pub std: f64,
20    /// Histogram bins (boundaries).
21    pub bins: Vec<f64>,
22    /// Histogram counts per bin.
23    pub counts: Vec<usize>,
24}
25
26impl HistogramStats {
27    /// Compute histogram statistics from parameter values.
28    pub fn compute(name: &str, values: &[f64], num_bins: usize) -> Self {
29        if values.is_empty() {
30            return Self {
31                name: name.to_string(),
32                min: 0.0,
33                max: 0.0,
34                mean: 0.0,
35                std: 0.0,
36                bins: vec![],
37                counts: vec![],
38            };
39        }
40
41        // Basic statistics
42        let min = values.iter().copied().fold(f64::INFINITY, f64::min);
43        let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
44        let sum: f64 = values.iter().sum();
45        let mean = sum / values.len() as f64;
46
47        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
48        let std = variance.sqrt();
49
50        // Create histogram bins
51        let mut bins = Vec::with_capacity(num_bins + 1);
52        let mut counts = vec![0; num_bins];
53
54        let range = max - min;
55        let bin_width = if range > 0.0 {
56            range / num_bins as f64
57        } else {
58            1.0
59        };
60
61        for i in 0..=num_bins {
62            bins.push(min + i as f64 * bin_width);
63        }
64
65        // Count values in each bin
66        for &value in values {
67            let bin_idx = if range > 0.0 {
68                ((value - min) / bin_width).floor() as usize
69            } else {
70                0
71            };
72            let bin_idx = bin_idx.min(num_bins - 1);
73            counts[bin_idx] += 1;
74        }
75
76        Self {
77            name: name.to_string(),
78            min,
79            max,
80            mean,
81            std,
82            bins,
83            counts,
84        }
85    }
86
87    /// Pretty print histogram as ASCII art.
88    pub fn display(&self, width: usize) {
89        println!("\n=== Histogram: {} ===", self.name);
90        println!("  Min: {:.6}, Max: {:.6}", self.min, self.max);
91        println!("  Mean: {:.6}, Std: {:.6}", self.mean, self.std);
92        println!("\n  Distribution:");
93
94        if self.counts.is_empty() {
95            println!("    (empty)");
96            return;
97        }
98
99        let max_count = *self.counts.iter().max().unwrap_or(&1);
100
101        for (i, &count) in self.counts.iter().enumerate() {
102            let bar_len = if max_count > 0 {
103                (count as f64 / max_count as f64 * width as f64) as usize
104            } else {
105                0
106            };
107
108            let bar = "█".repeat(bar_len);
109            let left = if i < self.bins.len() - 1 {
110                self.bins[i]
111            } else {
112                self.bins[i - 1]
113            };
114            let right = if i < self.bins.len() - 1 {
115                self.bins[i + 1]
116            } else {
117                self.bins[i]
118            };
119
120            println!("  [{:>8.3}, {:>8.3}): {:>6} {}", left, right, count, bar);
121        }
122    }
123}
124
125/// Callback for tracking weight histograms during training.
126///
127/// This callback computes and logs histogram statistics of model parameters
128/// at regular intervals. Useful for:
129/// - Detecting vanishing/exploding weights
130/// - Monitoring weight distribution changes
131/// - Debugging initialization issues
132/// - Understanding parameter evolution
133///
134/// # Example
135///
136/// ```no_run
137/// use tensorlogic_train::{CallbackList, HistogramCallback};
138///
139/// let mut callbacks = CallbackList::new();
140/// callbacks.add(Box::new(HistogramCallback::new(
141///     5,   // log_frequency: Every 5 epochs
142///     10,  // num_bins: 10 histogram bins
143///     true, // verbose: Print detailed histograms
144/// )));
145/// ```
146pub struct HistogramCallback {
147    /// Frequency of logging (every N epochs).
148    log_frequency: usize,
149    /// Number of histogram bins.
150    #[allow(dead_code)]
151    // Used in compute_histograms - will be active when parameters are accessible
152    num_bins: usize,
153    /// Whether to print detailed histograms.
154    verbose: bool,
155    /// History of histogram statistics.
156    pub history: Vec<HashMap<String, HistogramStats>>,
157}
158
159impl HistogramCallback {
160    /// Create a new histogram callback.
161    ///
162    /// # Arguments
163    /// * `log_frequency` - Log histograms every N epochs
164    /// * `num_bins` - Number of bins in each histogram
165    /// * `verbose` - Print detailed ASCII histograms
166    pub fn new(log_frequency: usize, num_bins: usize, verbose: bool) -> Self {
167        Self {
168            log_frequency,
169            num_bins,
170            verbose,
171            history: Vec::new(),
172        }
173    }
174
175    /// Compute histograms for all parameters in state.
176    #[allow(dead_code)] // Placeholder - will be used when TrainingState includes parameters
177    fn compute_histograms(&self, _state: &TrainingState) -> HashMap<String, HistogramStats> {
178        // In a real implementation, we would access parameters from state
179        // For now, this is a placeholder that would be populated when
180        // TrainingState includes parameter access
181
182        // Example of what this would look like with actual parameters:
183        // let mut histograms = HashMap::new();
184        // for (name, param) in state.parameters.iter() {
185        //     let values: Vec<f64> = param.iter().copied().collect();
186        //     let stats = HistogramStats::compute(name, &values, self.num_bins);
187        //     histograms.insert(name.clone(), stats);
188        // }
189        // histograms
190
191        HashMap::new()
192    }
193}
194
195impl Callback for HistogramCallback {
196    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
197        if (epoch + 1).is_multiple_of(self.log_frequency) {
198            let histograms = self.compute_histograms(state);
199
200            if self.verbose {
201                println!("\n--- Weight Histograms (Epoch {}) ---", epoch + 1);
202                for (_name, stats) in histograms.iter() {
203                    stats.display(40); // 40 character width for ASCII bars
204                }
205            } else {
206                println!(
207                    "Epoch {}: Computed histograms for {} parameters",
208                    epoch + 1,
209                    histograms.len()
210                );
211            }
212
213            self.history.push(histograms);
214        }
215
216        Ok(())
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_histogram_stats() {
226        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
227        let stats = HistogramStats::compute("test", &values, 5);
228
229        assert_eq!(stats.name, "test");
230        assert_eq!(stats.min, 1.0);
231        assert_eq!(stats.max, 10.0);
232        assert!((stats.mean - 5.5).abs() < 1e-6);
233        assert_eq!(stats.bins.len(), 6);
234        assert_eq!(stats.counts.len(), 5);
235        assert_eq!(stats.counts.iter().sum::<usize>(), 10);
236    }
237
238    #[test]
239    fn test_histogram_callback() {
240        let mut callback = HistogramCallback::new(2, 10, false);
241        let state = TrainingState {
242            epoch: 0,
243            batch: 0,
244            train_loss: 0.5,
245            batch_loss: 0.5,
246            val_loss: Some(0.6),
247            learning_rate: 0.01,
248            metrics: HashMap::new(),
249        };
250
251        // Should not log on epoch 0
252        callback.on_epoch_end(0, &state).unwrap();
253        assert_eq!(callback.history.len(), 0);
254
255        // Should log on epoch 1 (frequency=2, so every 2 epochs)
256        callback.on_epoch_end(1, &state).unwrap();
257        assert_eq!(callback.history.len(), 1);
258    }
259}