trustformers_debug/utilities/
weight_analysis.rs1use anyhow::Result;
4use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Serialize, Deserialize)]
9pub struct ExplodingLayer {
10 pub layer_index: usize,
11 pub gradient_norm: f32,
12 pub severity: ExplosionSeverity,
13 pub recommended_action: String,
14}
15
16#[derive(Debug, Serialize, Deserialize)]
18pub enum ExplosionSeverity {
19 Low,
20 Medium,
21 High,
22 Critical,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
27pub struct GradientExplosionAnalysis {
28 pub exploding_layers: Vec<ExplodingLayer>,
29 pub max_gradient_norm: f32,
30 pub mean_gradient_norm: f32,
31 pub std_gradient_norm: f32,
32 pub explosion_ratio: f32,
33 pub overall_severity: ExplosionSeverity,
34 pub mitigation_recommendations: Vec<String>,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
39pub struct WeightDistributionAnalysis {
40 pub layer_analyses: Vec<LayerWeightAnalysis>,
41 pub overall_statistics: WeightStatistics,
42 pub distribution_health: DistributionHealth,
43 pub outlier_detection: Vec<WeightOutlier>,
44}
45
46#[derive(Debug, Serialize, Deserialize)]
48pub struct LayerWeightAnalysis {
49 pub layer_index: usize,
50 pub statistics: WeightStatistics,
51 pub health_score: f32,
52 pub issues: Vec<String>,
53 pub recommendations: Vec<String>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct WeightStatistics {
59 pub mean: f32,
60 pub std_dev: f32,
61 pub skewness: f32,
62 pub kurtosis: f32,
63 pub entropy: f32,
64 pub min: f32,
65 pub max: f32,
66 pub zero_fraction: f32,
67}
68
69impl WeightStatistics {
70 pub fn accumulate(&mut self, other: &WeightStatistics) {
71 self.mean += other.mean;
73 self.std_dev += other.std_dev;
74 self.skewness += other.skewness;
75 self.kurtosis += other.kurtosis;
76 self.entropy += other.entropy;
77 self.min = self.min.min(other.min);
78 self.max = self.max.max(other.max);
79 self.zero_fraction += other.zero_fraction;
80 }
81
82 pub fn finalize(&mut self, count: usize) {
83 if count > 0 {
84 let count_f32 = count as f32;
85 self.mean /= count_f32;
86 self.std_dev /= count_f32;
87 self.skewness /= count_f32;
88 self.kurtosis /= count_f32;
89 self.entropy /= count_f32;
90 self.zero_fraction /= count_f32;
91 }
92 }
93}
94
95#[derive(Debug, Serialize, Deserialize)]
97pub struct WeightHealth {
98 pub score: f32,
99 pub issues: Vec<String>,
100 pub recommendations: Vec<String>,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
105pub struct DistributionHealth {
106 pub score: f32,
107 pub status: DistributionHealthStatus,
108}
109
110#[derive(Debug, Serialize, Deserialize)]
112pub enum DistributionHealthStatus {
113 Excellent,
114 Good,
115 Fair,
116 Poor,
117 Critical,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
122pub struct WeightOutlier {
123 pub layer_index: usize,
124 pub weight_index: usize,
125 pub value: f32,
126 pub z_score: f32,
127 pub severity: OutlierSeverity,
128}
129
130#[derive(Debug, Serialize, Deserialize)]
132pub enum OutlierSeverity {
133 Medium,
134 High,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct WeightDriftAnalysis {
140 pub mean_drift: f32,
141 pub max_drift: f32,
142 pub severity: WeightDriftSeverity,
143 pub affected_layers: Vec<usize>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub enum WeightDriftSeverity {
149 Minimal,
150 Low,
151 Medium,
152 High,
153}
154
155pub struct WeightAnalyzer;
157
158impl WeightAnalyzer {
159 pub fn detect_gradient_explosion(
161 gradients: &[ArrayD<f32>],
162 threshold: f32,
163 ) -> GradientExplosionAnalysis {
164 let mut exploding_layers = Vec::new();
165 let mut max_gradient_norm = 0.0f32;
166 let mut gradient_norms = Vec::new();
167
168 for (layer_idx, gradient) in gradients.iter().enumerate() {
169 let l2_norm = Self::compute_l2_norm(gradient);
170 gradient_norms.push(l2_norm);
171
172 if l2_norm > max_gradient_norm {
173 max_gradient_norm = l2_norm;
174 }
175
176 if l2_norm > threshold {
177 exploding_layers.push(ExplodingLayer {
178 layer_index: layer_idx,
179 gradient_norm: l2_norm,
180 severity: Self::classify_explosion_severity(l2_norm, &gradient_norms),
181 recommended_action: Self::recommend_explosion_mitigation(l2_norm),
182 });
183 }
184 }
185
186 let mean_norm = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
187 let std_norm = {
188 let variance: f32 =
189 gradient_norms.iter().map(|&x| (x - mean_norm).powi(2)).sum::<f32>()
190 / gradient_norms.len() as f32;
191 variance.sqrt()
192 };
193
194 let explosion_ratio = exploding_layers.len() as f32 / gradients.len() as f32;
195
196 let overall_severity = if explosion_ratio > 0.5 || max_gradient_norm > threshold * 10.0 {
197 ExplosionSeverity::Critical
198 } else if explosion_ratio > 0.3 || max_gradient_norm > threshold * 5.0 {
199 ExplosionSeverity::High
200 } else if explosion_ratio > 0.1 || max_gradient_norm > threshold * 2.0 {
201 ExplosionSeverity::Medium
202 } else {
203 ExplosionSeverity::Low
204 };
205
206 GradientExplosionAnalysis {
207 exploding_layers,
208 max_gradient_norm,
209 mean_gradient_norm: mean_norm,
210 std_gradient_norm: std_norm,
211 explosion_ratio,
212 overall_severity,
213 mitigation_recommendations: Self::generate_explosion_recommendations(
214 explosion_ratio,
215 max_gradient_norm,
216 ),
217 }
218 }
219
220 pub fn analyze_weight_distribution(
222 weights: &[ArrayD<f32>],
223 ) -> Result<WeightDistributionAnalysis> {
224 let mut layer_analyses = Vec::new();
225 let mut overall_stats = WeightStatistics::default();
226 let mut all_outliers = Vec::new();
227
228 for (layer_idx, weight_tensor) in weights.iter().enumerate() {
229 let layer_stats = Self::compute_weight_statistics(weight_tensor)?;
230 let health_score = Self::compute_weight_health_score(&layer_stats);
231 let outliers = Self::detect_weight_outliers(weight_tensor, layer_idx)?;
232
233 let issues = Self::identify_weight_issues(&layer_stats);
234 let recommendations = Self::generate_weight_recommendations(&issues);
235
236 layer_analyses.push(LayerWeightAnalysis {
237 layer_index: layer_idx,
238 statistics: layer_stats.clone(),
239 health_score,
240 issues,
241 recommendations,
242 });
243
244 overall_stats.accumulate(&layer_stats);
245 all_outliers.extend(outliers);
246 }
247
248 overall_stats.finalize(weights.len());
249
250 let distribution_health = Self::assess_distribution_health(&overall_stats);
251
252 Ok(WeightDistributionAnalysis {
253 layer_analyses,
254 overall_statistics: overall_stats,
255 distribution_health,
256 outlier_detection: all_outliers,
257 })
258 }
259
260 pub fn compute_l2_norm(tensor: &ArrayD<f32>) -> f32 {
262 tensor.iter().map(|&x| x * x).sum::<f32>().sqrt()
263 }
264
265 fn compute_weight_statistics(tensor: &ArrayD<f32>) -> Result<WeightStatistics> {
267 let data: Vec<f32> = tensor.iter().cloned().collect();
268 let count = data.len();
269
270 if count == 0 {
271 return Ok(WeightStatistics::default());
272 }
273
274 let sum: f32 = data.iter().sum();
276 let mean = sum / count as f32;
277
278 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / count as f32;
279 let std_dev = variance.sqrt();
280
281 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
283 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
284
285 let zero_count = data.iter().filter(|&&x| x == 0.0).count();
287 let zero_fraction = zero_count as f32 / count as f32;
288
289 let skewness = Self::compute_skewness(&data, mean, std_dev);
291 let kurtosis = Self::compute_kurtosis(&data, mean, std_dev);
292 let entropy = Self::compute_entropy(&data);
293
294 Ok(WeightStatistics {
295 mean,
296 std_dev,
297 skewness,
298 kurtosis,
299 entropy,
300 min,
301 max,
302 zero_fraction,
303 })
304 }
305
306 fn classify_explosion_severity(norm: f32, all_norms: &[f32]) -> ExplosionSeverity {
308 let mean_norm = all_norms.iter().sum::<f32>() / all_norms.len() as f32;
309 let ratio = norm / (mean_norm + 1e-8);
310
311 if ratio > 100.0 {
312 ExplosionSeverity::Critical
313 } else if ratio > 50.0 {
314 ExplosionSeverity::High
315 } else if ratio > 10.0 {
316 ExplosionSeverity::Medium
317 } else {
318 ExplosionSeverity::Low
319 }
320 }
321
322 fn recommend_explosion_mitigation(norm: f32) -> String {
324 if norm > 100.0 {
325 "Critical gradient explosion: Reduce learning rate by 10x and implement gradient clipping".to_string()
326 } else if norm > 10.0 {
327 "High gradient explosion: Reduce learning rate and implement gradient clipping"
328 .to_string()
329 } else if norm > 5.0 {
330 "Moderate gradient explosion: Consider gradient clipping or learning rate reduction"
331 .to_string()
332 } else {
333 "Monitor gradients for stability".to_string()
334 }
335 }
336
337 fn generate_explosion_recommendations(explosion_ratio: f32, max_norm: f32) -> Vec<String> {
339 let mut recommendations = Vec::new();
340
341 if explosion_ratio > 0.3 {
342 recommendations.push("High proportion of exploding gradients detected".to_string());
343 recommendations.push("Consider significant learning rate reduction".to_string());
344 }
345
346 if max_norm > 100.0 {
347 recommendations.push("Extremely large gradients detected".to_string());
348 recommendations.push("Implement gradient clipping with threshold < 1.0".to_string());
349 }
350
351 recommendations.push("Monitor gradient norms during training".to_string());
352 recommendations.push("Consider batch normalization or layer normalization".to_string());
353
354 recommendations
355 }
356
357 fn compute_weight_health_score(stats: &WeightStatistics) -> f32 {
359 let mut score: f32 = 100.0;
360
361 if stats.max.abs() > 10.0 || stats.min.abs() > 10.0 {
363 score -= 20.0;
364 }
365
366 if stats.zero_fraction > 0.5 {
368 score -= 30.0;
369 }
370
371 if stats.skewness.abs() > 2.0 {
373 score -= 15.0;
374 }
375 if stats.kurtosis > 10.0 {
376 score -= 15.0;
377 }
378
379 score.max(0.0)
380 }
381
382 fn detect_weight_outliers(
384 tensor: &ArrayD<f32>,
385 layer_idx: usize,
386 ) -> Result<Vec<WeightOutlier>> {
387 let data: Vec<f32> = tensor.iter().cloned().collect();
388 let mean = data.iter().sum::<f32>() / data.len() as f32;
389 let std_dev = {
390 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
391 variance.sqrt()
392 };
393
394 let mut outliers = Vec::new();
395
396 for (idx, &value) in data.iter().enumerate() {
397 let z_score = ((value - mean) / std_dev).abs();
398
399 if z_score > 3.0 {
400 let severity =
401 if z_score > 5.0 { OutlierSeverity::High } else { OutlierSeverity::Medium };
402
403 outliers.push(WeightOutlier {
404 layer_index: layer_idx,
405 weight_index: idx,
406 value,
407 z_score,
408 severity,
409 });
410 }
411 }
412
413 Ok(outliers)
414 }
415
416 fn assess_distribution_health(stats: &WeightStatistics) -> DistributionHealth {
418 let mut score = 100.0;
419
420 if stats.zero_fraction > 0.3 {
422 score -= 25.0;
423 }
424 if stats.skewness.abs() > 1.0 {
425 score -= 15.0;
426 }
427 if stats.kurtosis > 5.0 {
428 score -= 15.0;
429 }
430 if stats.max.abs() > 5.0 || stats.min.abs() > 5.0 {
431 score -= 20.0;
432 }
433
434 let status = match score {
435 s if s >= 90.0 => DistributionHealthStatus::Excellent,
436 s if s >= 75.0 => DistributionHealthStatus::Good,
437 s if s >= 60.0 => DistributionHealthStatus::Fair,
438 s if s >= 40.0 => DistributionHealthStatus::Poor,
439 _ => DistributionHealthStatus::Critical,
440 };
441
442 DistributionHealth { score, status }
443 }
444
445 fn identify_weight_issues(stats: &WeightStatistics) -> Vec<String> {
447 let mut issues = Vec::new();
448
449 if stats.zero_fraction > 0.5 {
450 issues.push("High proportion of zero weights (dead neurons)".to_string());
451 }
452
453 if stats.skewness.abs() > 2.0 {
454 issues.push("Highly skewed weight distribution".to_string());
455 }
456
457 if stats.kurtosis > 10.0 {
458 issues.push("Heavy-tailed weight distribution".to_string());
459 }
460
461 if stats.max.abs() > 10.0 || stats.min.abs() > 10.0 {
462 issues.push("Extreme weight values detected".to_string());
463 }
464
465 issues
466 }
467
468 fn generate_weight_recommendations(issues: &[String]) -> Vec<String> {
470 let mut recommendations = Vec::new();
471
472 for issue in issues {
473 match issue.as_str() {
474 s if s.contains("dead neurons") => {
475 recommendations.push(
476 "Consider reducing learning rate or changing activation function"
477 .to_string(),
478 );
479 },
480 s if s.contains("skewed") => {
481 recommendations.push(
482 "Consider weight normalization or different initialization".to_string(),
483 );
484 },
485 s if s.contains("heavy-tailed") => {
486 recommendations.push("Monitor for gradient instability".to_string());
487 },
488 s if s.contains("extreme") => {
489 recommendations.push("Implement weight clipping or regularization".to_string());
490 },
491 _ => {},
492 }
493 }
494
495 recommendations
496 }
497
498 fn compute_skewness(data: &[f32], mean: f32, std_dev: f32) -> f32 {
500 if std_dev == 0.0 || data.len() < 3 {
501 return 0.0;
502 }
503
504 let n = data.len() as f32;
505 data.iter().map(|&x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n
506 }
507
508 fn compute_kurtosis(data: &[f32], mean: f32, std_dev: f32) -> f32 {
510 if std_dev == 0.0 || data.len() < 4 {
511 return 0.0;
512 }
513
514 let n = data.len() as f32;
515 data.iter().map(|&x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n - 3.0
516 }
518
519 fn compute_entropy(data: &[f32]) -> f32 {
521 let std_dev = {
524 let mean = data.iter().sum::<f32>() / data.len() as f32;
525 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
526 variance.sqrt()
527 };
528
529 std_dev.log2().max(0.0)
531 }
532}