Skip to main content

trustformers_debug/
activation_visualizer.rs

1//! Activation visualization tools for layer-wise debugging
2//!
3//! This module provides tools to inspect and visualize activations from different layers
4//! of a neural network, including heatmaps, distributions, and statistical analysis.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11// Note: scirs2_core types available for advanced operations if needed in future
12
13/// Activation visualizer for inspecting layer outputs
14#[derive(Debug)]
15pub struct ActivationVisualizer {
16    /// Stored activations by layer name
17    activations: HashMap<String, ActivationData>,
18    /// Configuration for visualization
19    config: ActivationConfig,
20}
21
22/// Configuration for activation visualization
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ActivationConfig {
25    /// Number of histogram bins
26    pub num_bins: usize,
27    /// Whether to compute detailed statistics
28    pub detailed_stats: bool,
29    /// Threshold for outlier detection (in standard deviations)
30    pub outlier_threshold: f64,
31    /// Maximum number of activations to store (to prevent memory overflow)
32    pub max_stored_activations: usize,
33}
34
35impl Default for ActivationConfig {
36    fn default() -> Self {
37        Self {
38            num_bins: 50,
39            detailed_stats: true,
40            outlier_threshold: 3.0,
41            max_stored_activations: 10000,
42        }
43    }
44}
45
46/// Stored activation data for a layer
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ActivationData {
49    /// Layer name
50    pub layer_name: String,
51    /// Activation values (flattened)
52    pub values: Vec<f32>,
53    /// Original shape of the activation tensor
54    pub shape: Vec<usize>,
55    /// Statistics computed from the activations
56    pub statistics: ActivationStatistics,
57    /// Timestamp when captured
58    pub timestamp: u64,
59}
60
61/// Statistical summary of activations
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ActivationStatistics {
64    /// Mean activation value
65    pub mean: f64,
66    /// Standard deviation
67    pub std_dev: f64,
68    /// Minimum value
69    pub min: f64,
70    /// Maximum value
71    pub max: f64,
72    /// Median value
73    pub median: f64,
74    /// 25th percentile
75    pub q25: f64,
76    /// 75th percentile
77    pub q75: f64,
78    /// Number of zero activations
79    pub num_zeros: usize,
80    /// Number of negative activations
81    pub num_negative: usize,
82    /// Number of positive activations
83    pub num_positive: usize,
84    /// Outlier count (values beyond threshold std devs)
85    pub num_outliers: usize,
86    /// Sparsity ratio (fraction of zeros)
87    pub sparsity: f64,
88}
89
90/// Histogram data for activation distribution
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct ActivationHistogram {
93    /// Bin edges
94    pub bin_edges: Vec<f64>,
95    /// Bin counts
96    pub bin_counts: Vec<usize>,
97    /// Total count
98    pub total_count: usize,
99}
100
101/// Heatmap data for 2D activation visualization
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ActivationHeatmap {
104    /// Layer name
105    pub layer_name: String,
106    /// 2D values for heatmap
107    pub values: Vec<Vec<f64>>,
108    /// Row labels (optional)
109    pub row_labels: Option<Vec<String>>,
110    /// Column labels (optional)
111    pub col_labels: Option<Vec<String>>,
112}
113
114impl ActivationVisualizer {
115    /// Create a new activation visualizer
116    ///
117    /// # Example
118    ///
119    /// ```
120    /// use trustformers_debug::ActivationVisualizer;
121    ///
122    /// let visualizer = ActivationVisualizer::new();
123    /// ```
124    pub fn new() -> Self {
125        Self {
126            activations: HashMap::new(),
127            config: ActivationConfig::default(),
128        }
129    }
130
131    /// Create a new activation visualizer with custom configuration
132    pub fn with_config(config: ActivationConfig) -> Self {
133        Self {
134            activations: HashMap::new(),
135            config,
136        }
137    }
138
139    /// Register activations from a layer
140    ///
141    /// # Arguments
142    ///
143    /// * `layer_name` - Name of the layer
144    /// * `values` - Flattened activation values
145    /// * `shape` - Original shape of the activation tensor
146    ///
147    /// # Example
148    ///
149    /// ```
150    /// # use trustformers_debug::ActivationVisualizer;
151    /// # let mut visualizer = ActivationVisualizer::new();
152    /// let activations = vec![0.1, 0.5, 0.3, 0.8];
153    /// visualizer.register("layer1", activations, vec![2, 2]).unwrap();
154    /// ```
155    pub fn register(
156        &mut self,
157        layer_name: &str,
158        values: Vec<f32>,
159        shape: Vec<usize>,
160    ) -> Result<()> {
161        // Limit stored activations to prevent memory overflow
162        let values = if values.len() > self.config.max_stored_activations {
163            values.into_iter().take(self.config.max_stored_activations).collect()
164        } else {
165            values
166        };
167
168        let statistics = self.compute_statistics(&values)?;
169
170        let timestamp =
171            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
172
173        let activation_data = ActivationData {
174            layer_name: layer_name.to_string(),
175            values,
176            shape,
177            statistics,
178            timestamp,
179        };
180
181        self.activations.insert(layer_name.to_string(), activation_data);
182        Ok(())
183    }
184
185    /// Get activations for a specific layer
186    pub fn get_activations(&self, layer_name: &str) -> Option<&ActivationData> {
187        self.activations.get(layer_name)
188    }
189
190    /// Get all layer names with registered activations
191    pub fn get_layer_names(&self) -> Vec<String> {
192        self.activations.keys().cloned().collect()
193    }
194
195    /// Compute statistics for activation values
196    fn compute_statistics(&self, values: &[f32]) -> Result<ActivationStatistics> {
197        if values.is_empty() {
198            return Ok(ActivationStatistics {
199                mean: 0.0,
200                std_dev: 0.0,
201                min: 0.0,
202                max: 0.0,
203                median: 0.0,
204                q25: 0.0,
205                q75: 0.0,
206                num_zeros: 0,
207                num_negative: 0,
208                num_positive: 0,
209                num_outliers: 0,
210                sparsity: 0.0,
211            });
212        }
213
214        let mean: f64 = values.iter().map(|&x| x as f64).sum::<f64>() / values.len() as f64;
215
216        let variance: f64 = values
217            .iter()
218            .map(|&x| {
219                let diff = x as f64 - mean;
220                diff * diff
221            })
222            .sum::<f64>()
223            / values.len() as f64;
224
225        let std_dev = variance.sqrt();
226
227        let min = values.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as f64;
228        let max = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as f64;
229
230        // Count zeros, negatives, positives
231        let num_zeros = values.iter().filter(|&&x| x.abs() < 1e-8).count();
232        let num_negative = values.iter().filter(|&&x| x < 0.0).count();
233        let num_positive = values.iter().filter(|&&x| x > 0.0).count();
234
235        // Count outliers
236        let num_outliers = values
237            .iter()
238            .filter(|&&x| (x as f64 - mean).abs() > self.config.outlier_threshold * std_dev)
239            .count();
240
241        // Compute percentiles
242        let mut sorted_values: Vec<f32> = values.to_vec();
243        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
244
245        let median = percentile(&sorted_values, 50.0);
246        let q25 = percentile(&sorted_values, 25.0);
247        let q75 = percentile(&sorted_values, 75.0);
248
249        let sparsity = num_zeros as f64 / values.len() as f64;
250
251        Ok(ActivationStatistics {
252            mean,
253            std_dev,
254            min,
255            max,
256            median,
257            q25,
258            q75,
259            num_zeros,
260            num_negative,
261            num_positive,
262            num_outliers,
263            sparsity,
264        })
265    }
266
267    /// Create a histogram of activation values
268    pub fn create_histogram(&self, layer_name: &str) -> Result<ActivationHistogram> {
269        let activation = self
270            .activations
271            .get(layer_name)
272            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
273
274        let min = activation.statistics.min;
275        let max = activation.statistics.max;
276
277        let bin_width = (max - min) / self.config.num_bins as f64;
278        let mut bin_counts = vec![0; self.config.num_bins];
279
280        for &value in &activation.values {
281            let bin_idx = if bin_width > 0.0 {
282                ((value as f64 - min) / bin_width).floor() as usize
283            } else {
284                0
285            };
286            let bin_idx = bin_idx.min(self.config.num_bins - 1);
287            bin_counts[bin_idx] += 1;
288        }
289
290        let bin_edges: Vec<f64> =
291            (0..=self.config.num_bins).map(|i| min + i as f64 * bin_width).collect();
292
293        Ok(ActivationHistogram {
294            bin_edges,
295            bin_counts,
296            total_count: activation.values.len(),
297        })
298    }
299
300    /// Create a heatmap from 2D activations
301    ///
302    /// # Arguments
303    ///
304    /// * `layer_name` - Name of the layer
305    /// * `reshape` - Optional reshape dimensions (e.g., [height, width])
306    pub fn create_heatmap(
307        &self,
308        layer_name: &str,
309        reshape: Option<(usize, usize)>,
310    ) -> Result<ActivationHeatmap> {
311        let activation = self
312            .activations
313            .get(layer_name)
314            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
315
316        let (rows, cols) = if let Some((r, c)) = reshape {
317            (r, c)
318        } else {
319            // Try to infer 2D shape
320            if activation.shape.len() >= 2 {
321                let rows = activation.shape[activation.shape.len() - 2];
322                let cols = activation.shape[activation.shape.len() - 1];
323                (rows, cols)
324            } else {
325                // Fallback: make it as square as possible
326                let total = activation.values.len();
327                let cols = (total as f64).sqrt().ceil() as usize;
328                let rows = total.div_ceil(cols);
329                (rows, cols)
330            }
331        };
332
333        let mut values = vec![vec![0.0; cols]; rows];
334        for (i, &val) in activation.values.iter().enumerate().take(rows * cols) {
335            let row = i / cols;
336            let col = i % cols;
337            if row < rows {
338                values[row][col] = val as f64;
339            }
340        }
341
342        Ok(ActivationHeatmap {
343            layer_name: layer_name.to_string(),
344            values,
345            row_labels: None,
346            col_labels: None,
347        })
348    }
349
350    /// Export activation statistics to JSON
351    pub fn export_statistics(&self, layer_name: &str, output_path: &Path) -> Result<()> {
352        let activation = self
353            .activations
354            .get(layer_name)
355            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
356
357        let json = serde_json::to_string_pretty(&activation.statistics)?;
358        std::fs::write(output_path, json)?;
359
360        Ok(())
361    }
362
363    /// Plot distribution as ASCII histogram
364    pub fn plot_distribution_ascii(&self, layer_name: &str) -> Result<String> {
365        let histogram = self.create_histogram(layer_name)?;
366
367        let max_count = histogram.bin_counts.iter().max().unwrap_or(&0);
368        let scale = if *max_count > 0 { 50.0 / *max_count as f64 } else { 1.0 };
369
370        let mut output = String::new();
371        output.push_str(&format!("Activation Distribution: {}\n", layer_name));
372        output.push_str(&"=".repeat(60));
373        output.push('\n');
374
375        for i in 0..histogram.bin_counts.len() {
376            let bar_length = (histogram.bin_counts[i] as f64 * scale) as usize;
377            let bar = "█".repeat(bar_length);
378            output.push_str(&format!(
379                "{:8.3} - {:8.3} | {} ({})\n",
380                histogram.bin_edges[i],
381                histogram.bin_edges[i + 1],
382                bar,
383                histogram.bin_counts[i]
384            ));
385        }
386
387        Ok(output)
388    }
389
390    /// Print summary statistics for all layers
391    pub fn print_summary(&self) -> Result<String> {
392        let mut output = String::new();
393        output.push_str("Activation Summary\n");
394        output.push_str(&"=".repeat(80));
395        output.push('\n');
396
397        for (layer_name, activation) in &self.activations {
398            output.push_str(&format!("\nLayer: {}\n", layer_name));
399            output.push_str(&format!("  Shape: {:?}\n", activation.shape));
400            output.push_str(&format!("  Mean: {:.6}\n", activation.statistics.mean));
401            output.push_str(&format!(
402                "  Std Dev: {:.6}\n",
403                activation.statistics.std_dev
404            ));
405            output.push_str(&format!("  Min: {:.6}\n", activation.statistics.min));
406            output.push_str(&format!("  Max: {:.6}\n", activation.statistics.max));
407            output.push_str(&format!("  Median: {:.6}\n", activation.statistics.median));
408            output.push_str(&format!(
409                "  Sparsity: {:.2}%\n",
410                activation.statistics.sparsity * 100.0
411            ));
412            output.push_str(&format!(
413                "  Outliers: {} ({:.2}%)\n",
414                activation.statistics.num_outliers,
415                activation.statistics.num_outliers as f64 / activation.values.len() as f64 * 100.0
416            ));
417        }
418
419        Ok(output)
420    }
421
422    /// Clear all stored activations
423    pub fn clear(&mut self) {
424        self.activations.clear();
425    }
426
427    /// Get number of stored activations
428    pub fn num_layers(&self) -> usize {
429        self.activations.len()
430    }
431}
432
433impl Default for ActivationVisualizer {
434    fn default() -> Self {
435        Self::new()
436    }
437}
438
439/// Helper function to compute percentile
440fn percentile(sorted_values: &[f32], p: f64) -> f64 {
441    if sorted_values.is_empty() {
442        return 0.0;
443    }
444
445    let index = (p / 100.0 * (sorted_values.len() - 1) as f64).round() as usize;
446    sorted_values[index.min(sorted_values.len() - 1)] as f64
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use std::env;
453
454    #[test]
455    fn test_activation_visualizer_creation() {
456        let visualizer = ActivationVisualizer::new();
457        assert_eq!(visualizer.num_layers(), 0);
458    }
459
460    #[test]
461    fn test_register_activations() {
462        let mut visualizer = ActivationVisualizer::new();
463        let values = vec![0.1, 0.5, 0.3, 0.8, -0.2];
464
465        visualizer
466            .register("layer1", values.clone(), vec![5])
467            .expect("operation failed in test");
468        assert_eq!(visualizer.num_layers(), 1);
469
470        let activation = visualizer.get_activations("layer1").expect("operation failed in test");
471        assert_eq!(activation.values, values);
472        assert_eq!(activation.shape, vec![5]);
473    }
474
475    #[test]
476    fn test_compute_statistics() {
477        let visualizer = ActivationVisualizer::new();
478        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
479
480        let stats = visualizer.compute_statistics(&values).expect("operation failed in test");
481        assert_eq!(stats.mean, 3.0);
482        assert!(stats.std_dev > 0.0);
483        assert_eq!(stats.min, 1.0);
484        assert_eq!(stats.max, 5.0);
485        assert_eq!(stats.num_zeros, 0);
486        assert_eq!(stats.num_positive, 5);
487    }
488
489    #[test]
490    fn test_create_histogram() {
491        let mut visualizer = ActivationVisualizer::new();
492        let values: Vec<f32> = (0..100).map(|x| x as f32).collect();
493
494        visualizer
495            .register("layer1", values, vec![100])
496            .expect("operation failed in test");
497
498        let histogram = visualizer.create_histogram("layer1").expect("operation failed in test");
499        assert_eq!(histogram.bin_edges.len(), visualizer.config.num_bins + 1);
500        assert_eq!(histogram.total_count, 100);
501    }
502
503    #[test]
504    fn test_create_heatmap() {
505        let mut visualizer = ActivationVisualizer::new();
506        let values: Vec<f32> = (0..16).map(|x| x as f32).collect();
507
508        visualizer
509            .register("layer1", values, vec![4, 4])
510            .expect("operation failed in test");
511
512        let heatmap = visualizer
513            .create_heatmap("layer1", Some((4, 4)))
514            .expect("operation failed in test");
515        assert_eq!(heatmap.values.len(), 4);
516        assert_eq!(heatmap.values[0].len(), 4);
517    }
518
519    #[test]
520    fn test_export_statistics() {
521        let temp_dir = env::temp_dir();
522        let output_path = temp_dir.join("activation_stats.json");
523
524        let mut visualizer = ActivationVisualizer::new();
525        let values = vec![1.0, 2.0, 3.0];
526
527        visualizer
528            .register("layer1", values, vec![3])
529            .expect("operation failed in test");
530        visualizer
531            .export_statistics("layer1", &output_path)
532            .expect("operation failed in test");
533
534        assert!(output_path.exists());
535
536        // Clean up
537        let _ = std::fs::remove_file(output_path);
538    }
539
540    #[test]
541    fn test_plot_distribution_ascii() {
542        let mut visualizer = ActivationVisualizer::new();
543        let values: Vec<f32> = (0..100).map(|x| x as f32 / 100.0).collect();
544
545        visualizer
546            .register("layer1", values, vec![100])
547            .expect("operation failed in test");
548
549        let ascii_plot =
550            visualizer.plot_distribution_ascii("layer1").expect("operation failed in test");
551        assert!(ascii_plot.contains("Activation Distribution"));
552        assert!(ascii_plot.contains("layer1"));
553    }
554
555    #[test]
556    fn test_print_summary() {
557        let mut visualizer = ActivationVisualizer::new();
558
559        visualizer
560            .register("layer1", vec![1.0, 2.0, 3.0], vec![3])
561            .expect("operation failed in test");
562        visualizer
563            .register("layer2", vec![4.0, 5.0, 6.0], vec![3])
564            .expect("operation failed in test");
565
566        let summary = visualizer.print_summary().expect("operation failed in test");
567        assert!(summary.contains("layer1"));
568        assert!(summary.contains("layer2"));
569        assert!(summary.contains("Mean"));
570        assert!(summary.contains("Std Dev"));
571    }
572
573    #[test]
574    fn test_sparsity_calculation() {
575        let visualizer = ActivationVisualizer::new();
576        let values = vec![0.0, 0.0, 0.0, 1.0, 0.0];
577
578        let stats = visualizer.compute_statistics(&values).expect("operation failed in test");
579        assert_eq!(stats.num_zeros, 4);
580        assert_eq!(stats.sparsity, 0.8);
581    }
582
583    #[test]
584    fn test_clear_activations() {
585        let mut visualizer = ActivationVisualizer::new();
586
587        visualizer
588            .register("layer1", vec![1.0], vec![1])
589            .expect("operation failed in test");
590        visualizer
591            .register("layer2", vec![2.0], vec![1])
592            .expect("operation failed in test");
593
594        assert_eq!(visualizer.num_layers(), 2);
595
596        visualizer.clear();
597        assert_eq!(visualizer.num_layers(), 0);
598    }
599}