tensorlogic_train/callbacks/
histogram.rs1use crate::callbacks::core::Callback;
4use crate::{TrainResult, TrainingState};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct HistogramStats {
10 pub name: String,
12 pub min: f64,
14 pub max: f64,
16 pub mean: f64,
18 pub std: f64,
20 pub bins: Vec<f64>,
22 pub counts: Vec<usize>,
24}
25
26impl HistogramStats {
27 pub fn compute(name: &str, values: &[f64], num_bins: usize) -> Self {
29 if values.is_empty() {
30 return Self {
31 name: name.to_string(),
32 min: 0.0,
33 max: 0.0,
34 mean: 0.0,
35 std: 0.0,
36 bins: vec![],
37 counts: vec![],
38 };
39 }
40
41 let min = values.iter().copied().fold(f64::INFINITY, f64::min);
43 let max = values.iter().copied().fold(f64::NEG_INFINITY, f64::max);
44 let sum: f64 = values.iter().sum();
45 let mean = sum / values.len() as f64;
46
47 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
48 let std = variance.sqrt();
49
50 let mut bins = Vec::with_capacity(num_bins + 1);
52 let mut counts = vec![0; num_bins];
53
54 let range = max - min;
55 let bin_width = if range > 0.0 {
56 range / num_bins as f64
57 } else {
58 1.0
59 };
60
61 for i in 0..=num_bins {
62 bins.push(min + i as f64 * bin_width);
63 }
64
65 for &value in values {
67 let bin_idx = if range > 0.0 {
68 ((value - min) / bin_width).floor() as usize
69 } else {
70 0
71 };
72 let bin_idx = bin_idx.min(num_bins - 1);
73 counts[bin_idx] += 1;
74 }
75
76 Self {
77 name: name.to_string(),
78 min,
79 max,
80 mean,
81 std,
82 bins,
83 counts,
84 }
85 }
86
87 pub fn display(&self, width: usize) {
89 println!("\n=== Histogram: {} ===", self.name);
90 println!(" Min: {:.6}, Max: {:.6}", self.min, self.max);
91 println!(" Mean: {:.6}, Std: {:.6}", self.mean, self.std);
92 println!("\n Distribution:");
93
94 if self.counts.is_empty() {
95 println!(" (empty)");
96 return;
97 }
98
99 let max_count = *self.counts.iter().max().unwrap_or(&1);
100
101 for (i, &count) in self.counts.iter().enumerate() {
102 let bar_len = if max_count > 0 {
103 (count as f64 / max_count as f64 * width as f64) as usize
104 } else {
105 0
106 };
107
108 let bar = "█".repeat(bar_len);
109 let left = if i < self.bins.len() - 1 {
110 self.bins[i]
111 } else {
112 self.bins[i - 1]
113 };
114 let right = if i < self.bins.len() - 1 {
115 self.bins[i + 1]
116 } else {
117 self.bins[i]
118 };
119
120 println!(" [{:>8.3}, {:>8.3}): {:>6} {}", left, right, count, bar);
121 }
122 }
123}
124
125pub struct HistogramCallback {
147 log_frequency: usize,
149 #[allow(dead_code)]
151 num_bins: usize,
153 verbose: bool,
155 pub history: Vec<HashMap<String, HistogramStats>>,
157}
158
159impl HistogramCallback {
160 pub fn new(log_frequency: usize, num_bins: usize, verbose: bool) -> Self {
167 Self {
168 log_frequency,
169 num_bins,
170 verbose,
171 history: Vec::new(),
172 }
173 }
174
175 #[allow(dead_code)] fn compute_histograms(&self, _state: &TrainingState) -> HashMap<String, HistogramStats> {
178 HashMap::new()
192 }
193}
194
195impl Callback for HistogramCallback {
196 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
197 if (epoch + 1).is_multiple_of(self.log_frequency) {
198 let histograms = self.compute_histograms(state);
199
200 if self.verbose {
201 println!("\n--- Weight Histograms (Epoch {}) ---", epoch + 1);
202 for (_name, stats) in histograms.iter() {
203 stats.display(40); }
205 } else {
206 println!(
207 "Epoch {}: Computed histograms for {} parameters",
208 epoch + 1,
209 histograms.len()
210 );
211 }
212
213 self.history.push(histograms);
214 }
215
216 Ok(())
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_histogram_stats() {
226 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
227 let stats = HistogramStats::compute("test", &values, 5);
228
229 assert_eq!(stats.name, "test");
230 assert_eq!(stats.min, 1.0);
231 assert_eq!(stats.max, 10.0);
232 assert!((stats.mean - 5.5).abs() < 1e-6);
233 assert_eq!(stats.bins.len(), 6);
234 assert_eq!(stats.counts.len(), 5);
235 assert_eq!(stats.counts.iter().sum::<usize>(), 10);
236 }
237
238 #[test]
239 fn test_histogram_callback() {
240 let mut callback = HistogramCallback::new(2, 10, false);
241 let state = TrainingState {
242 epoch: 0,
243 batch: 0,
244 train_loss: 0.5,
245 batch_loss: 0.5,
246 val_loss: Some(0.6),
247 learning_rate: 0.01,
248 metrics: HashMap::new(),
249 };
250
251 callback.on_epoch_end(0, &state).unwrap();
253 assert_eq!(callback.history.len(), 0);
254
255 callback.on_epoch_end(1, &state).unwrap();
257 assert_eq!(callback.history.len(), 1);
258 }
259}