Skip to main content

trustformers_debug/
weight_analyzer.rs

1//! Weight distribution analysis tools
2//!
3//! This module provides tools to analyze weight distributions in neural networks,
4//! including dead neuron detection, initialization validation, and statistical analysis.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11/// Weight distribution analyzer for model inspection
12#[derive(Debug)]
13pub struct WeightAnalyzer {
14    /// Stored weight analyses by layer name
15    analyses: HashMap<String, WeightAnalysis>,
16    /// Configuration
17    config: WeightAnalyzerConfig,
18}
19
20/// Configuration for weight analysis
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WeightAnalyzerConfig {
23    /// Threshold for dead neuron detection (absolute value)
24    pub dead_neuron_threshold: f64,
25    /// Number of histogram bins
26    pub num_bins: usize,
27    /// Check for initialization issues
28    pub check_initialization: bool,
29    /// Expected initialization schemes to validate against
30    pub expected_init_schemes: Vec<InitializationScheme>,
31}
32
33impl Default for WeightAnalyzerConfig {
34    fn default() -> Self {
35        Self {
36            dead_neuron_threshold: 1e-8,
37            num_bins: 50,
38            check_initialization: true,
39            expected_init_schemes: vec![
40                InitializationScheme::XavierUniform,
41                InitializationScheme::HeNormal,
42            ],
43        }
44    }
45}
46
47/// Initialization schemes for validation
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum InitializationScheme {
50    /// Xavier/Glorot uniform initialization
51    XavierUniform,
52    /// Xavier/Glorot normal initialization
53    XavierNormal,
54    /// He uniform initialization
55    HeUniform,
56    /// He normal initialization
57    HeNormal,
58    /// LeCun normal initialization
59    LeCunNormal,
60    /// Orthogonal initialization
61    Orthogonal,
62    /// Uniform initialization
63    Uniform,
64    /// Normal initialization
65    Normal,
66}
67
68/// Analysis results for a layer's weights
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct WeightAnalysis {
71    /// Layer name
72    pub layer_name: String,
73    /// Weight statistics
74    pub statistics: WeightStatistics,
75    /// Dead neurons detected
76    pub dead_neurons: Vec<usize>,
77    /// Histogram data
78    pub histogram: WeightHistogram,
79    /// Likely initialization scheme
80    pub likely_init_scheme: Option<InitializationScheme>,
81    /// Initialization warnings
82    pub init_warnings: Vec<String>,
83}
84
85/// Statistical summary of weights
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct WeightStatistics {
88    /// Mean weight value
89    pub mean: f64,
90    /// Standard deviation
91    pub std_dev: f64,
92    /// Minimum value
93    pub min: f64,
94    /// Maximum value
95    pub max: f64,
96    /// Median value
97    pub median: f64,
98    /// 25th percentile
99    pub q25: f64,
100    /// 75th percentile
101    pub q75: f64,
102    /// Skewness
103    pub skewness: f64,
104    /// Kurtosis
105    pub kurtosis: f64,
106    /// L1 norm
107    pub l1_norm: f64,
108    /// L2 norm
109    pub l2_norm: f64,
110    /// Number of zero weights
111    pub num_zeros: usize,
112    /// Sparsity ratio
113    pub sparsity: f64,
114}
115
116/// Histogram of weight distribution
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct WeightHistogram {
119    /// Bin edges
120    pub bin_edges: Vec<f64>,
121    /// Bin counts
122    pub bin_counts: Vec<usize>,
123    /// Total count
124    pub total_count: usize,
125}
126
127impl WeightAnalyzer {
128    /// Create a new weight analyzer
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use trustformers_debug::WeightAnalyzer;
134    ///
135    /// let analyzer = WeightAnalyzer::new();
136    /// ```
137    pub fn new() -> Self {
138        Self {
139            analyses: HashMap::new(),
140            config: WeightAnalyzerConfig::default(),
141        }
142    }
143
144    /// Create a weight analyzer with custom configuration
145    pub fn with_config(config: WeightAnalyzerConfig) -> Self {
146        Self {
147            analyses: HashMap::new(),
148            config,
149        }
150    }
151
152    /// Analyze weights from a layer
153    ///
154    /// # Arguments
155    ///
156    /// * `layer_name` - Name of the layer
157    /// * `weights` - Weight values
158    ///
159    /// # Example
160    ///
161    /// ```
162    /// # use trustformers_debug::WeightAnalyzer;
163    /// # let mut analyzer = WeightAnalyzer::new();
164    /// let weights = vec![0.1, 0.2, 0.05, 0.15, 0.3];
165    /// let analysis = analyzer.analyze("layer1", &weights).unwrap();
166    /// ```
167    pub fn analyze(&mut self, layer_name: &str, weights: &[f64]) -> Result<&WeightAnalysis> {
168        let statistics = self.compute_statistics(weights)?;
169        let dead_neurons = self.detect_dead_neurons(weights);
170        let histogram = self.compute_histogram(weights)?;
171        let (likely_init_scheme, init_warnings) = if self.config.check_initialization {
172            self.check_initialization(&statistics)
173        } else {
174            (None, Vec::new())
175        };
176
177        let analysis = WeightAnalysis {
178            layer_name: layer_name.to_string(),
179            statistics,
180            dead_neurons,
181            histogram,
182            likely_init_scheme,
183            init_warnings,
184        };
185
186        self.analyses.insert(layer_name.to_string(), analysis);
187        Ok(self.analyses.get(layer_name).expect("analysis should exist after insert"))
188    }
189
190    /// Compute statistics for weights
191    fn compute_statistics(&self, weights: &[f64]) -> Result<WeightStatistics> {
192        if weights.is_empty() {
193            anyhow::bail!("Cannot compute statistics for empty weight array");
194        }
195
196        let n = weights.len() as f64;
197        let mean = weights.iter().sum::<f64>() / n;
198
199        let variance = weights
200            .iter()
201            .map(|&x| {
202                let diff = x - mean;
203                diff * diff
204            })
205            .sum::<f64>()
206            / n;
207        let std_dev = variance.sqrt();
208
209        let mut sorted = weights.to_vec();
210        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211
212        let min = sorted[0];
213        let max = sorted[sorted.len() - 1];
214        let median = percentile(&sorted, 50.0);
215        let q25 = percentile(&sorted, 25.0);
216        let q75 = percentile(&sorted, 75.0);
217
218        // Compute skewness
219        let skewness = if std_dev > 0.0 {
220            weights
221                .iter()
222                .map(|&x| {
223                    let z = (x - mean) / std_dev;
224                    z * z * z
225                })
226                .sum::<f64>()
227                / n
228        } else {
229            0.0
230        };
231
232        // Compute kurtosis
233        let kurtosis = if std_dev > 0.0 {
234            weights
235                .iter()
236                .map(|&x| {
237                    let z = (x - mean) / std_dev;
238                    z * z * z * z
239                })
240                .sum::<f64>()
241                / n
242                - 3.0
243        } else {
244            0.0
245        };
246
247        let l1_norm = weights.iter().map(|x| x.abs()).sum::<f64>();
248        let l2_norm = weights.iter().map(|x| x * x).sum::<f64>().sqrt();
249
250        let num_zeros = weights.iter().filter(|&&x| x.abs() < 1e-10).count();
251        let sparsity = num_zeros as f64 / n;
252
253        Ok(WeightStatistics {
254            mean,
255            std_dev,
256            min,
257            max,
258            median,
259            q25,
260            q75,
261            skewness,
262            kurtosis,
263            l1_norm,
264            l2_norm,
265            num_zeros,
266            sparsity,
267        })
268    }
269
270    /// Detect dead neurons (weights close to zero)
271    fn detect_dead_neurons(&self, weights: &[f64]) -> Vec<usize> {
272        weights
273            .iter()
274            .enumerate()
275            .filter_map(
276                |(i, &w)| {
277                    if w.abs() < self.config.dead_neuron_threshold {
278                        Some(i)
279                    } else {
280                        None
281                    }
282                },
283            )
284            .collect()
285    }
286
287    /// Compute histogram of weight distribution
288    fn compute_histogram(&self, weights: &[f64]) -> Result<WeightHistogram> {
289        if weights.is_empty() {
290            anyhow::bail!("Cannot compute histogram for empty weight array");
291        }
292
293        let min = weights.iter().fold(f64::INFINITY, |a, &b| a.min(b));
294        let max = weights.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
295
296        let bin_width = (max - min) / self.config.num_bins as f64;
297        let mut bin_counts = vec![0; self.config.num_bins];
298
299        for &weight in weights {
300            let bin_idx =
301                if bin_width > 0.0 { ((weight - min) / bin_width).floor() as usize } else { 0 };
302            let bin_idx = bin_idx.min(self.config.num_bins - 1);
303            bin_counts[bin_idx] += 1;
304        }
305
306        let bin_edges: Vec<f64> =
307            (0..=self.config.num_bins).map(|i| min + i as f64 * bin_width).collect();
308
309        Ok(WeightHistogram {
310            bin_edges,
311            bin_counts,
312            total_count: weights.len(),
313        })
314    }
315
316    /// Check initialization scheme and detect issues
317    fn check_initialization(
318        &self,
319        stats: &WeightStatistics,
320    ) -> (Option<InitializationScheme>, Vec<String>) {
321        let mut warnings = Vec::new();
322        let mut likely_scheme = None;
323
324        // Check if weights are all zero (not initialized)
325        if stats.sparsity > 0.99 {
326            warnings.push("Weights appear to be uninitialized (all zeros)".to_string());
327            return (None, warnings);
328        }
329
330        // Check if weights are too large
331        if stats.std_dev > 1.0 {
332            warnings.push(format!(
333                "Weights have high variance (std_dev={:.4}), may cause gradient explosion",
334                stats.std_dev
335            ));
336        }
337
338        // Check if weights are too small
339        if stats.std_dev < 0.001 {
340            warnings.push(format!(
341                "Weights have very low variance (std_dev={:.4}), may cause gradient vanishing",
342                stats.std_dev
343            ));
344        }
345
346        // Infer initialization scheme based on statistics
347        // Xavier: mean ~ 0, std ~ sqrt(2 / (fan_in + fan_out))
348        // He: mean ~ 0, std ~ sqrt(2 / fan_in)
349        // Normal: mean ~ 0, std ~ some constant
350
351        if stats.mean.abs() < 0.01 {
352            // Mean is close to zero, likely a good initialization
353            if stats.std_dev > 0.01 && stats.std_dev < 0.2 {
354                // Check distribution shape
355                if stats.skewness.abs() < 0.5 && stats.kurtosis.abs() < 1.0 {
356                    likely_scheme = Some(InitializationScheme::XavierNormal);
357                } else {
358                    likely_scheme = Some(InitializationScheme::Normal);
359                }
360            } else if stats.std_dev < 0.01 {
361                likely_scheme = Some(InitializationScheme::Uniform);
362            }
363        }
364
365        (likely_scheme, warnings)
366    }
367
368    /// Get analysis for a specific layer
369    pub fn get_analysis(&self, layer_name: &str) -> Option<&WeightAnalysis> {
370        self.analyses.get(layer_name)
371    }
372
373    /// Get all layer names with analyses
374    pub fn get_layer_names(&self) -> Vec<String> {
375        self.analyses.keys().cloned().collect()
376    }
377
378    /// Print summary of all analyses
379    pub fn print_summary(&self) -> String {
380        let mut output = String::new();
381        output.push_str("Weight Distribution Summary\n");
382        output.push_str(&"=".repeat(80));
383        output.push('\n');
384
385        for (layer_name, analysis) in &self.analyses {
386            output.push_str(&format!("\nLayer: {}\n", layer_name));
387            output.push_str(&format!("  Mean: {:.6}\n", analysis.statistics.mean));
388            output.push_str(&format!("  Std Dev: {:.6}\n", analysis.statistics.std_dev));
389            output.push_str(&format!(
390                "  Range: [{:.6}, {:.6}]\n",
391                analysis.statistics.min, analysis.statistics.max
392            ));
393            output.push_str(&format!("  Median: {:.6}\n", analysis.statistics.median));
394            output.push_str(&format!(
395                "  Sparsity: {:.2}%\n",
396                analysis.statistics.sparsity * 100.0
397            ));
398            output.push_str(&format!(
399                "  Dead Neurons: {} ({:.2}%)\n",
400                analysis.dead_neurons.len(),
401                analysis.dead_neurons.len() as f64 / analysis.histogram.total_count as f64 * 100.0
402            ));
403
404            if let Some(scheme) = analysis.likely_init_scheme {
405                output.push_str(&format!("  Likely Init: {:?}\n", scheme));
406            }
407
408            if !analysis.init_warnings.is_empty() {
409                output.push_str("  Warnings:\n");
410                for warning in &analysis.init_warnings {
411                    output.push_str(&format!("    - {}\n", warning));
412                }
413            }
414        }
415
416        output
417    }
418
419    /// Export analysis to JSON
420    pub fn export_to_json(&self, layer_name: &str, output_path: &Path) -> Result<()> {
421        let analysis = self
422            .analyses
423            .get(layer_name)
424            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
425
426        let json = serde_json::to_string_pretty(analysis)?;
427        std::fs::write(output_path, json)?;
428
429        Ok(())
430    }
431
432    /// Plot weight distribution as ASCII histogram
433    pub fn plot_distribution_ascii(&self, layer_name: &str) -> Result<String> {
434        let analysis = self
435            .analyses
436            .get(layer_name)
437            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
438
439        let histogram = &analysis.histogram;
440        let max_count = histogram.bin_counts.iter().max().unwrap_or(&0);
441        let scale = if *max_count > 0 { 50.0 / *max_count as f64 } else { 1.0 };
442
443        let mut output = String::new();
444        output.push_str(&format!("Weight Distribution: {}\n", layer_name));
445        output.push_str(&"=".repeat(60));
446        output.push('\n');
447
448        for i in 0..histogram.bin_counts.len() {
449            let bar_length = (histogram.bin_counts[i] as f64 * scale) as usize;
450            let bar = "█".repeat(bar_length);
451            output.push_str(&format!(
452                "{:8.3} - {:8.3} | {} ({})\n",
453                histogram.bin_edges[i],
454                histogram.bin_edges[i + 1],
455                bar,
456                histogram.bin_counts[i]
457            ));
458        }
459
460        output.push_str("\nStatistics:\n");
461        output.push_str(&format!("  Mean: {:.6}\n", analysis.statistics.mean));
462        output.push_str(&format!("  Std Dev: {:.6}\n", analysis.statistics.std_dev));
463        output.push_str(&format!(
464            "  Skewness: {:.6}\n",
465            analysis.statistics.skewness
466        ));
467        output.push_str(&format!(
468            "  Kurtosis: {:.6}\n",
469            analysis.statistics.kurtosis
470        ));
471
472        Ok(output)
473    }
474
475    /// Clear all stored analyses
476    pub fn clear(&mut self) {
477        self.analyses.clear();
478    }
479
480    /// Get number of analyzed layers
481    pub fn num_layers(&self) -> usize {
482        self.analyses.len()
483    }
484}
485
486impl Default for WeightAnalyzer {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492/// Helper function to compute percentile
493fn percentile(sorted_values: &[f64], p: f64) -> f64 {
494    if sorted_values.is_empty() {
495        return 0.0;
496    }
497
498    let index = (p / 100.0 * (sorted_values.len() - 1) as f64).round() as usize;
499    sorted_values[index.min(sorted_values.len() - 1)]
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use std::env;
506
507    #[test]
508    fn test_weight_analyzer_creation() {
509        let analyzer = WeightAnalyzer::new();
510        assert_eq!(analyzer.num_layers(), 0);
511    }
512
513    #[test]
514    fn test_analyze_weights() {
515        let mut analyzer = WeightAnalyzer::new();
516        let weights = vec![0.1, 0.2, 0.15, 0.3, 0.25];
517
518        let analysis = analyzer.analyze("layer1", &weights).expect("operation failed in test");
519        assert_eq!(analysis.layer_name, "layer1");
520        assert!(analysis.statistics.mean > 0.0);
521        assert!(analysis.statistics.std_dev > 0.0);
522    }
523
524    #[test]
525    fn test_dead_neuron_detection() {
526        let mut analyzer = WeightAnalyzer::new();
527        let weights = vec![0.1, 0.0, 0.2, 0.0, 0.3]; // Two dead neurons
528
529        let analysis = analyzer.analyze("layer1", &weights).expect("operation failed in test");
530        assert_eq!(analysis.dead_neurons.len(), 2);
531    }
532
533    #[test]
534    fn test_compute_histogram() {
535        let analyzer = WeightAnalyzer::new();
536        let weights: Vec<f64> = (0..100).map(|x| x as f64 / 100.0).collect();
537
538        let histogram = analyzer.compute_histogram(&weights).expect("operation failed in test");
539        assert_eq!(histogram.bin_edges.len(), analyzer.config.num_bins + 1);
540        assert_eq!(histogram.total_count, 100);
541    }
542
543    #[test]
544    fn test_weight_statistics() {
545        let analyzer = WeightAnalyzer::new();
546        let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
547
548        let stats = analyzer.compute_statistics(&weights).expect("operation failed in test");
549        assert_eq!(stats.mean, 3.0);
550        assert!(stats.std_dev > 0.0);
551        assert_eq!(stats.min, 1.0);
552        assert_eq!(stats.max, 5.0);
553    }
554
555    #[test]
556    fn test_initialization_check() {
557        let analyzer = WeightAnalyzer::new();
558
559        // Simulate Xavier-like initialization
560        let stats = WeightStatistics {
561            mean: 0.001,
562            std_dev: 0.05,
563            min: -0.15,
564            max: 0.15,
565            median: 0.0,
566            q25: -0.03,
567            q75: 0.03,
568            skewness: 0.1,
569            kurtosis: 0.2,
570            l1_norm: 10.0,
571            l2_norm: 5.0,
572            num_zeros: 0,
573            sparsity: 0.0,
574        };
575
576        let (scheme, warnings) = analyzer.check_initialization(&stats);
577        assert!(scheme.is_some());
578        assert!(warnings.is_empty() || warnings.len() <= 1);
579    }
580
581    #[test]
582    fn test_export_to_json() {
583        let temp_dir = env::temp_dir();
584        let output_path = temp_dir.join("weight_analysis.json");
585
586        let mut analyzer = WeightAnalyzer::new();
587        analyzer.analyze("layer1", &[1.0, 2.0, 3.0]).expect("operation failed in test");
588
589        analyzer
590            .export_to_json("layer1", &output_path)
591            .expect("operation failed in test");
592        assert!(output_path.exists());
593
594        // Clean up
595        let _ = std::fs::remove_file(output_path);
596    }
597
598    #[test]
599    fn test_plot_distribution_ascii() {
600        let mut analyzer = WeightAnalyzer::new();
601        let weights: Vec<f64> = (0..100).map(|x| x as f64 / 100.0).collect();
602
603        analyzer.analyze("layer1", &weights).expect("operation failed in test");
604
605        let ascii_plot =
606            analyzer.plot_distribution_ascii("layer1").expect("operation failed in test");
607        assert!(ascii_plot.contains("Weight Distribution"));
608        assert!(ascii_plot.contains("layer1"));
609        assert!(ascii_plot.contains("Statistics"));
610    }
611
612    #[test]
613    fn test_print_summary() {
614        let mut analyzer = WeightAnalyzer::new();
615
616        analyzer.analyze("layer1", &[1.0, 2.0, 3.0]).expect("operation failed in test");
617        analyzer.analyze("layer2", &[0.5, 1.0, 1.5]).expect("operation failed in test");
618
619        let summary = analyzer.print_summary();
620        assert!(summary.contains("layer1"));
621        assert!(summary.contains("layer2"));
622        assert!(summary.contains("Mean"));
623        assert!(summary.contains("Std Dev"));
624    }
625
626    #[test]
627    fn test_sparsity_calculation() {
628        let analyzer = WeightAnalyzer::new();
629        let weights = vec![0.0, 0.0, 0.0, 1.0, 0.0];
630
631        let stats = analyzer.compute_statistics(&weights).expect("operation failed in test");
632        assert_eq!(stats.num_zeros, 4);
633        assert_eq!(stats.sparsity, 0.8);
634    }
635
636    #[test]
637    fn test_clear_analyses() {
638        let mut analyzer = WeightAnalyzer::new();
639
640        analyzer.analyze("layer1", &[1.0]).expect("operation failed in test");
641        analyzer.analyze("layer2", &[2.0]).expect("operation failed in test");
642
643        assert_eq!(analyzer.num_layers(), 2);
644
645        analyzer.clear();
646        assert_eq!(analyzer.num_layers(), 0);
647    }
648}