Skip to main content

trustformers_debug/utilities/
weight_analysis.rs

1//! Weight and gradient analysis utilities
2
3use anyhow::Result;
4use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::ArrayD;
5use serde::{Deserialize, Serialize};
6
7/// Layer exploding gradient information
8#[derive(Debug, Serialize, Deserialize)]
9pub struct ExplodingLayer {
10    pub layer_index: usize,
11    pub gradient_norm: f32,
12    pub severity: ExplosionSeverity,
13    pub recommended_action: String,
14}
15
16/// Severity levels for gradient explosion
17#[derive(Debug, Serialize, Deserialize)]
18pub enum ExplosionSeverity {
19    Low,
20    Medium,
21    High,
22    Critical,
23}
24
25/// Comprehensive gradient explosion analysis
26#[derive(Debug, Serialize, Deserialize)]
27pub struct GradientExplosionAnalysis {
28    pub exploding_layers: Vec<ExplodingLayer>,
29    pub max_gradient_norm: f32,
30    pub mean_gradient_norm: f32,
31    pub std_gradient_norm: f32,
32    pub explosion_ratio: f32,
33    pub overall_severity: ExplosionSeverity,
34    pub mitigation_recommendations: Vec<String>,
35}
36
37/// Weight distribution analysis
38#[derive(Debug, Serialize, Deserialize)]
39pub struct WeightDistributionAnalysis {
40    pub layer_analyses: Vec<LayerWeightAnalysis>,
41    pub overall_statistics: WeightStatistics,
42    pub distribution_health: DistributionHealth,
43    pub outlier_detection: Vec<WeightOutlier>,
44}
45
46/// Individual layer weight analysis
47#[derive(Debug, Serialize, Deserialize)]
48pub struct LayerWeightAnalysis {
49    pub layer_index: usize,
50    pub statistics: WeightStatistics,
51    pub health_score: f32,
52    pub issues: Vec<String>,
53    pub recommendations: Vec<String>,
54}
55
56/// Weight statistics for a layer or model
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct WeightStatistics {
59    pub mean: f32,
60    pub std_dev: f32,
61    pub skewness: f32,
62    pub kurtosis: f32,
63    pub entropy: f32,
64    pub min: f32,
65    pub max: f32,
66    pub zero_fraction: f32,
67}
68
69impl WeightStatistics {
70    pub fn accumulate(&mut self, other: &WeightStatistics) {
71        // Simple accumulation for overall statistics
72        self.mean += other.mean;
73        self.std_dev += other.std_dev;
74        self.skewness += other.skewness;
75        self.kurtosis += other.kurtosis;
76        self.entropy += other.entropy;
77        self.min = self.min.min(other.min);
78        self.max = self.max.max(other.max);
79        self.zero_fraction += other.zero_fraction;
80    }
81
82    pub fn finalize(&mut self, count: usize) {
83        if count > 0 {
84            let count_f32 = count as f32;
85            self.mean /= count_f32;
86            self.std_dev /= count_f32;
87            self.skewness /= count_f32;
88            self.kurtosis /= count_f32;
89            self.entropy /= count_f32;
90            self.zero_fraction /= count_f32;
91        }
92    }
93}
94
95/// Weight health assessment
96#[derive(Debug, Serialize, Deserialize)]
97pub struct WeightHealth {
98    pub score: f32,
99    pub issues: Vec<String>,
100    pub recommendations: Vec<String>,
101}
102
103/// Distribution health status
104#[derive(Debug, Serialize, Deserialize)]
105pub struct DistributionHealth {
106    pub score: f32,
107    pub status: DistributionHealthStatus,
108}
109
110/// Health status levels for weight distributions
111#[derive(Debug, Serialize, Deserialize)]
112pub enum DistributionHealthStatus {
113    Excellent,
114    Good,
115    Fair,
116    Poor,
117    Critical,
118}
119
120/// Weight outlier information
121#[derive(Debug, Serialize, Deserialize)]
122pub struct WeightOutlier {
123    pub layer_index: usize,
124    pub weight_index: usize,
125    pub value: f32,
126    pub z_score: f32,
127    pub severity: OutlierSeverity,
128}
129
130/// Severity levels for weight outliers
131#[derive(Debug, Serialize, Deserialize)]
132pub enum OutlierSeverity {
133    Medium,
134    High,
135}
136
137/// Weight drift analysis between model states
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct WeightDriftAnalysis {
140    pub mean_drift: f32,
141    pub max_drift: f32,
142    pub severity: WeightDriftSeverity,
143    pub affected_layers: Vec<usize>,
144}
145
146/// Drift severity levels
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub enum WeightDriftSeverity {
149    Minimal,
150    Low,
151    Medium,
152    High,
153}
154
155/// Weight and gradient analysis utilities
156pub struct WeightAnalyzer;
157
158impl WeightAnalyzer {
159    /// Detect gradient explosion patterns in a set of gradients
160    pub fn detect_gradient_explosion(
161        gradients: &[ArrayD<f32>],
162        threshold: f32,
163    ) -> GradientExplosionAnalysis {
164        let mut exploding_layers = Vec::new();
165        let mut max_gradient_norm = 0.0f32;
166        let mut gradient_norms = Vec::new();
167
168        for (layer_idx, gradient) in gradients.iter().enumerate() {
169            let l2_norm = Self::compute_l2_norm(gradient);
170            gradient_norms.push(l2_norm);
171
172            if l2_norm > max_gradient_norm {
173                max_gradient_norm = l2_norm;
174            }
175
176            if l2_norm > threshold {
177                exploding_layers.push(ExplodingLayer {
178                    layer_index: layer_idx,
179                    gradient_norm: l2_norm,
180                    severity: Self::classify_explosion_severity(l2_norm, &gradient_norms),
181                    recommended_action: Self::recommend_explosion_mitigation(l2_norm),
182                });
183            }
184        }
185
186        let mean_norm = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
187        let std_norm = {
188            let variance: f32 =
189                gradient_norms.iter().map(|&x| (x - mean_norm).powi(2)).sum::<f32>()
190                    / gradient_norms.len() as f32;
191            variance.sqrt()
192        };
193
194        let explosion_ratio = exploding_layers.len() as f32 / gradients.len() as f32;
195
196        let overall_severity = if explosion_ratio > 0.5 || max_gradient_norm > threshold * 10.0 {
197            ExplosionSeverity::Critical
198        } else if explosion_ratio > 0.3 || max_gradient_norm > threshold * 5.0 {
199            ExplosionSeverity::High
200        } else if explosion_ratio > 0.1 || max_gradient_norm > threshold * 2.0 {
201            ExplosionSeverity::Medium
202        } else {
203            ExplosionSeverity::Low
204        };
205
206        GradientExplosionAnalysis {
207            exploding_layers,
208            max_gradient_norm,
209            mean_gradient_norm: mean_norm,
210            std_gradient_norm: std_norm,
211            explosion_ratio,
212            overall_severity,
213            mitigation_recommendations: Self::generate_explosion_recommendations(
214                explosion_ratio,
215                max_gradient_norm,
216            ),
217        }
218    }
219
220    /// Analyze weight distributions across model layers
221    pub fn analyze_weight_distribution(
222        weights: &[ArrayD<f32>],
223    ) -> Result<WeightDistributionAnalysis> {
224        let mut layer_analyses = Vec::new();
225        let mut overall_stats = WeightStatistics::default();
226        let mut all_outliers = Vec::new();
227
228        for (layer_idx, weight_tensor) in weights.iter().enumerate() {
229            let layer_stats = Self::compute_weight_statistics(weight_tensor)?;
230            let health_score = Self::compute_weight_health_score(&layer_stats);
231            let outliers = Self::detect_weight_outliers(weight_tensor, layer_idx)?;
232
233            let issues = Self::identify_weight_issues(&layer_stats);
234            let recommendations = Self::generate_weight_recommendations(&issues);
235
236            layer_analyses.push(LayerWeightAnalysis {
237                layer_index: layer_idx,
238                statistics: layer_stats.clone(),
239                health_score,
240                issues,
241                recommendations,
242            });
243
244            overall_stats.accumulate(&layer_stats);
245            all_outliers.extend(outliers);
246        }
247
248        overall_stats.finalize(weights.len());
249
250        let distribution_health = Self::assess_distribution_health(&overall_stats);
251
252        Ok(WeightDistributionAnalysis {
253            layer_analyses,
254            overall_statistics: overall_stats,
255            distribution_health,
256            outlier_detection: all_outliers,
257        })
258    }
259
260    /// Compute L2 norm of a tensor
261    pub fn compute_l2_norm(tensor: &ArrayD<f32>) -> f32 {
262        tensor.iter().map(|&x| x * x).sum::<f32>().sqrt()
263    }
264
265    /// Compute comprehensive weight statistics for a tensor
266    fn compute_weight_statistics(tensor: &ArrayD<f32>) -> Result<WeightStatistics> {
267        let data: Vec<f32> = tensor.iter().cloned().collect();
268        let count = data.len();
269
270        if count == 0 {
271            return Ok(WeightStatistics::default());
272        }
273
274        // Basic statistics
275        let sum: f32 = data.iter().sum();
276        let mean = sum / count as f32;
277
278        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / count as f32;
279        let std_dev = variance.sqrt();
280
281        // Min/max
282        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
283        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
284
285        // Zero fraction
286        let zero_count = data.iter().filter(|&&x| x == 0.0).count();
287        let zero_fraction = zero_count as f32 / count as f32;
288
289        // Higher order moments
290        let skewness = Self::compute_skewness(&data, mean, std_dev);
291        let kurtosis = Self::compute_kurtosis(&data, mean, std_dev);
292        let entropy = Self::compute_entropy(&data);
293
294        Ok(WeightStatistics {
295            mean,
296            std_dev,
297            skewness,
298            kurtosis,
299            entropy,
300            min,
301            max,
302            zero_fraction,
303        })
304    }
305
306    /// Classify gradient explosion severity
307    fn classify_explosion_severity(norm: f32, all_norms: &[f32]) -> ExplosionSeverity {
308        let mean_norm = all_norms.iter().sum::<f32>() / all_norms.len() as f32;
309        let ratio = norm / (mean_norm + 1e-8);
310
311        if ratio > 100.0 {
312            ExplosionSeverity::Critical
313        } else if ratio > 50.0 {
314            ExplosionSeverity::High
315        } else if ratio > 10.0 {
316            ExplosionSeverity::Medium
317        } else {
318            ExplosionSeverity::Low
319        }
320    }
321
322    /// Recommend mitigation for gradient explosion
323    fn recommend_explosion_mitigation(norm: f32) -> String {
324        if norm > 100.0 {
325            "Critical gradient explosion: Reduce learning rate by 10x and implement gradient clipping".to_string()
326        } else if norm > 10.0 {
327            "High gradient explosion: Reduce learning rate and implement gradient clipping"
328                .to_string()
329        } else if norm > 5.0 {
330            "Moderate gradient explosion: Consider gradient clipping or learning rate reduction"
331                .to_string()
332        } else {
333            "Monitor gradients for stability".to_string()
334        }
335    }
336
337    /// Generate recommendations for gradient explosion mitigation
338    fn generate_explosion_recommendations(explosion_ratio: f32, max_norm: f32) -> Vec<String> {
339        let mut recommendations = Vec::new();
340
341        if explosion_ratio > 0.3 {
342            recommendations.push("High proportion of exploding gradients detected".to_string());
343            recommendations.push("Consider significant learning rate reduction".to_string());
344        }
345
346        if max_norm > 100.0 {
347            recommendations.push("Extremely large gradients detected".to_string());
348            recommendations.push("Implement gradient clipping with threshold < 1.0".to_string());
349        }
350
351        recommendations.push("Monitor gradient norms during training".to_string());
352        recommendations.push("Consider batch normalization or layer normalization".to_string());
353
354        recommendations
355    }
356
357    /// Compute weight health score
358    fn compute_weight_health_score(stats: &WeightStatistics) -> f32 {
359        let mut score: f32 = 100.0;
360
361        // Penalize extreme values
362        if stats.max.abs() > 10.0 || stats.min.abs() > 10.0 {
363            score -= 20.0;
364        }
365
366        // Penalize high zero fraction (dead neurons)
367        if stats.zero_fraction > 0.5 {
368            score -= 30.0;
369        }
370
371        // Penalize extreme skewness or kurtosis
372        if stats.skewness.abs() > 2.0 {
373            score -= 15.0;
374        }
375        if stats.kurtosis > 10.0 {
376            score -= 15.0;
377        }
378
379        score.max(0.0)
380    }
381
382    /// Detect weight outliers in a tensor
383    fn detect_weight_outliers(
384        tensor: &ArrayD<f32>,
385        layer_idx: usize,
386    ) -> Result<Vec<WeightOutlier>> {
387        let data: Vec<f32> = tensor.iter().cloned().collect();
388        let mean = data.iter().sum::<f32>() / data.len() as f32;
389        let std_dev = {
390            let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
391            variance.sqrt()
392        };
393
394        let mut outliers = Vec::new();
395
396        for (idx, &value) in data.iter().enumerate() {
397            let z_score = ((value - mean) / std_dev).abs();
398
399            if z_score > 3.0 {
400                let severity =
401                    if z_score > 5.0 { OutlierSeverity::High } else { OutlierSeverity::Medium };
402
403                outliers.push(WeightOutlier {
404                    layer_index: layer_idx,
405                    weight_index: idx,
406                    value,
407                    z_score,
408                    severity,
409                });
410            }
411        }
412
413        Ok(outliers)
414    }
415
416    /// Assess overall distribution health
417    fn assess_distribution_health(stats: &WeightStatistics) -> DistributionHealth {
418        let mut score = 100.0;
419
420        // Factor in various metrics
421        if stats.zero_fraction > 0.3 {
422            score -= 25.0;
423        }
424        if stats.skewness.abs() > 1.0 {
425            score -= 15.0;
426        }
427        if stats.kurtosis > 5.0 {
428            score -= 15.0;
429        }
430        if stats.max.abs() > 5.0 || stats.min.abs() > 5.0 {
431            score -= 20.0;
432        }
433
434        let status = match score {
435            s if s >= 90.0 => DistributionHealthStatus::Excellent,
436            s if s >= 75.0 => DistributionHealthStatus::Good,
437            s if s >= 60.0 => DistributionHealthStatus::Fair,
438            s if s >= 40.0 => DistributionHealthStatus::Poor,
439            _ => DistributionHealthStatus::Critical,
440        };
441
442        DistributionHealth { score, status }
443    }
444
445    /// Identify issues in weight statistics
446    fn identify_weight_issues(stats: &WeightStatistics) -> Vec<String> {
447        let mut issues = Vec::new();
448
449        if stats.zero_fraction > 0.5 {
450            issues.push("High proportion of zero weights (dead neurons)".to_string());
451        }
452
453        if stats.skewness.abs() > 2.0 {
454            issues.push("Highly skewed weight distribution".to_string());
455        }
456
457        if stats.kurtosis > 10.0 {
458            issues.push("Heavy-tailed weight distribution".to_string());
459        }
460
461        if stats.max.abs() > 10.0 || stats.min.abs() > 10.0 {
462            issues.push("Extreme weight values detected".to_string());
463        }
464
465        issues
466    }
467
468    /// Generate recommendations based on weight issues
469    fn generate_weight_recommendations(issues: &[String]) -> Vec<String> {
470        let mut recommendations = Vec::new();
471
472        for issue in issues {
473            match issue.as_str() {
474                s if s.contains("dead neurons") => {
475                    recommendations.push(
476                        "Consider reducing learning rate or changing activation function"
477                            .to_string(),
478                    );
479                },
480                s if s.contains("skewed") => {
481                    recommendations.push(
482                        "Consider weight normalization or different initialization".to_string(),
483                    );
484                },
485                s if s.contains("heavy-tailed") => {
486                    recommendations.push("Monitor for gradient instability".to_string());
487                },
488                s if s.contains("extreme") => {
489                    recommendations.push("Implement weight clipping or regularization".to_string());
490                },
491                _ => {},
492            }
493        }
494
495        recommendations
496    }
497
498    /// Compute skewness of data
499    fn compute_skewness(data: &[f32], mean: f32, std_dev: f32) -> f32 {
500        if std_dev == 0.0 || data.len() < 3 {
501            return 0.0;
502        }
503
504        let n = data.len() as f32;
505        data.iter().map(|&x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n
506    }
507
508    /// Compute kurtosis of data
509    fn compute_kurtosis(data: &[f32], mean: f32, std_dev: f32) -> f32 {
510        if std_dev == 0.0 || data.len() < 4 {
511            return 0.0;
512        }
513
514        let n = data.len() as f32;
515        data.iter().map(|&x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n - 3.0
516        // Excess kurtosis
517    }
518
519    /// Compute entropy of data (simplified)
520    fn compute_entropy(data: &[f32]) -> f32 {
521        // Simplified entropy computation
522        // In practice, this would discretize the data and compute proper entropy
523        let std_dev = {
524            let mean = data.iter().sum::<f32>() / data.len() as f32;
525            let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
526            variance.sqrt()
527        };
528
529        // Higher std_dev implies higher entropy (roughly)
530        std_dev.log2().max(0.0)
531    }
532}