Skip to main content

trustformers_debug/gradient_debugger/
debugger.rs

1//! Main Gradient Debugger Implementation
2//!
3//! This module provides the main GradientDebugger that orchestrates all gradient
4//! debugging capabilities including monitoring, anomaly detection, performance tracking,
5//! conflict analysis, visualization, and enhanced analysis.
6
7use super::anomaly_detection::*;
8use super::conflict_analysis::*;
9use super::enhanced_analysis::*;
10use super::monitoring::*;
11use super::performance_tracking::*;
12use super::types::*;
13use super::visualization::*;
14use crate::DebugConfig;
15use anyhow::Result;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19/// Flow analysis for gradient flow patterns
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct FlowAnalysis {
22    pub layer_analyses: HashMap<String, LayerFlowAnalysis>,
23}
24
25/// Analysis of gradient flow for a specific layer
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct LayerFlowAnalysis {
28    pub layer_name: String,
29    pub is_vanishing: bool,
30    pub is_exploding: bool,
31    pub gradient_norm: f64,
32    pub flow_consistency: f64,
33}
34
35/// Main gradient debugger
36#[derive(Debug)]
37pub struct GradientDebugger {
38    #[allow(dead_code)]
39    config: DebugConfig,
40    gradient_config: GradientDebugConfig,
41    gradient_histories: HashMap<String, GradientHistory>,
42    current_step: usize,
43    alerts: Vec<GradientAlert>,
44    layer_no_gradient_count: HashMap<String, usize>,
45
46    // Advanced features
47    adaptive_thresholds: HashMap<String, AdaptiveThresholds>,
48    real_time_monitors: HashMap<String, RealTimeGradientMonitor>,
49    anomaly_detector: GradientAnomalyDetector,
50    performance_tracker: GradientPerformanceTracker,
51    conflict_analyzer: GradientConflictAnalyzer,
52    flow_visualizer: GradientFlowVisualizer,
53    enhanced_analyzer: EnhancedGradientAnalyzer,
54}
55
56impl GradientDebugger {
57    /// Create a new gradient debugger
58    pub fn new(config: DebugConfig) -> Self {
59        let gradient_config = GradientDebugConfig::default();
60
61        Self {
62            config,
63            gradient_config: gradient_config.clone(),
64            gradient_histories: HashMap::new(),
65            current_step: 0,
66            alerts: Vec::new(),
67            layer_no_gradient_count: HashMap::new(),
68            adaptive_thresholds: HashMap::new(),
69            real_time_monitors: HashMap::new(),
70            anomaly_detector: GradientAnomalyDetector::default(),
71            performance_tracker: GradientPerformanceTracker::default(),
72            conflict_analyzer: GradientConflictAnalyzer::default(),
73            flow_visualizer: GradientFlowVisualizer::default(),
74            enhanced_analyzer: EnhancedGradientAnalyzer::default(),
75        }
76    }
77
78    /// Create with custom gradient configuration
79    pub fn with_gradient_config(config: DebugConfig, gradient_config: GradientDebugConfig) -> Self {
80        Self {
81            config,
82            gradient_config: gradient_config.clone(),
83            gradient_histories: HashMap::new(),
84            current_step: 0,
85            alerts: Vec::new(),
86            layer_no_gradient_count: HashMap::new(),
87            adaptive_thresholds: HashMap::new(),
88            real_time_monitors: HashMap::new(),
89            anomaly_detector: GradientAnomalyDetector::default(),
90            performance_tracker: GradientPerformanceTracker::default(),
91            conflict_analyzer: GradientConflictAnalyzer::default(),
92            flow_visualizer: GradientFlowVisualizer::default(),
93            enhanced_analyzer: EnhancedGradientAnalyzer::default(),
94        }
95    }
96
97    /// Record gradient flow for a layer
98    pub fn record_gradient_flow(
99        &mut self,
100        layer_name: &str,
101        gradient_norm: f64,
102        gradient_mean: f64,
103        gradient_std: f64,
104    ) -> Result<()> {
105        let flow = GradientFlow {
106            layer_name: layer_name.to_string(),
107            step: self.current_step,
108            gradient_norm,
109            gradient_mean,
110            gradient_std,
111            gradient_max: gradient_mean + gradient_std,
112            gradient_min: gradient_mean - gradient_std,
113            dead_neurons_ratio: self.estimate_dead_neurons_ratio(gradient_norm),
114            active_neurons_ratio: 1.0 - self.estimate_dead_neurons_ratio(gradient_norm),
115            timestamp: chrono::Utc::now(),
116        };
117
118        // Update gradient history
119        {
120            let history = self
121                .gradient_histories
122                .entry(layer_name.to_string())
123                .or_insert_with(|| GradientHistory::new(layer_name.to_string(), 1000));
124            history.add_gradient_flow(&flow);
125        }
126
127        // Update adaptive thresholds
128        let thresholds =
129            self.adaptive_thresholds.entry(layer_name.to_string()).or_insert_with(|| {
130                AdaptiveThresholds::new(
131                    layer_name.to_string(),
132                    self.gradient_config.vanishing_threshold,
133                    self.gradient_config.exploding_threshold,
134                )
135            });
136        thresholds.update_thresholds(gradient_norm);
137
138        // Update real-time monitor
139        let monitor = self
140            .real_time_monitors
141            .entry(layer_name.to_string())
142            .or_insert_with(|| RealTimeGradientMonitor::new(layer_name.to_string()));
143        monitor.update(gradient_norm);
144
145        // Check for alerts
146        self.check_gradient_alerts(layer_name, &flow)?;
147
148        // Record performance metrics
149        let timer = self.performance_tracker.start_timing(layer_name);
150        let (_, computation_time) = timer.finish();
151        self.performance_tracker
152            .record_layer_performance(layer_name, computation_time, 0); // Memory usage simplified
153
154        // Detect anomalies
155        let anomalies =
156            self.anomaly_detector
157                .detect_anomalies(layer_name, gradient_norm, self.current_step);
158        for anomaly in anomalies {
159            self.alerts.push(GradientAlert::GradientOscillation {
160                layer_name: anomaly.layer_name,
161                variance: anomaly.severity,
162            });
163        }
164
165        // Establish baseline if needed
166        if let Some(history) = self.gradient_histories.get(layer_name) {
167            if history.gradient_norms.len() == 50 {
168                let gradient_values: Vec<f64> = history.gradient_norms.iter().cloned().collect();
169                self.anomaly_detector.establish_baseline(layer_name, &gradient_values);
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Get current gradient debugging status
177    pub fn get_status(&self) -> GradientDebugStatus {
178        let layer_statuses: HashMap<String, LayerGradientStatus> = self
179            .gradient_histories
180            .iter()
181            .map(|(layer_name, history)| {
182                let status = self.compute_layer_status(layer_name, history);
183                (layer_name.clone(), status)
184            })
185            .collect();
186
187        let overall_health = self.compute_overall_health(&layer_statuses);
188        let recent_alerts: Vec<GradientAlert> =
189            self.alerts.iter().rev().take(10).cloned().collect();
190
191        GradientDebugStatus {
192            current_step: self.current_step,
193            overall_health,
194            layer_statuses,
195            recent_alerts,
196            total_alerts: self.alerts.len(),
197            active_layers: self.gradient_histories.len(),
198        }
199    }
200
201    /// Generate flow analysis for report generation
202    fn generate_flow_analysis(&self) -> FlowAnalysis {
203        let mut layer_analyses = HashMap::new();
204
205        for (layer_name, history) in &self.gradient_histories {
206            let latest_gradient = history.gradient_norms.back().cloned().unwrap_or(0.0);
207
208            // Determine if gradients are vanishing or exploding
209            let is_vanishing = latest_gradient < 1e-8
210                || (history.gradient_norms.len() > 5
211                    && history.gradient_norms.iter().rev().take(5).all(|&g| g < 1e-6));
212
213            let is_exploding = latest_gradient > 100.0
214                || (history.gradient_norms.len() > 3
215                    && history.gradient_norms.iter().rev().take(3).any(|&g| g > 50.0));
216
217            // Calculate flow consistency (variance in gradient norms)
218            let flow_consistency = if history.gradient_norms.len() > 1 {
219                let mean = history.gradient_norms.iter().sum::<f64>()
220                    / history.gradient_norms.len() as f64;
221                let variance =
222                    history.gradient_norms.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
223                        / history.gradient_norms.len() as f64;
224                1.0 / (1.0 + variance) // Higher consistency = lower variance
225            } else {
226                1.0
227            };
228
229            layer_analyses.insert(
230                layer_name.clone(),
231                LayerFlowAnalysis {
232                    layer_name: layer_name.clone(),
233                    is_vanishing,
234                    is_exploding,
235                    gradient_norm: latest_gradient,
236                    flow_consistency,
237                },
238            );
239        }
240
241        FlowAnalysis { layer_analyses }
242    }
243
244    /// Generate comprehensive debugging report
245    pub fn generate_comprehensive_report(&self) -> Result<ComprehensiveGradientReport> {
246        let status = self.get_status();
247        let conflict_analysis = self.conflict_analyzer.analyze_conflicts(&self.gradient_histories);
248        let visualization = self
249            .flow_visualizer
250            .generate_visualization(&self.gradient_histories, self.current_step);
251        let enhanced_analysis =
252            self.enhanced_analyzer.generate_enhanced_analysis(&self.gradient_histories);
253        let performance_snapshot = self.performance_tracker.take_performance_snapshot();
254        let anomaly_summary = self.anomaly_detector.get_anomaly_summary(None);
255
256        let flow_analysis = self.generate_flow_analysis();
257
258        Ok(ComprehensiveGradientReport {
259            timestamp: chrono::Utc::now(),
260            status,
261            conflict_analysis,
262            visualization,
263            enhanced_analysis,
264            flow_analysis,
265            performance_snapshot,
266            anomaly_summary,
267            recommendations: self.generate_comprehensive_recommendations()?,
268        })
269    }
270
271    /// Analyze gradient conflicts between layers
272    pub fn analyze_gradient_conflicts(&self) -> GradientConflictAnalysis {
273        self.conflict_analyzer.analyze_conflicts(&self.gradient_histories)
274    }
275
276    /// Generate gradient flow visualization
277    pub fn generate_gradient_flow_visualization(&self) -> GradientFlowVisualization {
278        self.flow_visualizer
279            .generate_visualization(&self.gradient_histories, self.current_step)
280    }
281
282    /// Generate enhanced layer analysis
283    pub fn generate_enhanced_layer_analysis(&self) -> EnhancedLayerGradientAnalysis {
284        self.enhanced_analyzer.generate_enhanced_analysis(&self.gradient_histories)
285    }
286
287    /// Get performance insights
288    pub fn get_performance_insights(&self) -> PerformanceInsights {
289        let trends = self.performance_tracker.get_performance_trends();
290        let recommendations = self.performance_tracker.generate_optimization_recommendations();
291        let bottlenecks = self.performance_tracker.bottleneck_layers.clone();
292
293        PerformanceInsights {
294            trends,
295            recommendations,
296            bottlenecks,
297            current_throughput: self.performance_tracker.throughput_gradients_per_second,
298            memory_usage: self.performance_tracker.memory_usage_bytes,
299        }
300    }
301
302    /// Advance to next step
303    pub fn next_step(&mut self) {
304        self.current_step += 1;
305
306        // Clear old alerts (keep last 100)
307        if self.alerts.len() > 100 {
308            self.alerts.drain(0..self.alerts.len() - 100);
309        }
310
311        // Update no-gradient counters
312        for (layer_name, history) in &self.gradient_histories {
313            if let Some(latest_norm) = history.gradient_norms.back() {
314                if *latest_norm < 1e-8 {
315                    *self.layer_no_gradient_count.entry(layer_name.clone()).or_insert(0) += 1;
316                } else {
317                    self.layer_no_gradient_count.insert(layer_name.clone(), 0);
318                }
319            }
320        }
321
322        // Check for no-gradient alerts
323        for (layer_name, &count) in &self.layer_no_gradient_count {
324            if count >= self.gradient_config.no_gradient_steps_threshold {
325                self.alerts.push(GradientAlert::NoGradientFlow {
326                    layer_name: layer_name.clone(),
327                    steps_without_gradient: count,
328                });
329            }
330        }
331    }
332
333    /// Reset debugger state
334    pub fn reset(&mut self) {
335        self.gradient_histories.clear();
336        self.current_step = 0;
337        self.alerts.clear();
338        self.layer_no_gradient_count.clear();
339        self.adaptive_thresholds.clear();
340        self.real_time_monitors.clear();
341        self.anomaly_detector = GradientAnomalyDetector::default();
342        self.performance_tracker = GradientPerformanceTracker::default();
343    }
344
345    /// Get alerts for a specific layer
346    pub fn get_layer_alerts(&self, layer_name: &str) -> Vec<&GradientAlert> {
347        self.alerts
348            .iter()
349            .filter(|alert| match alert {
350                GradientAlert::VanishingGradients {
351                    layer_name: name, ..
352                } => name == layer_name,
353                GradientAlert::ExplodingGradients {
354                    layer_name: name, ..
355                } => name == layer_name,
356                GradientAlert::DeadNeurons {
357                    layer_name: name, ..
358                } => name == layer_name,
359                GradientAlert::GradientOscillation {
360                    layer_name: name, ..
361                } => name == layer_name,
362                GradientAlert::NoGradientFlow {
363                    layer_name: name, ..
364                } => name == layer_name,
365            })
366            .collect()
367    }
368
369    /// Get gradient history for a layer
370    pub fn get_layer_history(&self, layer_name: &str) -> Option<&GradientHistory> {
371        self.gradient_histories.get(layer_name)
372    }
373
374    /// Get all monitored layers
375    pub fn get_monitored_layers(&self) -> Vec<&String> {
376        self.gradient_histories.keys().collect()
377    }
378
379    // Private helper methods
380
381    fn estimate_dead_neurons_ratio(&self, gradient_norm: f64) -> f64 {
382        // Simplified estimation - in practice would analyze individual neuron gradients
383        if gradient_norm < 1e-6 {
384            0.9 // Assume 90% dead if very low gradient
385        } else if gradient_norm < 1e-4 {
386            0.3 // Assume 30% dead if low gradient
387        } else {
388            0.05 // Assume 5% dead for normal gradients
389        }
390    }
391
392    fn check_gradient_alerts(&mut self, layer_name: &str, flow: &GradientFlow) -> Result<()> {
393        // Check adaptive thresholds first
394        if let Some(thresholds) = self.adaptive_thresholds.get(layer_name) {
395            let threshold_alerts = thresholds.check_thresholds(flow.gradient_norm);
396            self.alerts.extend(threshold_alerts);
397        } else {
398            // Fallback to static thresholds
399            if flow.gradient_norm < self.gradient_config.vanishing_threshold {
400                self.alerts.push(GradientAlert::VanishingGradients {
401                    layer_name: layer_name.to_string(),
402                    norm: flow.gradient_norm,
403                    threshold: self.gradient_config.vanishing_threshold,
404                });
405            }
406
407            if flow.gradient_norm > self.gradient_config.exploding_threshold {
408                self.alerts.push(GradientAlert::ExplodingGradients {
409                    layer_name: layer_name.to_string(),
410                    norm: flow.gradient_norm,
411                    threshold: self.gradient_config.exploding_threshold,
412                });
413            }
414        }
415
416        // Check dead neurons
417        if flow.dead_neurons_ratio > self.gradient_config.dead_neuron_threshold {
418            self.alerts.push(GradientAlert::DeadNeurons {
419                layer_name: layer_name.to_string(),
420                ratio: flow.dead_neurons_ratio,
421                threshold: self.gradient_config.dead_neuron_threshold,
422            });
423        }
424
425        // Check oscillation
426        if let Some(monitor) = self.real_time_monitors.get(layer_name) {
427            if monitor.is_oscillating() {
428                self.alerts.push(GradientAlert::GradientOscillation {
429                    layer_name: layer_name.to_string(),
430                    variance: monitor.get_stability_score(),
431                });
432            }
433        }
434
435        Ok(())
436    }
437
438    fn compute_layer_status(
439        &self,
440        layer_name: &str,
441        history: &GradientHistory,
442    ) -> LayerGradientStatus {
443        let latest_norm = history.gradient_norms.back().cloned().unwrap_or(0.0);
444        let health = self.classify_layer_health(layer_name, history);
445        let alerts = self.get_layer_alerts(layer_name).len();
446        let trend = history.get_trend_slope().unwrap_or(0.0);
447
448        LayerGradientStatus {
449            layer_name: layer_name.to_string(),
450            health,
451            latest_gradient_norm: latest_norm,
452            gradient_trend: trend,
453            alert_count: alerts,
454            steps_recorded: history.gradient_norms.len(),
455        }
456    }
457
458    fn classify_layer_health(&self, layer_name: &str, history: &GradientHistory) -> LayerHealth {
459        let latest_norm = history.gradient_norms.back().cloned().unwrap_or(0.0);
460        let alert_count = self.get_layer_alerts(layer_name).len();
461
462        if !(1e-7..=100.0).contains(&latest_norm) || alert_count > 3 {
463            LayerHealth::Critical
464        } else if !(1e-5..=10.0).contains(&latest_norm) || alert_count > 0 {
465            LayerHealth::Warning
466        } else {
467            LayerHealth::Healthy
468        }
469    }
470
471    fn compute_overall_health(
472        &self,
473        layer_statuses: &HashMap<String, LayerGradientStatus>,
474    ) -> LayerHealth {
475        if layer_statuses.is_empty() {
476            return LayerHealth::Healthy;
477        }
478
479        let critical_count =
480            layer_statuses.values().filter(|s| s.health == LayerHealth::Critical).count();
481        let warning_count =
482            layer_statuses.values().filter(|s| s.health == LayerHealth::Warning).count();
483        let total = layer_statuses.len();
484
485        if critical_count > 0 || warning_count as f64 / total as f64 > 0.5 {
486            LayerHealth::Critical
487        } else if warning_count > 0 {
488            LayerHealth::Warning
489        } else {
490            LayerHealth::Healthy
491        }
492    }
493
494    fn generate_comprehensive_recommendations(&self) -> Result<Vec<GradientRecommendation>> {
495        let mut recommendations = Vec::new();
496
497        // Performance recommendations
498        let perf_recs = self.performance_tracker.generate_optimization_recommendations();
499        for rec in perf_recs {
500            recommendations.push(GradientRecommendation {
501                recommendation_type: RecommendationType::Performance,
502                title: rec.layer_name,
503                description: format!("{:?}: {}", rec.issue_type, rec.recommendations.join(", ")),
504                priority: match rec.severity {
505                    OptimizationSeverity::Critical => GradientRecommendationPriority::High,
506                    OptimizationSeverity::High => GradientRecommendationPriority::High,
507                    OptimizationSeverity::Medium => GradientRecommendationPriority::Medium,
508                    OptimizationSeverity::Low => GradientRecommendationPriority::Low,
509                },
510                expected_impact: rec.expected_improvement,
511            });
512        }
513
514        // Conflict recommendations
515        let conflict_analysis = self.conflict_analyzer.analyze_conflicts(&self.gradient_histories);
516        for strategy in conflict_analysis.mitigation_strategies {
517            recommendations.push(GradientRecommendation {
518                recommendation_type: RecommendationType::Conflict,
519                title: strategy.strategy_name,
520                description: strategy.description,
521                priority: match strategy.implementation_complexity {
522                    MitigationComplexity::Simple => GradientRecommendationPriority::High,
523                    MitigationComplexity::Moderate => GradientRecommendationPriority::Medium,
524                    MitigationComplexity::Complex => GradientRecommendationPriority::Medium,
525                    MitigationComplexity::RequiresArchitectureChange => {
526                        GradientRecommendationPriority::Low
527                    },
528                },
529                expected_impact: strategy.effectiveness,
530            });
531        }
532
533        // Anomaly recommendations
534        let anomaly_summary = self.anomaly_detector.get_anomaly_summary(None);
535        for rec_text in anomaly_summary.recommendations {
536            recommendations.push(GradientRecommendation {
537                recommendation_type: RecommendationType::Anomaly,
538                title: "Anomaly Mitigation".to_string(),
539                description: rec_text,
540                priority: if anomaly_summary.average_severity > 0.7 {
541                    GradientRecommendationPriority::High
542                } else {
543                    GradientRecommendationPriority::Medium
544                },
545                expected_impact: 1.0 - anomaly_summary.average_severity,
546            });
547        }
548
549        // Sort by priority and expected impact
550        recommendations.sort_by(|a, b| {
551            let priority_cmp = b.priority.cmp(&a.priority);
552            if priority_cmp == std::cmp::Ordering::Equal {
553                b.expected_impact
554                    .partial_cmp(&a.expected_impact)
555                    .unwrap_or(std::cmp::Ordering::Equal)
556            } else {
557                priority_cmp
558            }
559        });
560
561        Ok(recommendations)
562    }
563
564    /// Generate recommendations based on current analysis
565    pub fn generate_recommendations(&self) -> Result<Vec<GradientRecommendation>> {
566        self.generate_comprehensive_recommendations()
567    }
568
569    /// Start the gradient debugger
570    pub async fn start(&mut self) -> Result<()> {
571        // Initialize monitoring systems
572        self.performance_tracker.start_monitoring();
573
574        // Reset state for a new debugging session
575        self.current_step = 0;
576        self.alerts.clear();
577
578        // Initialize adaptive thresholds for existing histories
579        for (layer_name, history) in &self.gradient_histories {
580            if !history.gradient_norms.is_empty() {
581                let thresholds = AdaptiveThresholds::from_history(history);
582                self.adaptive_thresholds.insert(layer_name.clone(), thresholds);
583            }
584        }
585
586        Ok(())
587    }
588
589    /// Generate comprehensive gradient report
590    pub async fn generate_report(&self) -> Result<ComprehensiveGradientReport> {
591        let status = GradientDebugStatus {
592            current_step: self.current_step,
593            overall_health: self.evaluate_overall_health(),
594            layer_statuses: self.get_layer_statuses(),
595            recent_alerts: self.alerts.iter().rev().take(10).cloned().collect(),
596            total_alerts: self.alerts.len(),
597            active_layers: self.gradient_histories.len(),
598        };
599
600        let conflict_analysis = self.conflict_analyzer.analyze_conflicts(&self.gradient_histories);
601        let visualization = self.flow_visualizer.create_visualization(&self.gradient_histories);
602        let enhanced_analysis = self.enhanced_analyzer.analyze_gradients(&self.gradient_histories);
603        let performance_snapshot = self.performance_tracker.take_performance_snapshot();
604        let anomaly_summary = self.anomaly_detector.get_anomaly_summary(None);
605        let recommendations = self.generate_recommendations().unwrap_or_default();
606
607        let flow_analysis = self.generate_flow_analysis();
608
609        Ok(ComprehensiveGradientReport {
610            timestamp: chrono::Utc::now(),
611            status,
612            conflict_analysis,
613            visualization,
614            enhanced_analysis,
615            flow_analysis,
616            performance_snapshot,
617            anomaly_summary,
618            recommendations,
619        })
620    }
621
622    /// Quick analysis for immediate insights
623    pub async fn quick_analysis(&self) -> Result<GradientQuickAnalysis> {
624        let mut problematic_layers = Vec::new();
625        let mut total_gradients = 0f64;
626        let mut active_layers = 0;
627
628        for (layer_name, history) in &self.gradient_histories {
629            if !history.gradient_norms.is_empty() {
630                active_layers += 1;
631                let latest_norm = history
632                    .gradient_norms
633                    .back()
634                    .expect("gradient_norms should not be empty after is_empty check");
635                total_gradients += latest_norm;
636
637                // Check for basic problems
638                if *latest_norm < 1e-8 {
639                    problematic_layers.push(format!("{}: Vanishing gradients", layer_name));
640                } else if *latest_norm > 100.0 {
641                    problematic_layers.push(format!("{}: Exploding gradients", layer_name));
642                }
643            }
644        }
645
646        let average_gradient =
647            if active_layers > 0 { total_gradients / active_layers as f64 } else { 0.0 };
648
649        let health_score = self.calculate_quick_health_score();
650
651        Ok(GradientQuickAnalysis {
652            overall_health: if health_score > 0.8 {
653                LayerHealth::Healthy
654            } else if health_score > 0.5 {
655                LayerHealth::Warning
656            } else {
657                LayerHealth::Critical
658            },
659            active_layers,
660            problematic_layers,
661            average_gradient_norm: average_gradient,
662            recent_alerts_count: self.alerts.len(),
663            timestamp: chrono::Utc::now(),
664        })
665    }
666
667    /// Evaluate overall gradient health
668    fn evaluate_overall_health(&self) -> LayerHealth {
669        if self.gradient_histories.is_empty() {
670            return LayerHealth::Unknown;
671        }
672
673        let mut healthy_count = 0;
674        let mut warning_count = 0;
675        let mut critical_count = 0;
676
677        for history in self.gradient_histories.values() {
678            if let Some(latest_norm) = history.gradient_norms.back() {
679                if *latest_norm < 1e-8 || *latest_norm > 100.0 {
680                    critical_count += 1;
681                } else if *latest_norm < 1e-6 || *latest_norm > 10.0 {
682                    warning_count += 1;
683                } else {
684                    healthy_count += 1;
685                }
686            }
687        }
688
689        let total = healthy_count + warning_count + critical_count;
690        let critical_ratio = critical_count as f64 / total as f64;
691        let warning_ratio = (warning_count + critical_count) as f64 / total as f64;
692
693        if critical_ratio > 0.3 {
694            LayerHealth::Critical
695        } else if warning_ratio > 0.5 {
696            LayerHealth::Warning
697        } else {
698            LayerHealth::Healthy
699        }
700    }
701
702    /// Get status for each layer
703    fn get_layer_statuses(&self) -> HashMap<String, LayerGradientStatus> {
704        let mut statuses = HashMap::new();
705
706        for (layer_name, history) in &self.gradient_histories {
707            let status = if let Some(latest_norm) = history.gradient_norms.back() {
708                LayerGradientStatus {
709                    layer_name: layer_name.clone(),
710                    latest_gradient_norm: *latest_norm,
711                    gradient_trend: self.calculate_trend_value(history),
712                    health: if *latest_norm < 1e-8 {
713                        LayerHealth::Critical
714                    } else if *latest_norm > 100.0 {
715                        LayerHealth::Critical
716                    } else if *latest_norm < 1e-6 || *latest_norm > 10.0 {
717                        LayerHealth::Warning
718                    } else {
719                        LayerHealth::Healthy
720                    },
721                    alert_count: self.get_layer_alerts(layer_name).len(),
722                    steps_recorded: history.gradient_norms.len(),
723                }
724            } else {
725                LayerGradientStatus {
726                    layer_name: layer_name.clone(),
727                    latest_gradient_norm: 0.0,
728                    gradient_trend: 0.0,
729                    health: LayerHealth::Unknown,
730                    alert_count: 0,
731                    steps_recorded: 0,
732                }
733            };
734
735            statuses.insert(layer_name.clone(), status);
736        }
737
738        statuses
739    }
740
741    /// Calculate gradient trend for a layer
742    #[allow(dead_code)]
743    fn calculate_trend(&self, history: &GradientHistory) -> GradientTrend {
744        if history.gradient_norms.len() < 3 {
745            return GradientTrend::Unknown;
746        }
747
748        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(3).cloned().collect();
749
750        if recent[0] > recent[1] && recent[1] > recent[2] {
751            GradientTrend::Increasing
752        } else if recent[0] < recent[1] && recent[1] < recent[2] {
753            GradientTrend::Decreasing
754        } else {
755            GradientTrend::Stable
756        }
757    }
758
759    /// Calculate gradient trend as numeric value for a layer
760    fn calculate_trend_value(&self, history: &GradientHistory) -> f64 {
761        if history.gradient_norms.len() < 2 {
762            return 0.0;
763        }
764
765        let recent: Vec<f64> = history.gradient_norms.iter().rev().take(10).cloned().collect();
766        if recent.len() < 2 {
767            return 0.0;
768        }
769
770        // Calculate linear trend slope
771        let n = recent.len() as f64;
772        let sum_x = (0..recent.len()).sum::<usize>() as f64;
773        let sum_y = recent.iter().sum::<f64>();
774        let sum_xy = recent.iter().enumerate().map(|(i, &y)| i as f64 * y).sum::<f64>();
775        let sum_x2 = (0..recent.len()).map(|i| (i * i) as f64).sum::<f64>();
776
777        (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x)
778    }
779
780    /// Calculate quick health score
781    fn calculate_quick_health_score(&self) -> f64 {
782        if self.gradient_histories.is_empty() {
783            return 0.0;
784        }
785
786        let mut score = 0.0;
787        let mut count = 0;
788
789        for history in self.gradient_histories.values() {
790            if let Some(latest_norm) = history.gradient_norms.back() {
791                // Score based on gradient magnitude (ideal range: 1e-4 to 1.0)
792                let norm_score = if *latest_norm >= 1e-4 && *latest_norm <= 1.0 {
793                    1.0
794                } else if *latest_norm >= 1e-6 && *latest_norm <= 10.0 {
795                    0.7
796                } else if *latest_norm >= 1e-8 && *latest_norm <= 100.0 {
797                    0.3
798                } else {
799                    0.0
800                };
801
802                score += norm_score;
803                count += 1;
804            }
805        }
806
807        if count == 0 {
808            0.0
809        } else {
810            score / count as f64
811        }
812    }
813}
814
815/// Current gradient debugging status
816#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
817pub struct GradientDebugStatus {
818    pub current_step: usize,
819    pub overall_health: LayerHealth,
820    pub layer_statuses: HashMap<String, LayerGradientStatus>,
821    pub recent_alerts: Vec<GradientAlert>,
822    pub total_alerts: usize,
823    pub active_layers: usize,
824}
825
826/// Comprehensive gradient debugging report
827#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
828pub struct ComprehensiveGradientReport {
829    pub timestamp: chrono::DateTime<chrono::Utc>,
830    pub status: GradientDebugStatus,
831    pub conflict_analysis: GradientConflictAnalysis,
832    pub visualization: GradientFlowVisualization,
833    pub enhanced_analysis: EnhancedLayerGradientAnalysis,
834    pub flow_analysis: FlowAnalysis,
835    pub performance_snapshot: PerformanceSnapshot,
836    pub anomaly_summary: AnomalySummary,
837    pub recommendations: Vec<GradientRecommendation>,
838}
839
840impl ComprehensiveGradientReport {
841    /// Check if there are vanishing gradient issues
842    pub fn has_vanishing_gradients(&self) -> bool {
843        // Check if any layers have very small gradients
844        for layer_status in self.status.layer_statuses.values() {
845            if layer_status.latest_gradient_norm < 1e-8 {
846                return true;
847            }
848        }
849
850        // Check anomaly summary for vanishing gradient patterns
851        for anomaly in &self.anomaly_summary.anomalies {
852            if matches!(
853                anomaly.anomaly_type,
854                crate::anomaly_detector::AnomalyType::GradientVanishing
855            ) {
856                return true;
857            }
858        }
859
860        false
861    }
862
863    /// Check if there are exploding gradient issues
864    pub fn has_exploding_gradients(&self) -> bool {
865        // Check if any layers have very large gradients
866        for layer_status in self.status.layer_statuses.values() {
867            if layer_status.latest_gradient_norm > 100.0 {
868                return true;
869            }
870        }
871
872        // Check anomaly summary for exploding gradient patterns
873        for anomaly in &self.anomaly_summary.anomalies {
874            if matches!(
875                anomaly.anomaly_type,
876                crate::anomaly_detector::AnomalyType::GradientExplosion
877                    | crate::anomaly_detector::AnomalyType::NumericalInstability
878            ) {
879                return true;
880            }
881        }
882
883        false
884    }
885}
886
887/// Performance insights summary
888#[derive(Debug, Clone)]
889pub struct PerformanceInsights {
890    pub trends: PerformanceTrends,
891    pub recommendations: Vec<OptimizationRecommendation>,
892    pub bottlenecks: Vec<String>,
893    pub current_throughput: f64,
894    pub memory_usage: usize,
895}
896
897/// Gradient debugging recommendation
898#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
899pub struct GradientRecommendation {
900    pub recommendation_type: RecommendationType,
901    pub title: String,
902    pub description: String,
903    pub priority: GradientRecommendationPriority,
904    pub expected_impact: f64,
905}
906
907#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
908pub enum RecommendationType {
909    Performance,
910    Conflict,
911    Anomaly,
912    Architecture,
913    Optimization,
914}
915
916#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
917pub enum GradientRecommendationPriority {
918    Low,
919    Medium,
920    High,
921}
922
923/// Quick analysis results for immediate insights
924#[derive(Debug, Clone)]
925pub struct GradientQuickAnalysis {
926    pub overall_health: LayerHealth,
927    pub active_layers: usize,
928    pub problematic_layers: Vec<String>,
929    pub average_gradient_norm: f64,
930    pub recent_alerts_count: usize,
931    pub timestamp: chrono::DateTime<chrono::Utc>,
932}
933
934/// Status for individual layer gradients
935#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
936pub struct LayerGradientStatus {
937    pub layer_name: String,
938    pub health: LayerHealth,
939    pub latest_gradient_norm: f64,
940    pub gradient_trend: f64,
941    pub alert_count: usize,
942    pub steps_recorded: usize,
943}
944
945/// Gradient trend indicators
946#[derive(Debug, Clone, PartialEq, Eq)]
947pub enum GradientTrend {
948    Unknown,
949    Increasing,
950    Decreasing,
951    Stable,
952}