1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug)]
15pub struct ActivationVisualizer {
16 activations: HashMap<String, ActivationData>,
18 config: ActivationConfig,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ActivationConfig {
25 pub num_bins: usize,
27 pub detailed_stats: bool,
29 pub outlier_threshold: f64,
31 pub max_stored_activations: usize,
33}
34
35impl Default for ActivationConfig {
36 fn default() -> Self {
37 Self {
38 num_bins: 50,
39 detailed_stats: true,
40 outlier_threshold: 3.0,
41 max_stored_activations: 10000,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ActivationData {
49 pub layer_name: String,
51 pub values: Vec<f32>,
53 pub shape: Vec<usize>,
55 pub statistics: ActivationStatistics,
57 pub timestamp: u64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ActivationStatistics {
64 pub mean: f64,
66 pub std_dev: f64,
68 pub min: f64,
70 pub max: f64,
72 pub median: f64,
74 pub q25: f64,
76 pub q75: f64,
78 pub num_zeros: usize,
80 pub num_negative: usize,
82 pub num_positive: usize,
84 pub num_outliers: usize,
86 pub sparsity: f64,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ActivationHistogram {
93 pub bin_edges: Vec<f64>,
95 pub bin_counts: Vec<usize>,
97 pub total_count: usize,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ActivationHeatmap {
104 pub layer_name: String,
106 pub values: Vec<Vec<f64>>,
108 pub row_labels: Option<Vec<String>>,
110 pub col_labels: Option<Vec<String>>,
112}
113
114impl ActivationVisualizer {
115 pub fn new() -> Self {
125 Self {
126 activations: HashMap::new(),
127 config: ActivationConfig::default(),
128 }
129 }
130
131 pub fn with_config(config: ActivationConfig) -> Self {
133 Self {
134 activations: HashMap::new(),
135 config,
136 }
137 }
138
139 pub fn register(
156 &mut self,
157 layer_name: &str,
158 values: Vec<f32>,
159 shape: Vec<usize>,
160 ) -> Result<()> {
161 let values = if values.len() > self.config.max_stored_activations {
163 values.into_iter().take(self.config.max_stored_activations).collect()
164 } else {
165 values
166 };
167
168 let statistics = self.compute_statistics(&values)?;
169
170 let timestamp =
171 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
172
173 let activation_data = ActivationData {
174 layer_name: layer_name.to_string(),
175 values,
176 shape,
177 statistics,
178 timestamp,
179 };
180
181 self.activations.insert(layer_name.to_string(), activation_data);
182 Ok(())
183 }
184
185 pub fn get_activations(&self, layer_name: &str) -> Option<&ActivationData> {
187 self.activations.get(layer_name)
188 }
189
190 pub fn get_layer_names(&self) -> Vec<String> {
192 self.activations.keys().cloned().collect()
193 }
194
195 fn compute_statistics(&self, values: &[f32]) -> Result<ActivationStatistics> {
197 if values.is_empty() {
198 return Ok(ActivationStatistics {
199 mean: 0.0,
200 std_dev: 0.0,
201 min: 0.0,
202 max: 0.0,
203 median: 0.0,
204 q25: 0.0,
205 q75: 0.0,
206 num_zeros: 0,
207 num_negative: 0,
208 num_positive: 0,
209 num_outliers: 0,
210 sparsity: 0.0,
211 });
212 }
213
214 let mean: f64 = values.iter().map(|&x| x as f64).sum::<f64>() / values.len() as f64;
215
216 let variance: f64 = values
217 .iter()
218 .map(|&x| {
219 let diff = x as f64 - mean;
220 diff * diff
221 })
222 .sum::<f64>()
223 / values.len() as f64;
224
225 let std_dev = variance.sqrt();
226
227 let min = values.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as f64;
228 let max = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as f64;
229
230 let num_zeros = values.iter().filter(|&&x| x.abs() < 1e-8).count();
232 let num_negative = values.iter().filter(|&&x| x < 0.0).count();
233 let num_positive = values.iter().filter(|&&x| x > 0.0).count();
234
235 let num_outliers = values
237 .iter()
238 .filter(|&&x| (x as f64 - mean).abs() > self.config.outlier_threshold * std_dev)
239 .count();
240
241 let mut sorted_values: Vec<f32> = values.to_vec();
243 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
244
245 let median = percentile(&sorted_values, 50.0);
246 let q25 = percentile(&sorted_values, 25.0);
247 let q75 = percentile(&sorted_values, 75.0);
248
249 let sparsity = num_zeros as f64 / values.len() as f64;
250
251 Ok(ActivationStatistics {
252 mean,
253 std_dev,
254 min,
255 max,
256 median,
257 q25,
258 q75,
259 num_zeros,
260 num_negative,
261 num_positive,
262 num_outliers,
263 sparsity,
264 })
265 }
266
267 pub fn create_histogram(&self, layer_name: &str) -> Result<ActivationHistogram> {
269 let activation = self
270 .activations
271 .get(layer_name)
272 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
273
274 let min = activation.statistics.min;
275 let max = activation.statistics.max;
276
277 let bin_width = (max - min) / self.config.num_bins as f64;
278 let mut bin_counts = vec![0; self.config.num_bins];
279
280 for &value in &activation.values {
281 let bin_idx = if bin_width > 0.0 {
282 ((value as f64 - min) / bin_width).floor() as usize
283 } else {
284 0
285 };
286 let bin_idx = bin_idx.min(self.config.num_bins - 1);
287 bin_counts[bin_idx] += 1;
288 }
289
290 let bin_edges: Vec<f64> =
291 (0..=self.config.num_bins).map(|i| min + i as f64 * bin_width).collect();
292
293 Ok(ActivationHistogram {
294 bin_edges,
295 bin_counts,
296 total_count: activation.values.len(),
297 })
298 }
299
300 pub fn create_heatmap(
307 &self,
308 layer_name: &str,
309 reshape: Option<(usize, usize)>,
310 ) -> Result<ActivationHeatmap> {
311 let activation = self
312 .activations
313 .get(layer_name)
314 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
315
316 let (rows, cols) = if let Some((r, c)) = reshape {
317 (r, c)
318 } else {
319 if activation.shape.len() >= 2 {
321 let rows = activation.shape[activation.shape.len() - 2];
322 let cols = activation.shape[activation.shape.len() - 1];
323 (rows, cols)
324 } else {
325 let total = activation.values.len();
327 let cols = (total as f64).sqrt().ceil() as usize;
328 let rows = total.div_ceil(cols);
329 (rows, cols)
330 }
331 };
332
333 let mut values = vec![vec![0.0; cols]; rows];
334 for (i, &val) in activation.values.iter().enumerate().take(rows * cols) {
335 let row = i / cols;
336 let col = i % cols;
337 if row < rows {
338 values[row][col] = val as f64;
339 }
340 }
341
342 Ok(ActivationHeatmap {
343 layer_name: layer_name.to_string(),
344 values,
345 row_labels: None,
346 col_labels: None,
347 })
348 }
349
350 pub fn export_statistics(&self, layer_name: &str, output_path: &Path) -> Result<()> {
352 let activation = self
353 .activations
354 .get(layer_name)
355 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
356
357 let json = serde_json::to_string_pretty(&activation.statistics)?;
358 std::fs::write(output_path, json)?;
359
360 Ok(())
361 }
362
363 pub fn plot_distribution_ascii(&self, layer_name: &str) -> Result<String> {
365 let histogram = self.create_histogram(layer_name)?;
366
367 let max_count = histogram.bin_counts.iter().max().unwrap_or(&0);
368 let scale = if *max_count > 0 { 50.0 / *max_count as f64 } else { 1.0 };
369
370 let mut output = String::new();
371 output.push_str(&format!("Activation Distribution: {}\n", layer_name));
372 output.push_str(&"=".repeat(60));
373 output.push('\n');
374
375 for i in 0..histogram.bin_counts.len() {
376 let bar_length = (histogram.bin_counts[i] as f64 * scale) as usize;
377 let bar = "█".repeat(bar_length);
378 output.push_str(&format!(
379 "{:8.3} - {:8.3} | {} ({})\n",
380 histogram.bin_edges[i],
381 histogram.bin_edges[i + 1],
382 bar,
383 histogram.bin_counts[i]
384 ));
385 }
386
387 Ok(output)
388 }
389
390 pub fn print_summary(&self) -> Result<String> {
392 let mut output = String::new();
393 output.push_str("Activation Summary\n");
394 output.push_str(&"=".repeat(80));
395 output.push('\n');
396
397 for (layer_name, activation) in &self.activations {
398 output.push_str(&format!("\nLayer: {}\n", layer_name));
399 output.push_str(&format!(" Shape: {:?}\n", activation.shape));
400 output.push_str(&format!(" Mean: {:.6}\n", activation.statistics.mean));
401 output.push_str(&format!(
402 " Std Dev: {:.6}\n",
403 activation.statistics.std_dev
404 ));
405 output.push_str(&format!(" Min: {:.6}\n", activation.statistics.min));
406 output.push_str(&format!(" Max: {:.6}\n", activation.statistics.max));
407 output.push_str(&format!(" Median: {:.6}\n", activation.statistics.median));
408 output.push_str(&format!(
409 " Sparsity: {:.2}%\n",
410 activation.statistics.sparsity * 100.0
411 ));
412 output.push_str(&format!(
413 " Outliers: {} ({:.2}%)\n",
414 activation.statistics.num_outliers,
415 activation.statistics.num_outliers as f64 / activation.values.len() as f64 * 100.0
416 ));
417 }
418
419 Ok(output)
420 }
421
422 pub fn clear(&mut self) {
424 self.activations.clear();
425 }
426
427 pub fn num_layers(&self) -> usize {
429 self.activations.len()
430 }
431}
432
433impl Default for ActivationVisualizer {
434 fn default() -> Self {
435 Self::new()
436 }
437}
438
439fn percentile(sorted_values: &[f32], p: f64) -> f64 {
441 if sorted_values.is_empty() {
442 return 0.0;
443 }
444
445 let index = (p / 100.0 * (sorted_values.len() - 1) as f64).round() as usize;
446 sorted_values[index.min(sorted_values.len() - 1)] as f64
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use std::env;
453
454 #[test]
455 fn test_activation_visualizer_creation() {
456 let visualizer = ActivationVisualizer::new();
457 assert_eq!(visualizer.num_layers(), 0);
458 }
459
460 #[test]
461 fn test_register_activations() {
462 let mut visualizer = ActivationVisualizer::new();
463 let values = vec![0.1, 0.5, 0.3, 0.8, -0.2];
464
465 visualizer.register("layer1", values.clone(), vec![5]).unwrap();
466 assert_eq!(visualizer.num_layers(), 1);
467
468 let activation = visualizer.get_activations("layer1").unwrap();
469 assert_eq!(activation.values, values);
470 assert_eq!(activation.shape, vec![5]);
471 }
472
473 #[test]
474 fn test_compute_statistics() {
475 let visualizer = ActivationVisualizer::new();
476 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
477
478 let stats = visualizer.compute_statistics(&values).unwrap();
479 assert_eq!(stats.mean, 3.0);
480 assert!(stats.std_dev > 0.0);
481 assert_eq!(stats.min, 1.0);
482 assert_eq!(stats.max, 5.0);
483 assert_eq!(stats.num_zeros, 0);
484 assert_eq!(stats.num_positive, 5);
485 }
486
487 #[test]
488 fn test_create_histogram() {
489 let mut visualizer = ActivationVisualizer::new();
490 let values: Vec<f32> = (0..100).map(|x| x as f32).collect();
491
492 visualizer.register("layer1", values, vec![100]).unwrap();
493
494 let histogram = visualizer.create_histogram("layer1").unwrap();
495 assert_eq!(histogram.bin_edges.len(), visualizer.config.num_bins + 1);
496 assert_eq!(histogram.total_count, 100);
497 }
498
499 #[test]
500 fn test_create_heatmap() {
501 let mut visualizer = ActivationVisualizer::new();
502 let values: Vec<f32> = (0..16).map(|x| x as f32).collect();
503
504 visualizer.register("layer1", values, vec![4, 4]).unwrap();
505
506 let heatmap = visualizer.create_heatmap("layer1", Some((4, 4))).unwrap();
507 assert_eq!(heatmap.values.len(), 4);
508 assert_eq!(heatmap.values[0].len(), 4);
509 }
510
511 #[test]
512 fn test_export_statistics() {
513 let temp_dir = env::temp_dir();
514 let output_path = temp_dir.join("activation_stats.json");
515
516 let mut visualizer = ActivationVisualizer::new();
517 let values = vec![1.0, 2.0, 3.0];
518
519 visualizer.register("layer1", values, vec![3]).unwrap();
520 visualizer.export_statistics("layer1", &output_path).unwrap();
521
522 assert!(output_path.exists());
523
524 let _ = std::fs::remove_file(output_path);
526 }
527
528 #[test]
529 fn test_plot_distribution_ascii() {
530 let mut visualizer = ActivationVisualizer::new();
531 let values: Vec<f32> = (0..100).map(|x| x as f32 / 100.0).collect();
532
533 visualizer.register("layer1", values, vec![100]).unwrap();
534
535 let ascii_plot = visualizer.plot_distribution_ascii("layer1").unwrap();
536 assert!(ascii_plot.contains("Activation Distribution"));
537 assert!(ascii_plot.contains("layer1"));
538 }
539
540 #[test]
541 fn test_print_summary() {
542 let mut visualizer = ActivationVisualizer::new();
543
544 visualizer.register("layer1", vec![1.0, 2.0, 3.0], vec![3]).unwrap();
545 visualizer.register("layer2", vec![4.0, 5.0, 6.0], vec![3]).unwrap();
546
547 let summary = visualizer.print_summary().unwrap();
548 assert!(summary.contains("layer1"));
549 assert!(summary.contains("layer2"));
550 assert!(summary.contains("Mean"));
551 assert!(summary.contains("Std Dev"));
552 }
553
554 #[test]
555 fn test_sparsity_calculation() {
556 let visualizer = ActivationVisualizer::new();
557 let values = vec![0.0, 0.0, 0.0, 1.0, 0.0];
558
559 let stats = visualizer.compute_statistics(&values).unwrap();
560 assert_eq!(stats.num_zeros, 4);
561 assert_eq!(stats.sparsity, 0.8);
562 }
563
564 #[test]
565 fn test_clear_activations() {
566 let mut visualizer = ActivationVisualizer::new();
567
568 visualizer.register("layer1", vec![1.0], vec![1]).unwrap();
569 visualizer.register("layer2", vec![2.0], vec![1]).unwrap();
570
571 assert_eq!(visualizer.num_layers(), 2);
572
573 visualizer.clear();
574 assert_eq!(visualizer.num_layers(), 0);
575 }
576}