1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug)]
13pub struct WeightAnalyzer {
14 analyses: HashMap<String, WeightAnalysis>,
16 config: WeightAnalyzerConfig,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WeightAnalyzerConfig {
23 pub dead_neuron_threshold: f64,
25 pub num_bins: usize,
27 pub check_initialization: bool,
29 pub expected_init_schemes: Vec<InitializationScheme>,
31}
32
33impl Default for WeightAnalyzerConfig {
34 fn default() -> Self {
35 Self {
36 dead_neuron_threshold: 1e-8,
37 num_bins: 50,
38 check_initialization: true,
39 expected_init_schemes: vec![
40 InitializationScheme::XavierUniform,
41 InitializationScheme::HeNormal,
42 ],
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum InitializationScheme {
50 XavierUniform,
52 XavierNormal,
54 HeUniform,
56 HeNormal,
58 LeCunNormal,
60 Orthogonal,
62 Uniform,
64 Normal,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct WeightAnalysis {
71 pub layer_name: String,
73 pub statistics: WeightStatistics,
75 pub dead_neurons: Vec<usize>,
77 pub histogram: WeightHistogram,
79 pub likely_init_scheme: Option<InitializationScheme>,
81 pub init_warnings: Vec<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct WeightStatistics {
88 pub mean: f64,
90 pub std_dev: f64,
92 pub min: f64,
94 pub max: f64,
96 pub median: f64,
98 pub q25: f64,
100 pub q75: f64,
102 pub skewness: f64,
104 pub kurtosis: f64,
106 pub l1_norm: f64,
108 pub l2_norm: f64,
110 pub num_zeros: usize,
112 pub sparsity: f64,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct WeightHistogram {
119 pub bin_edges: Vec<f64>,
121 pub bin_counts: Vec<usize>,
123 pub total_count: usize,
125}
126
127impl WeightAnalyzer {
128 pub fn new() -> Self {
138 Self {
139 analyses: HashMap::new(),
140 config: WeightAnalyzerConfig::default(),
141 }
142 }
143
144 pub fn with_config(config: WeightAnalyzerConfig) -> Self {
146 Self {
147 analyses: HashMap::new(),
148 config,
149 }
150 }
151
152 pub fn analyze(&mut self, layer_name: &str, weights: &[f64]) -> Result<&WeightAnalysis> {
168 let statistics = self.compute_statistics(weights)?;
169 let dead_neurons = self.detect_dead_neurons(weights);
170 let histogram = self.compute_histogram(weights)?;
171 let (likely_init_scheme, init_warnings) = if self.config.check_initialization {
172 self.check_initialization(&statistics)
173 } else {
174 (None, Vec::new())
175 };
176
177 let analysis = WeightAnalysis {
178 layer_name: layer_name.to_string(),
179 statistics,
180 dead_neurons,
181 histogram,
182 likely_init_scheme,
183 init_warnings,
184 };
185
186 self.analyses.insert(layer_name.to_string(), analysis);
187 Ok(self.analyses.get(layer_name).unwrap())
188 }
189
190 fn compute_statistics(&self, weights: &[f64]) -> Result<WeightStatistics> {
192 if weights.is_empty() {
193 anyhow::bail!("Cannot compute statistics for empty weight array");
194 }
195
196 let n = weights.len() as f64;
197 let mean = weights.iter().sum::<f64>() / n;
198
199 let variance = weights
200 .iter()
201 .map(|&x| {
202 let diff = x - mean;
203 diff * diff
204 })
205 .sum::<f64>()
206 / n;
207 let std_dev = variance.sqrt();
208
209 let mut sorted = weights.to_vec();
210 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
211
212 let min = sorted[0];
213 let max = sorted[sorted.len() - 1];
214 let median = percentile(&sorted, 50.0);
215 let q25 = percentile(&sorted, 25.0);
216 let q75 = percentile(&sorted, 75.0);
217
218 let skewness = if std_dev > 0.0 {
220 weights
221 .iter()
222 .map(|&x| {
223 let z = (x - mean) / std_dev;
224 z * z * z
225 })
226 .sum::<f64>()
227 / n
228 } else {
229 0.0
230 };
231
232 let kurtosis = if std_dev > 0.0 {
234 weights
235 .iter()
236 .map(|&x| {
237 let z = (x - mean) / std_dev;
238 z * z * z * z
239 })
240 .sum::<f64>()
241 / n
242 - 3.0
243 } else {
244 0.0
245 };
246
247 let l1_norm = weights.iter().map(|x| x.abs()).sum::<f64>();
248 let l2_norm = weights.iter().map(|x| x * x).sum::<f64>().sqrt();
249
250 let num_zeros = weights.iter().filter(|&&x| x.abs() < 1e-10).count();
251 let sparsity = num_zeros as f64 / n;
252
253 Ok(WeightStatistics {
254 mean,
255 std_dev,
256 min,
257 max,
258 median,
259 q25,
260 q75,
261 skewness,
262 kurtosis,
263 l1_norm,
264 l2_norm,
265 num_zeros,
266 sparsity,
267 })
268 }
269
270 fn detect_dead_neurons(&self, weights: &[f64]) -> Vec<usize> {
272 weights
273 .iter()
274 .enumerate()
275 .filter_map(
276 |(i, &w)| {
277 if w.abs() < self.config.dead_neuron_threshold {
278 Some(i)
279 } else {
280 None
281 }
282 },
283 )
284 .collect()
285 }
286
287 fn compute_histogram(&self, weights: &[f64]) -> Result<WeightHistogram> {
289 if weights.is_empty() {
290 anyhow::bail!("Cannot compute histogram for empty weight array");
291 }
292
293 let min = weights.iter().fold(f64::INFINITY, |a, &b| a.min(b));
294 let max = weights.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
295
296 let bin_width = (max - min) / self.config.num_bins as f64;
297 let mut bin_counts = vec![0; self.config.num_bins];
298
299 for &weight in weights {
300 let bin_idx =
301 if bin_width > 0.0 { ((weight - min) / bin_width).floor() as usize } else { 0 };
302 let bin_idx = bin_idx.min(self.config.num_bins - 1);
303 bin_counts[bin_idx] += 1;
304 }
305
306 let bin_edges: Vec<f64> =
307 (0..=self.config.num_bins).map(|i| min + i as f64 * bin_width).collect();
308
309 Ok(WeightHistogram {
310 bin_edges,
311 bin_counts,
312 total_count: weights.len(),
313 })
314 }
315
316 fn check_initialization(
318 &self,
319 stats: &WeightStatistics,
320 ) -> (Option<InitializationScheme>, Vec<String>) {
321 let mut warnings = Vec::new();
322 let mut likely_scheme = None;
323
324 if stats.sparsity > 0.99 {
326 warnings.push("Weights appear to be uninitialized (all zeros)".to_string());
327 return (None, warnings);
328 }
329
330 if stats.std_dev > 1.0 {
332 warnings.push(format!(
333 "Weights have high variance (std_dev={:.4}), may cause gradient explosion",
334 stats.std_dev
335 ));
336 }
337
338 if stats.std_dev < 0.001 {
340 warnings.push(format!(
341 "Weights have very low variance (std_dev={:.4}), may cause gradient vanishing",
342 stats.std_dev
343 ));
344 }
345
346 if stats.mean.abs() < 0.01 {
352 if stats.std_dev > 0.01 && stats.std_dev < 0.2 {
354 if stats.skewness.abs() < 0.5 && stats.kurtosis.abs() < 1.0 {
356 likely_scheme = Some(InitializationScheme::XavierNormal);
357 } else {
358 likely_scheme = Some(InitializationScheme::Normal);
359 }
360 } else if stats.std_dev < 0.01 {
361 likely_scheme = Some(InitializationScheme::Uniform);
362 }
363 }
364
365 (likely_scheme, warnings)
366 }
367
368 pub fn get_analysis(&self, layer_name: &str) -> Option<&WeightAnalysis> {
370 self.analyses.get(layer_name)
371 }
372
373 pub fn get_layer_names(&self) -> Vec<String> {
375 self.analyses.keys().cloned().collect()
376 }
377
378 pub fn print_summary(&self) -> String {
380 let mut output = String::new();
381 output.push_str("Weight Distribution Summary\n");
382 output.push_str(&"=".repeat(80));
383 output.push('\n');
384
385 for (layer_name, analysis) in &self.analyses {
386 output.push_str(&format!("\nLayer: {}\n", layer_name));
387 output.push_str(&format!(" Mean: {:.6}\n", analysis.statistics.mean));
388 output.push_str(&format!(" Std Dev: {:.6}\n", analysis.statistics.std_dev));
389 output.push_str(&format!(
390 " Range: [{:.6}, {:.6}]\n",
391 analysis.statistics.min, analysis.statistics.max
392 ));
393 output.push_str(&format!(" Median: {:.6}\n", analysis.statistics.median));
394 output.push_str(&format!(
395 " Sparsity: {:.2}%\n",
396 analysis.statistics.sparsity * 100.0
397 ));
398 output.push_str(&format!(
399 " Dead Neurons: {} ({:.2}%)\n",
400 analysis.dead_neurons.len(),
401 analysis.dead_neurons.len() as f64 / analysis.histogram.total_count as f64 * 100.0
402 ));
403
404 if let Some(scheme) = analysis.likely_init_scheme {
405 output.push_str(&format!(" Likely Init: {:?}\n", scheme));
406 }
407
408 if !analysis.init_warnings.is_empty() {
409 output.push_str(" Warnings:\n");
410 for warning in &analysis.init_warnings {
411 output.push_str(&format!(" - {}\n", warning));
412 }
413 }
414 }
415
416 output
417 }
418
419 pub fn export_to_json(&self, layer_name: &str, output_path: &Path) -> Result<()> {
421 let analysis = self
422 .analyses
423 .get(layer_name)
424 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
425
426 let json = serde_json::to_string_pretty(analysis)?;
427 std::fs::write(output_path, json)?;
428
429 Ok(())
430 }
431
432 pub fn plot_distribution_ascii(&self, layer_name: &str) -> Result<String> {
434 let analysis = self
435 .analyses
436 .get(layer_name)
437 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
438
439 let histogram = &analysis.histogram;
440 let max_count = histogram.bin_counts.iter().max().unwrap_or(&0);
441 let scale = if *max_count > 0 { 50.0 / *max_count as f64 } else { 1.0 };
442
443 let mut output = String::new();
444 output.push_str(&format!("Weight Distribution: {}\n", layer_name));
445 output.push_str(&"=".repeat(60));
446 output.push('\n');
447
448 for i in 0..histogram.bin_counts.len() {
449 let bar_length = (histogram.bin_counts[i] as f64 * scale) as usize;
450 let bar = "█".repeat(bar_length);
451 output.push_str(&format!(
452 "{:8.3} - {:8.3} | {} ({})\n",
453 histogram.bin_edges[i],
454 histogram.bin_edges[i + 1],
455 bar,
456 histogram.bin_counts[i]
457 ));
458 }
459
460 output.push_str("\nStatistics:\n");
461 output.push_str(&format!(" Mean: {:.6}\n", analysis.statistics.mean));
462 output.push_str(&format!(" Std Dev: {:.6}\n", analysis.statistics.std_dev));
463 output.push_str(&format!(
464 " Skewness: {:.6}\n",
465 analysis.statistics.skewness
466 ));
467 output.push_str(&format!(
468 " Kurtosis: {:.6}\n",
469 analysis.statistics.kurtosis
470 ));
471
472 Ok(output)
473 }
474
475 pub fn clear(&mut self) {
477 self.analyses.clear();
478 }
479
480 pub fn num_layers(&self) -> usize {
482 self.analyses.len()
483 }
484}
485
486impl Default for WeightAnalyzer {
487 fn default() -> Self {
488 Self::new()
489 }
490}
491
492fn percentile(sorted_values: &[f64], p: f64) -> f64 {
494 if sorted_values.is_empty() {
495 return 0.0;
496 }
497
498 let index = (p / 100.0 * (sorted_values.len() - 1) as f64).round() as usize;
499 sorted_values[index.min(sorted_values.len() - 1)]
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use std::env;
506
507 #[test]
508 fn test_weight_analyzer_creation() {
509 let analyzer = WeightAnalyzer::new();
510 assert_eq!(analyzer.num_layers(), 0);
511 }
512
513 #[test]
514 fn test_analyze_weights() {
515 let mut analyzer = WeightAnalyzer::new();
516 let weights = vec![0.1, 0.2, 0.15, 0.3, 0.25];
517
518 let analysis = analyzer.analyze("layer1", &weights).unwrap();
519 assert_eq!(analysis.layer_name, "layer1");
520 assert!(analysis.statistics.mean > 0.0);
521 assert!(analysis.statistics.std_dev > 0.0);
522 }
523
524 #[test]
525 fn test_dead_neuron_detection() {
526 let mut analyzer = WeightAnalyzer::new();
527 let weights = vec![0.1, 0.0, 0.2, 0.0, 0.3]; let analysis = analyzer.analyze("layer1", &weights).unwrap();
530 assert_eq!(analysis.dead_neurons.len(), 2);
531 }
532
533 #[test]
534 fn test_compute_histogram() {
535 let analyzer = WeightAnalyzer::new();
536 let weights: Vec<f64> = (0..100).map(|x| x as f64 / 100.0).collect();
537
538 let histogram = analyzer.compute_histogram(&weights).unwrap();
539 assert_eq!(histogram.bin_edges.len(), analyzer.config.num_bins + 1);
540 assert_eq!(histogram.total_count, 100);
541 }
542
543 #[test]
544 fn test_weight_statistics() {
545 let analyzer = WeightAnalyzer::new();
546 let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
547
548 let stats = analyzer.compute_statistics(&weights).unwrap();
549 assert_eq!(stats.mean, 3.0);
550 assert!(stats.std_dev > 0.0);
551 assert_eq!(stats.min, 1.0);
552 assert_eq!(stats.max, 5.0);
553 }
554
555 #[test]
556 fn test_initialization_check() {
557 let analyzer = WeightAnalyzer::new();
558
559 let stats = WeightStatistics {
561 mean: 0.001,
562 std_dev: 0.05,
563 min: -0.15,
564 max: 0.15,
565 median: 0.0,
566 q25: -0.03,
567 q75: 0.03,
568 skewness: 0.1,
569 kurtosis: 0.2,
570 l1_norm: 10.0,
571 l2_norm: 5.0,
572 num_zeros: 0,
573 sparsity: 0.0,
574 };
575
576 let (scheme, warnings) = analyzer.check_initialization(&stats);
577 assert!(scheme.is_some());
578 assert!(warnings.is_empty() || warnings.len() <= 1);
579 }
580
581 #[test]
582 fn test_export_to_json() {
583 let temp_dir = env::temp_dir();
584 let output_path = temp_dir.join("weight_analysis.json");
585
586 let mut analyzer = WeightAnalyzer::new();
587 analyzer.analyze("layer1", &[1.0, 2.0, 3.0]).unwrap();
588
589 analyzer.export_to_json("layer1", &output_path).unwrap();
590 assert!(output_path.exists());
591
592 let _ = std::fs::remove_file(output_path);
594 }
595
596 #[test]
597 fn test_plot_distribution_ascii() {
598 let mut analyzer = WeightAnalyzer::new();
599 let weights: Vec<f64> = (0..100).map(|x| x as f64 / 100.0).collect();
600
601 analyzer.analyze("layer1", &weights).unwrap();
602
603 let ascii_plot = analyzer.plot_distribution_ascii("layer1").unwrap();
604 assert!(ascii_plot.contains("Weight Distribution"));
605 assert!(ascii_plot.contains("layer1"));
606 assert!(ascii_plot.contains("Statistics"));
607 }
608
609 #[test]
610 fn test_print_summary() {
611 let mut analyzer = WeightAnalyzer::new();
612
613 analyzer.analyze("layer1", &[1.0, 2.0, 3.0]).unwrap();
614 analyzer.analyze("layer2", &[0.5, 1.0, 1.5]).unwrap();
615
616 let summary = analyzer.print_summary();
617 assert!(summary.contains("layer1"));
618 assert!(summary.contains("layer2"));
619 assert!(summary.contains("Mean"));
620 assert!(summary.contains("Std Dev"));
621 }
622
623 #[test]
624 fn test_sparsity_calculation() {
625 let analyzer = WeightAnalyzer::new();
626 let weights = vec![0.0, 0.0, 0.0, 1.0, 0.0];
627
628 let stats = analyzer.compute_statistics(&weights).unwrap();
629 assert_eq!(stats.num_zeros, 4);
630 assert_eq!(stats.sparsity, 0.8);
631 }
632
633 #[test]
634 fn test_clear_analyses() {
635 let mut analyzer = WeightAnalyzer::new();
636
637 analyzer.analyze("layer1", &[1.0]).unwrap();
638 analyzer.analyze("layer2", &[2.0]).unwrap();
639
640 assert_eq!(analyzer.num_layers(), 2);
641
642 analyzer.clear();
643 assert_eq!(analyzer.num_layers(), 0);
644 }
645}