trustformers_debug/gradient_debugger/
enhanced_analysis.rs

1//! Enhanced Layer Analysis and Network-Level Insights
2//!
3//! This module provides comprehensive enhanced analysis capabilities including
4//! detailed layer-wise analysis, network-level gradient insights, and optimization
5//! priority ranking for gradient debugging.
6
7use super::types::*;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Enhanced layer-wise gradient analysis
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct EnhancedLayerGradientAnalysis {
14    pub layer_details: HashMap<String, LayerGradientDetails>,
15    pub network_level_analysis: NetworkLevelAnalysis,
16    pub gradient_hierarchy: GradientHierarchy,
17    pub optimization_priorities: Vec<OptimizationPriority>,
18}
19
20/// Detailed gradient analysis for a specific layer
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct LayerGradientDetails {
23    pub layer_name: String,
24    pub gradient_statistics: GradientStatistics,
25    pub flow_characteristics: FlowCharacteristics,
26    pub health_metrics: LayerHealthMetrics,
27    pub optimization_suggestions: Vec<LayerOptimizationSuggestion>,
28    pub comparative_analysis: ComparativeAnalysis,
29}
30
31/// Network-level gradient analysis
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct NetworkLevelAnalysis {
34    pub overall_gradient_health: LayerHealth,
35    pub gradient_distribution: GradientDistribution,
36    pub layer_interactions: Vec<LayerInteraction>,
37    pub convergence_indicators: ConvergenceIndicators,
38    pub training_dynamics: TrainingDynamics,
39    pub stability_assessment: StabilityAssessment,
40}
41
42/// Distribution of gradients across the network
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct GradientDistribution {
45    pub mean_gradient_norm: f64,
46    pub gradient_variance: f64,
47    pub gradient_skewness: f64,
48    pub gradient_kurtosis: f64,
49    pub layer_gradient_ratios: HashMap<String, f64>,
50    pub distribution_type: DistributionType,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub enum DistributionType {
55    Normal,
56    Skewed,
57    HeavyTailed,
58    Multimodal,
59    Degenerate,
60}
61
62/// Interaction between layers in gradient flow
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct LayerInteraction {
65    pub layer1: String,
66    pub layer2: String,
67    pub interaction_strength: f64,
68    pub interaction_type: InteractionType,
69    pub impact_score: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub enum InteractionType {
74    Cooperative,
75    Competitive,
76    Neutral,
77    Disruptive,
78}
79
80/// Indicators of training convergence
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ConvergenceIndicators {
83    pub gradient_convergence_score: f64,
84    pub parameter_convergence_score: f64,
85    pub loss_convergence_score: f64,
86    pub convergence_trend: ConvergenceTrend,
87    pub estimated_steps_to_convergence: Option<usize>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum ConvergenceTrend {
92    Converging,
93    Stable,
94    Diverging,
95    Oscillating,
96    Unknown,
97}
98
99/// Training dynamics analysis
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct TrainingDynamics {
102    pub learning_phase: LearningPhase,
103    pub gradient_momentum: f64,
104    pub learning_velocity: f64,
105    pub adaptation_rate: f64,
106    pub plateau_detection: PlateauDetection,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub enum LearningPhase {
111    InitialLearning,
112    RapidLearning,
113    Refinement,
114    Convergence,
115    Plateau,
116    Overfitting,
117}
118
119/// Plateau detection in training
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct PlateauDetection {
122    pub is_plateau: bool,
123    pub plateau_duration: usize,
124    pub plateau_severity: PlateauSeverity,
125    pub suggested_actions: Vec<String>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub enum PlateauSeverity {
130    Mild,
131    Moderate,
132    Severe,
133}
134
135/// Network stability assessment
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct StabilityAssessment {
138    pub overall_stability: f64,
139    pub stability_trend: StabilityTrend,
140    pub instability_sources: Vec<InstabilitySource>,
141    pub stability_forecast: StabilityForecast,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub enum StabilityTrend {
146    Improving,
147    Stable,
148    Degrading,
149}
150
151/// Source of instability in the network
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct InstabilitySource {
154    pub source_type: InstabilityType,
155    pub affected_layers: Vec<String>,
156    pub severity: f64,
157    pub description: String,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub enum InstabilityType {
162    GradientExplosion,
163    GradientVanishing,
164    Oscillation,
165    Stagnation,
166    Chaos,
167}
168
169/// Forecast of stability trends
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct StabilityForecast {
172    pub short_term_outlook: StabilityOutlook,
173    pub long_term_outlook: StabilityOutlook,
174    pub confidence_level: f64,
175    pub recommended_monitoring: Vec<String>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub enum StabilityOutlook {
180    Stable,
181    Improving,
182    Deteriorating,
183    Uncertain,
184}
185
186/// Hierarchical organization of gradient information
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct GradientHierarchy {
189    pub layer_groups: Vec<LayerGroup>,
190    pub hierarchy_levels: Vec<HierarchyLevel>,
191    pub cross_level_interactions: Vec<CrossLevelInteraction>,
192}
193
194/// Group of related layers
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct LayerGroup {
197    pub group_name: String,
198    pub layers: Vec<String>,
199    pub group_characteristics: GroupCharacteristics,
200    pub internal_coherence: f64,
201}
202
203/// Characteristics of a layer group
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct GroupCharacteristics {
206    pub average_gradient_norm: f64,
207    pub gradient_synchronization: f64,
208    pub learning_rate_sensitivity: f64,
209    pub optimization_difficulty: OptimizationDifficulty,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub enum OptimizationDifficulty {
214    Easy,
215    Moderate,
216    Difficult,
217    VeryDifficult,
218}
219
220/// Level in the gradient hierarchy
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct HierarchyLevel {
223    pub level_id: usize,
224    pub level_name: String,
225    pub layer_groups: Vec<String>,
226    pub level_importance: f64,
227    pub optimization_impact: f64,
228}
229
230/// Interaction between hierarchy levels
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct CrossLevelInteraction {
233    pub from_level: usize,
234    pub to_level: usize,
235    pub interaction_strength: f64,
236    pub interaction_direction: InteractionDirection,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub enum InteractionDirection {
241    TopDown,
242    BottomUp,
243    Bidirectional,
244}
245
246/// Optimization priority for layers or groups
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct OptimizationPriority {
249    pub target_name: String,
250    pub target_type: OptimizationTarget,
251    pub priority_score: f64,
252    pub urgency_level: UrgencyLevel,
253    pub optimization_potential: f64,
254    pub recommended_actions: Vec<PrioritizedAction>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub enum OptimizationTarget {
259    IndividualLayer,
260    LayerGroup,
261    NetworkLevel,
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub enum UrgencyLevel {
266    Low,
267    Medium,
268    High,
269    Critical,
270}
271
272/// Prioritized optimization action
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct PrioritizedAction {
275    pub action_name: String,
276    pub action_type: ActionType,
277    pub expected_impact: f64,
278    pub implementation_effort: ImplementationEffort,
279    pub prerequisites: Vec<String>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub enum ActionType {
284    ParameterAdjustment,
285    ArchitecturalChange,
286    OptimizationTechnique,
287    RegularizationMethod,
288    LearningRateScheduling,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub enum ImplementationEffort {
293    Minimal,
294    Low,
295    Moderate,
296    High,
297    Extensive,
298}
299
300/// Layer-specific optimization suggestion
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct LayerOptimizationSuggestion {
303    pub suggestion_type: SuggestionType,
304    pub description: String,
305    pub rationale: String,
306    pub expected_improvement: f64,
307    pub implementation_complexity: ImplementationComplexity,
308    pub side_effects: Vec<String>,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub enum SuggestionType {
313    WeightInitialization,
314    LearningRateAdjustment,
315    RegularizationAdd,
316    ArchitecturalModification,
317    OptimizationAlgorithm,
318    BatchNormalization,
319    DropoutAdjustment,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub enum ImplementationComplexity {
324    Simple,
325    Moderate,
326    Complex,
327    RequiresRetraining,
328}
329
330/// Enhanced gradient analyzer
331#[derive(Debug)]
332pub struct EnhancedGradientAnalyzer {
333    #[allow(dead_code)]
334    analysis_depth: AnalysisDepth,
335    convergence_window: usize,
336    stability_threshold: f64,
337}
338
339#[derive(Debug, Clone)]
340pub enum AnalysisDepth {
341    Basic,
342    Standard,
343    Comprehensive,
344    Expert,
345}
346
347impl Default for EnhancedGradientAnalyzer {
348    fn default() -> Self {
349        Self {
350            analysis_depth: AnalysisDepth::Standard,
351            convergence_window: 100,
352            stability_threshold: 0.8,
353        }
354    }
355}
356
357impl EnhancedGradientAnalyzer {
358    pub fn new(depth: AnalysisDepth, window: usize, threshold: f64) -> Self {
359        Self {
360            analysis_depth: depth,
361            convergence_window: window,
362            stability_threshold: threshold,
363        }
364    }
365
366    pub fn generate_enhanced_analysis(
367        &self,
368        gradient_histories: &HashMap<String, GradientHistory>,
369    ) -> EnhancedLayerGradientAnalysis {
370        let layer_details = self.generate_layer_details(gradient_histories);
371        let network_level_analysis = self.analyze_network_level_gradients(&layer_details);
372        let gradient_hierarchy = self.build_gradient_hierarchy(&layer_details);
373        let optimization_priorities =
374            self.rank_optimization_priorities(&layer_details, &network_level_analysis);
375
376        EnhancedLayerGradientAnalysis {
377            layer_details,
378            network_level_analysis,
379            gradient_hierarchy,
380            optimization_priorities,
381        }
382    }
383
384    fn generate_layer_details(
385        &self,
386        gradient_histories: &HashMap<String, GradientHistory>,
387    ) -> HashMap<String, LayerGradientDetails> {
388        let mut layer_details = HashMap::new();
389
390        for (layer_name, history) in gradient_histories {
391            let gradient_statistics = self.compute_detailed_gradient_stats(history);
392            let flow_characteristics = self.analyze_flow_characteristics(history);
393            let health_metrics = self.compute_layer_health_metrics(history);
394            let optimization_suggestions =
395                self.generate_layer_optimization_suggestions(layer_name, history);
396            let comparative_analysis =
397                self.compare_with_other_layers(layer_name, history, gradient_histories);
398
399            let analysis = LayerGradientDetails {
400                layer_name: layer_name.clone(),
401                gradient_statistics,
402                flow_characteristics,
403                health_metrics,
404                optimization_suggestions,
405                comparative_analysis,
406            };
407
408            layer_details.insert(layer_name.clone(), analysis);
409        }
410
411        layer_details
412    }
413
414    fn compute_detailed_gradient_stats(&self, history: &GradientHistory) -> GradientStatistics {
415        if history.gradient_norms.is_empty() {
416            return GradientStatistics {
417                mean: 0.0,
418                std: 0.0,
419                median: 0.0,
420                percentile_95: 0.0,
421                percentile_5: 0.0,
422                samples: 0,
423                variance: 0.0,
424                skewness: 0.0,
425                kurtosis: 0.0,
426            };
427        }
428
429        let values: Vec<f64> = history.gradient_norms.iter().cloned().collect();
430        let n = values.len() as f64;
431        let mean = values.iter().sum::<f64>() / n;
432        let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
433        let std = variance.sqrt();
434
435        let mut sorted_values = values.clone();
436        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
437
438        let median_idx = values.len() / 2;
439        let median = if values.len() % 2 == 0 {
440            (sorted_values[median_idx - 1] + sorted_values[median_idx]) / 2.0
441        } else {
442            sorted_values[median_idx]
443        };
444
445        let percentile_5_idx = (values.len() as f64 * 0.05) as usize;
446        let percentile_95_idx = (values.len() as f64 * 0.95) as usize;
447        let percentile_5 = sorted_values[percentile_5_idx];
448        let percentile_95 = sorted_values[percentile_95_idx.min(sorted_values.len() - 1)];
449
450        // Compute skewness and kurtosis
451        let skewness = if std > 0.0 {
452            values.iter().map(|&x| ((x - mean) / std).powi(3)).sum::<f64>() / n
453        } else {
454            0.0
455        };
456
457        let kurtosis = if std > 0.0 {
458            values.iter().map(|&x| ((x - mean) / std).powi(4)).sum::<f64>() / n - 3.0
459        } else {
460            0.0
461        };
462
463        GradientStatistics {
464            mean,
465            std,
466            median,
467            percentile_95,
468            percentile_5,
469            samples: values.len(),
470            variance,
471            skewness,
472            kurtosis,
473        }
474    }
475
476    fn analyze_flow_characteristics(&self, history: &GradientHistory) -> FlowCharacteristics {
477        let consistency_score = self.compute_flow_consistency(history);
478        let smoothness_index = self.compute_smoothness_index(history);
479        let trend_strength = self.compute_trend_strength(history);
480        let oscillation_frequency = self.compute_oscillation_frequency(history);
481        let stability_measure = self.compute_stability_measure(history);
482
483        FlowCharacteristics {
484            consistency_score,
485            smoothness_index,
486            trend_strength,
487            oscillation_frequency,
488            stability_measure,
489        }
490    }
491
492    fn compute_flow_consistency(&self, history: &GradientHistory) -> f64 {
493        if history.gradient_norms.len() < 2 {
494            return 1.0;
495        }
496
497        let variations: Vec<f64> = history
498            .gradient_norms
499            .iter()
500            .collect::<Vec<&f64>>()
501            .windows(2)
502            .map(|pair| (*pair[1] - *pair[0]).abs() / (*pair[0] + 1e-8))
503            .collect();
504
505        let avg_variation = variations.iter().sum::<f64>() / variations.len() as f64;
506        (1.0_f64 / (1.0 + avg_variation)).min(1.0)
507    }
508
509    fn compute_smoothness_index(&self, history: &GradientHistory) -> f64 {
510        if history.gradient_norms.len() < 3 {
511            return 1.0;
512        }
513
514        // Compute second derivatives to measure smoothness
515        let second_derivatives: Vec<f64> = history
516            .gradient_norms
517            .iter()
518            .collect::<Vec<&f64>>()
519            .windows(3)
520            .map(|window| *window[2] - 2.0 * *window[1] + *window[0])
521            .collect();
522
523        let avg_second_derivative = second_derivatives.iter().map(|&x| x.abs()).sum::<f64>()
524            / second_derivatives.len() as f64;
525        (1.0_f64 / (1.0 + avg_second_derivative)).min(1.0)
526    }
527
528    fn compute_trend_strength(&self, history: &GradientHistory) -> f64 {
529        history.get_trend_slope().map(|slope| slope.abs().min(1.0)).unwrap_or(0.0)
530    }
531
532    fn compute_oscillation_frequency(&self, history: &GradientHistory) -> f64 {
533        if history.gradient_norms.len() < 4 {
534            return 0.0;
535        }
536
537        let sign_changes = history
538            .gradient_norms
539            .iter()
540            .collect::<Vec<&f64>>()
541            .windows(2)
542            .map(|pair| *pair[1] - *pair[0])
543            .collect::<Vec<f64>>()
544            .windows(2)
545            .filter(|pair| pair[0] * pair[1] < 0.0)
546            .count();
547
548        sign_changes as f64 / history.gradient_norms.len() as f64
549    }
550
551    fn compute_stability_measure(&self, history: &GradientHistory) -> f64 {
552        if history.gradient_norms.is_empty() {
553            return 0.0;
554        }
555
556        let mean = history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
557        let variance = history.gradient_norms.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
558            / history.gradient_norms.len() as f64;
559
560        if mean == 0.0 {
561            return 0.0;
562        }
563
564        let coefficient_of_variation = variance.sqrt() / mean;
565        (1.0 / (1.0 + coefficient_of_variation)).min(1.0)
566    }
567
568    fn compute_layer_health_metrics(&self, history: &GradientHistory) -> LayerHealthMetrics {
569        let gradient_statistics = self.compute_detailed_gradient_stats(history);
570        let flow_characteristics = self.analyze_flow_characteristics(history);
571
572        let gradient_stability = flow_characteristics.stability_measure;
573        let information_flow_rate = self.compute_information_flow_rate(history);
574        let neuron_activity_ratio = self.estimate_neuron_activity_ratio(history);
575        let convergence_indicator = self.compute_convergence_indicator(history);
576
577        let mut risk_factors = Vec::new();
578        if gradient_statistics.mean < 1e-5 {
579            risk_factors.push("Very low gradient magnitude".to_string());
580        }
581        if gradient_statistics.mean > 100.0 {
582            risk_factors.push("Very high gradient magnitude".to_string());
583        }
584        if gradient_stability < 0.5 {
585            risk_factors.push("High gradient instability".to_string());
586        }
587        if flow_characteristics.oscillation_frequency > 0.5 {
588            risk_factors.push("High oscillation frequency".to_string());
589        }
590
591        let overall_health = if !risk_factors.is_empty() {
592            if risk_factors.len() > 2 {
593                LayerHealth::Critical
594            } else {
595                LayerHealth::Warning
596            }
597        } else {
598            LayerHealth::Healthy
599        };
600
601        LayerHealthMetrics {
602            overall_health,
603            gradient_stability,
604            information_flow_rate,
605            neuron_activity_ratio,
606            convergence_indicator,
607            risk_factors,
608        }
609    }
610
611    fn compute_information_flow_rate(&self, history: &GradientHistory) -> f64 {
612        if history.gradient_norms.len() < 2 {
613            return 0.0;
614        }
615
616        let total_change: f64 = history
617            .gradient_norms
618            .iter()
619            .collect::<Vec<&f64>>()
620            .windows(2)
621            .map(|pair| (*pair[1] - *pair[0]).abs())
622            .sum();
623
624        total_change / history.gradient_norms.len() as f64
625    }
626
627    fn estimate_neuron_activity_ratio(&self, history: &GradientHistory) -> f64 {
628        // Simplified estimation based on gradient magnitude
629        let mean_gradient =
630            history.gradient_norms.iter().sum::<f64>() / history.gradient_norms.len() as f64;
631        (mean_gradient / (mean_gradient + 1e-5)).min(1.0)
632    }
633
634    fn compute_convergence_indicator(&self, history: &GradientHistory) -> f64 {
635        if history.gradient_norms.len() < self.convergence_window {
636            return 0.5; // Neutral score if insufficient data
637        }
638
639        let recent: Vec<f64> = history
640            .gradient_norms
641            .iter()
642            .rev()
643            .take(self.convergence_window)
644            .cloned()
645            .collect();
646        let trend_slope = self.compute_trend_for_values(&recent);
647
648        // Negative slope indicates convergence (decreasing gradients)
649        if trend_slope < 0.0 {
650            (-trend_slope).min(1.0)
651        } else {
652            0.0
653        }
654    }
655
656    fn compute_trend_for_values(&self, values: &[f64]) -> f64 {
657        if values.len() < 3 {
658            return 0.0;
659        }
660
661        let n = values.len() as f64;
662        let sum_x: f64 = (0..values.len()).map(|i| i as f64).sum();
663        let sum_y: f64 = values.iter().sum();
664        let sum_xy: f64 = values.iter().enumerate().map(|(i, &y)| i as f64 * y).sum();
665        let sum_x2: f64 = (0..values.len()).map(|i| (i as f64).powi(2)).sum();
666
667        (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x.powi(2))
668    }
669
670    fn generate_layer_optimization_suggestions(
671        &self,
672        _layer_name: &str,
673        history: &GradientHistory,
674    ) -> Vec<LayerOptimizationSuggestion> {
675        let mut suggestions = Vec::new();
676        let stats = self.compute_detailed_gradient_stats(history);
677        let flow = self.analyze_flow_characteristics(history);
678
679        // Low gradient suggestions
680        if stats.mean < 1e-5 {
681            suggestions.push(LayerOptimizationSuggestion {
682                suggestion_type: SuggestionType::WeightInitialization,
683                description: "Consider better weight initialization methods".to_string(),
684                rationale: "Very low gradients may indicate poor initialization".to_string(),
685                expected_improvement: 0.7,
686                implementation_complexity: ImplementationComplexity::Simple,
687                side_effects: vec!["May require retraining from scratch".to_string()],
688            });
689        }
690
691        // High gradient suggestions
692        if stats.mean > 10.0 {
693            suggestions.push(LayerOptimizationSuggestion {
694                suggestion_type: SuggestionType::LearningRateAdjustment,
695                description: "Reduce learning rate for this layer".to_string(),
696                rationale: "High gradients may indicate learning rate is too large".to_string(),
697                expected_improvement: 0.6,
698                implementation_complexity: ImplementationComplexity::Simple,
699                side_effects: vec!["May slow down convergence".to_string()],
700            });
701        }
702
703        // High oscillation suggestions
704        if flow.oscillation_frequency > 0.5 {
705            suggestions.push(LayerOptimizationSuggestion {
706                suggestion_type: SuggestionType::RegularizationAdd,
707                description: "Add dropout or weight decay".to_string(),
708                rationale: "High oscillation may indicate overfitting or instability".to_string(),
709                expected_improvement: 0.5,
710                implementation_complexity: ImplementationComplexity::Moderate,
711                side_effects: vec!["May reduce model capacity".to_string()],
712            });
713        }
714
715        suggestions
716    }
717
718    fn compare_with_other_layers(
719        &self,
720        layer_name: &str,
721        history: &GradientHistory,
722        all_histories: &HashMap<String, GradientHistory>,
723    ) -> ComparativeAnalysis {
724        let current_stats = self.compute_detailed_gradient_stats(history);
725        let mut other_means = Vec::new();
726
727        for (other_name, other_history) in all_histories {
728            if other_name != layer_name {
729                let other_stats = self.compute_detailed_gradient_stats(other_history);
730                other_means.push(other_stats.mean);
731            }
732        }
733
734        if other_means.is_empty() {
735            return ComparativeAnalysis {
736                relative_performance: 1.0,
737                rank_among_layers: 1,
738                similar_layers: vec![],
739                performance_gap: 0.0,
740                optimization_potential: 0.5,
741            };
742        }
743
744        other_means.sort_by(|a, b| b.partial_cmp(a).unwrap());
745        let rank = other_means
746            .iter()
747            .position(|&x| x <= current_stats.mean)
748            .unwrap_or(other_means.len())
749            + 1;
750
751        let avg_other_mean = other_means.iter().sum::<f64>() / other_means.len() as f64;
752        let relative_performance =
753            if avg_other_mean > 0.0 { current_stats.mean / avg_other_mean } else { 1.0 };
754
755        let performance_gap = (current_stats.mean - avg_other_mean).abs();
756
757        // Find similar layers (within 20% of performance)
758        let similar_layers: Vec<String> = all_histories
759            .iter()
760            .filter(|(other_name, other_history)| {
761                if *other_name == layer_name {
762                    return false;
763                }
764                let other_stats = self.compute_detailed_gradient_stats(other_history);
765                let ratio = (current_stats.mean / (other_stats.mean + 1e-8))
766                    .max(other_stats.mean / (current_stats.mean + 1e-8));
767                ratio <= 1.2
768            })
769            .map(|(name, _)| name.clone())
770            .collect();
771
772        let optimization_potential = if relative_performance < 0.5 {
773            0.8
774        } else if relative_performance < 0.8 {
775            0.6
776        } else {
777            0.3
778        };
779
780        ComparativeAnalysis {
781            relative_performance,
782            rank_among_layers: rank,
783            similar_layers,
784            performance_gap,
785            optimization_potential,
786        }
787    }
788
789    fn analyze_network_level_gradients(
790        &self,
791        layer_details: &HashMap<String, LayerGradientDetails>,
792    ) -> NetworkLevelAnalysis {
793        let overall_gradient_health = self.assess_overall_health(layer_details);
794        let gradient_distribution = self.analyze_gradient_distribution(layer_details);
795        let layer_interactions = self.analyze_layer_interactions(layer_details);
796        let convergence_indicators = self.analyze_convergence_indicators(layer_details);
797        let training_dynamics = self.analyze_training_dynamics(layer_details);
798        let stability_assessment = self.assess_network_stability(layer_details);
799
800        NetworkLevelAnalysis {
801            overall_gradient_health,
802            gradient_distribution,
803            layer_interactions,
804            convergence_indicators,
805            training_dynamics,
806            stability_assessment,
807        }
808    }
809
810    fn assess_overall_health(
811        &self,
812        layer_details: &HashMap<String, LayerGradientDetails>,
813    ) -> LayerHealth {
814        let health_counts = layer_details
815            .values()
816            .map(|details| &details.health_metrics.overall_health)
817            .fold([0, 0, 0], |mut acc, health| {
818                match health {
819                    LayerHealth::Healthy => acc[0] += 1,
820                    LayerHealth::Warning => acc[1] += 1,
821                    LayerHealth::Critical => acc[2] += 1,
822                    LayerHealth::Unknown => {}, // Ignore unknown health status
823                }
824                acc
825            });
826
827        let total = health_counts.iter().sum::<usize>();
828        if total == 0 {
829            return LayerHealth::Healthy;
830        }
831
832        let critical_ratio = health_counts[2] as f64 / total as f64;
833        let warning_ratio = health_counts[1] as f64 / total as f64;
834
835        if critical_ratio > 0.3 {
836            LayerHealth::Critical
837        } else if critical_ratio > 0.1 || warning_ratio > 0.5 {
838            LayerHealth::Warning
839        } else {
840            LayerHealth::Healthy
841        }
842    }
843
844    fn analyze_gradient_distribution(
845        &self,
846        layer_details: &HashMap<String, LayerGradientDetails>,
847    ) -> GradientDistribution {
848        let gradient_means: Vec<f64> =
849            layer_details.values().map(|details| details.gradient_statistics.mean).collect();
850
851        if gradient_means.is_empty() {
852            return GradientDistribution {
853                mean_gradient_norm: 0.0,
854                gradient_variance: 0.0,
855                gradient_skewness: 0.0,
856                gradient_kurtosis: 0.0,
857                layer_gradient_ratios: HashMap::new(),
858                distribution_type: DistributionType::Degenerate,
859            };
860        }
861
862        let n = gradient_means.len() as f64;
863        let mean_gradient_norm = gradient_means.iter().sum::<f64>() / n;
864        let gradient_variance =
865            gradient_means.iter().map(|&x| (x - mean_gradient_norm).powi(2)).sum::<f64>() / n;
866
867        let std_dev = gradient_variance.sqrt();
868        let gradient_skewness = if std_dev > 0.0 {
869            gradient_means
870                .iter()
871                .map(|&x| ((x - mean_gradient_norm) / std_dev).powi(3))
872                .sum::<f64>()
873                / n
874        } else {
875            0.0
876        };
877
878        let gradient_kurtosis = if std_dev > 0.0 {
879            gradient_means
880                .iter()
881                .map(|&x| ((x - mean_gradient_norm) / std_dev).powi(4))
882                .sum::<f64>()
883                / n
884                - 3.0
885        } else {
886            0.0
887        };
888
889        let mut layer_gradient_ratios = HashMap::new();
890        for (layer_name, details) in layer_details {
891            let ratio = if mean_gradient_norm > 0.0 {
892                details.gradient_statistics.mean / mean_gradient_norm
893            } else {
894                1.0
895            };
896            layer_gradient_ratios.insert(layer_name.clone(), ratio);
897        }
898
899        let distribution_type =
900            self.classify_distribution_type(gradient_skewness, gradient_kurtosis);
901
902        GradientDistribution {
903            mean_gradient_norm,
904            gradient_variance,
905            gradient_skewness,
906            gradient_kurtosis,
907            layer_gradient_ratios,
908            distribution_type,
909        }
910    }
911
912    fn classify_distribution_type(&self, skewness: f64, kurtosis: f64) -> DistributionType {
913        if skewness.abs() > 2.0 {
914            DistributionType::Skewed
915        } else if kurtosis > 3.0 {
916            DistributionType::HeavyTailed
917        } else if kurtosis < -1.0 {
918            DistributionType::Multimodal
919        } else {
920            DistributionType::Normal
921        }
922    }
923
924    fn analyze_layer_interactions(
925        &self,
926        layer_details: &HashMap<String, LayerGradientDetails>,
927    ) -> Vec<LayerInteraction> {
928        let mut interactions = Vec::new();
929
930        let layer_names: Vec<String> = layer_details.keys().cloned().collect();
931        for i in 0..layer_names.len() {
932            for j in (i + 1)..layer_names.len() {
933                let layer1 = &layer_names[i];
934                let layer2 = &layer_names[j];
935
936                if let (Some(details1), Some(details2)) =
937                    (layer_details.get(layer1), layer_details.get(layer2))
938                {
939                    let interaction_strength =
940                        self.compute_interaction_strength(details1, details2);
941                    let interaction_type = self.classify_interaction_type(details1, details2);
942                    let impact_score = interaction_strength * 0.5; // Simplified impact calculation
943
944                    interactions.push(LayerInteraction {
945                        layer1: layer1.clone(),
946                        layer2: layer2.clone(),
947                        interaction_strength,
948                        interaction_type,
949                        impact_score,
950                    });
951                }
952            }
953        }
954
955        interactions
956    }
957
958    fn compute_interaction_strength(
959        &self,
960        details1: &LayerGradientDetails,
961        details2: &LayerGradientDetails,
962    ) -> f64 {
963        let mean_diff =
964            (details1.gradient_statistics.mean - details2.gradient_statistics.mean).abs();
965        let stability_diff = (details1.flow_characteristics.stability_measure
966            - details2.flow_characteristics.stability_measure)
967            .abs();
968
969        // Interaction strength is inversely related to differences
970        let combined_diff = mean_diff + stability_diff;
971        1.0 / (1.0 + combined_diff)
972    }
973
974    fn classify_interaction_type(
975        &self,
976        details1: &LayerGradientDetails,
977        details2: &LayerGradientDetails,
978    ) -> InteractionType {
979        let convergence_diff = (details1.health_metrics.convergence_indicator
980            - details2.health_metrics.convergence_indicator)
981            .abs();
982
983        if convergence_diff < 0.1 {
984            InteractionType::Cooperative
985        } else if convergence_diff > 0.5 {
986            InteractionType::Competitive
987        } else {
988            InteractionType::Neutral
989        }
990    }
991
992    fn analyze_convergence_indicators(
993        &self,
994        layer_details: &HashMap<String, LayerGradientDetails>,
995    ) -> ConvergenceIndicators {
996        let convergence_scores: Vec<f64> = layer_details
997            .values()
998            .map(|details| details.health_metrics.convergence_indicator)
999            .collect();
1000
1001        let gradient_convergence_score =
1002            convergence_scores.iter().sum::<f64>() / convergence_scores.len().max(1) as f64;
1003        let parameter_convergence_score = gradient_convergence_score * 0.8; // Simplified
1004        let loss_convergence_score = gradient_convergence_score * 0.9; // Simplified
1005
1006        let convergence_trend = if gradient_convergence_score > 0.8 {
1007            ConvergenceTrend::Converging
1008        } else if gradient_convergence_score > 0.6 {
1009            ConvergenceTrend::Stable
1010        } else if gradient_convergence_score < 0.3 {
1011            ConvergenceTrend::Diverging
1012        } else {
1013            ConvergenceTrend::Unknown
1014        };
1015
1016        let estimated_steps_to_convergence = if gradient_convergence_score > 0.1 {
1017            Some(((1.0 - gradient_convergence_score) * 1000.0) as usize)
1018        } else {
1019            None
1020        };
1021
1022        ConvergenceIndicators {
1023            gradient_convergence_score,
1024            parameter_convergence_score,
1025            loss_convergence_score,
1026            convergence_trend,
1027            estimated_steps_to_convergence,
1028        }
1029    }
1030
1031    fn analyze_training_dynamics(
1032        &self,
1033        layer_details: &HashMap<String, LayerGradientDetails>,
1034    ) -> TrainingDynamics {
1035        let avg_convergence = layer_details
1036            .values()
1037            .map(|details| details.health_metrics.convergence_indicator)
1038            .sum::<f64>()
1039            / layer_details.len().max(1) as f64;
1040
1041        let learning_phase = match avg_convergence {
1042            x if x < 0.2 => LearningPhase::InitialLearning,
1043            x if x < 0.4 => LearningPhase::RapidLearning,
1044            x if x < 0.6 => LearningPhase::Refinement,
1045            x if x < 0.8 => LearningPhase::Convergence,
1046            _ => LearningPhase::Plateau,
1047        };
1048
1049        let gradient_momentum = avg_convergence * 0.8; // Simplified
1050        let learning_velocity = avg_convergence * 1.2; // Simplified
1051        let adaptation_rate = 1.0 - avg_convergence; // Simplified
1052
1053        let plateau_detection = PlateauDetection {
1054            is_plateau: avg_convergence > 0.9,
1055            plateau_duration: if avg_convergence > 0.9 { 10 } else { 0 },
1056            plateau_severity: if avg_convergence > 0.95 {
1057                PlateauSeverity::Severe
1058            } else {
1059                PlateauSeverity::Mild
1060            },
1061            suggested_actions: if avg_convergence > 0.9 {
1062                vec![
1063                    "Consider learning rate reduction".to_string(),
1064                    "Add regularization".to_string(),
1065                ]
1066            } else {
1067                vec![]
1068            },
1069        };
1070
1071        TrainingDynamics {
1072            learning_phase,
1073            gradient_momentum,
1074            learning_velocity,
1075            adaptation_rate,
1076            plateau_detection,
1077        }
1078    }
1079
1080    fn assess_network_stability(
1081        &self,
1082        layer_details: &HashMap<String, LayerGradientDetails>,
1083    ) -> StabilityAssessment {
1084        let stability_scores: Vec<f64> = layer_details
1085            .values()
1086            .map(|details| details.flow_characteristics.stability_measure)
1087            .collect();
1088
1089        let overall_stability =
1090            stability_scores.iter().sum::<f64>() / stability_scores.len().max(1) as f64;
1091
1092        let stability_trend = if overall_stability > self.stability_threshold {
1093            StabilityTrend::Stable
1094        } else {
1095            StabilityTrend::Degrading
1096        };
1097
1098        let instability_sources = self.identify_instability_sources(layer_details);
1099
1100        let stability_forecast = StabilityForecast {
1101            short_term_outlook: if overall_stability > 0.7 {
1102                StabilityOutlook::Stable
1103            } else {
1104                StabilityOutlook::Deteriorating
1105            },
1106            long_term_outlook: if overall_stability > 0.8 {
1107                StabilityOutlook::Stable
1108            } else {
1109                StabilityOutlook::Uncertain
1110            },
1111            confidence_level: overall_stability,
1112            recommended_monitoring: vec![
1113                "Monitor gradient norms".to_string(),
1114                "Track convergence indicators".to_string(),
1115            ],
1116        };
1117
1118        StabilityAssessment {
1119            overall_stability,
1120            stability_trend,
1121            instability_sources,
1122            stability_forecast,
1123        }
1124    }
1125
1126    fn identify_instability_sources(
1127        &self,
1128        layer_details: &HashMap<String, LayerGradientDetails>,
1129    ) -> Vec<InstabilitySource> {
1130        let mut sources = Vec::new();
1131
1132        for (layer_name, details) in layer_details {
1133            if details.gradient_statistics.mean > 100.0 {
1134                sources.push(InstabilitySource {
1135                    source_type: InstabilityType::GradientExplosion,
1136                    affected_layers: vec![layer_name.clone()],
1137                    severity: details.gradient_statistics.mean / 100.0,
1138                    description: format!("High gradient magnitude in layer {}", layer_name),
1139                });
1140            }
1141
1142            if details.gradient_statistics.mean < 1e-5 {
1143                sources.push(InstabilitySource {
1144                    source_type: InstabilityType::GradientVanishing,
1145                    affected_layers: vec![layer_name.clone()],
1146                    severity: 1.0 - (details.gradient_statistics.mean * 1e5),
1147                    description: format!("Very low gradient magnitude in layer {}", layer_name),
1148                });
1149            }
1150
1151            if details.flow_characteristics.oscillation_frequency > 0.5 {
1152                sources.push(InstabilitySource {
1153                    source_type: InstabilityType::Oscillation,
1154                    affected_layers: vec![layer_name.clone()],
1155                    severity: details.flow_characteristics.oscillation_frequency,
1156                    description: format!("High oscillation frequency in layer {}", layer_name),
1157                });
1158            }
1159        }
1160
1161        sources
1162    }
1163
1164    fn build_gradient_hierarchy(
1165        &self,
1166        layer_details: &HashMap<String, LayerGradientDetails>,
1167    ) -> GradientHierarchy {
1168        // Simplified hierarchy building - group layers by similar characteristics
1169        let mut layer_groups = Vec::new();
1170        let mut hierarchy_levels = Vec::new();
1171        let cross_level_interactions = Vec::new(); // Simplified - would compute actual interactions
1172
1173        // Group by gradient magnitude ranges
1174        let high_gradient_layers: Vec<String> = layer_details
1175            .iter()
1176            .filter(|(_, details)| details.gradient_statistics.mean > 1.0)
1177            .map(|(name, _)| name.clone())
1178            .collect();
1179
1180        let medium_gradient_layers: Vec<String> = layer_details
1181            .iter()
1182            .filter(|(_, details)| {
1183                details.gradient_statistics.mean >= 0.1 && details.gradient_statistics.mean <= 1.0
1184            })
1185            .map(|(name, _)| name.clone())
1186            .collect();
1187
1188        let low_gradient_layers: Vec<String> = layer_details
1189            .iter()
1190            .filter(|(_, details)| details.gradient_statistics.mean < 0.1)
1191            .map(|(name, _)| name.clone())
1192            .collect();
1193
1194        if !high_gradient_layers.is_empty() {
1195            layer_groups.push(LayerGroup {
1196                group_name: "High Gradient Layers".to_string(),
1197                layers: high_gradient_layers.clone(),
1198                group_characteristics: GroupCharacteristics {
1199                    average_gradient_norm: 2.0, // Simplified
1200                    gradient_synchronization: 0.8,
1201                    learning_rate_sensitivity: 0.9,
1202                    optimization_difficulty: OptimizationDifficulty::Difficult,
1203                },
1204                internal_coherence: 0.7,
1205            });
1206
1207            hierarchy_levels.push(HierarchyLevel {
1208                level_id: 0,
1209                level_name: "High Gradient Level".to_string(),
1210                layer_groups: vec!["High Gradient Layers".to_string()],
1211                level_importance: 0.9,
1212                optimization_impact: 0.8,
1213            });
1214        }
1215
1216        if !medium_gradient_layers.is_empty() {
1217            layer_groups.push(LayerGroup {
1218                group_name: "Medium Gradient Layers".to_string(),
1219                layers: medium_gradient_layers,
1220                group_characteristics: GroupCharacteristics {
1221                    average_gradient_norm: 0.5,
1222                    gradient_synchronization: 0.6,
1223                    learning_rate_sensitivity: 0.5,
1224                    optimization_difficulty: OptimizationDifficulty::Moderate,
1225                },
1226                internal_coherence: 0.8,
1227            });
1228
1229            hierarchy_levels.push(HierarchyLevel {
1230                level_id: 1,
1231                level_name: "Medium Gradient Level".to_string(),
1232                layer_groups: vec!["Medium Gradient Layers".to_string()],
1233                level_importance: 0.7,
1234                optimization_impact: 0.6,
1235            });
1236        }
1237
1238        if !low_gradient_layers.is_empty() {
1239            layer_groups.push(LayerGroup {
1240                group_name: "Low Gradient Layers".to_string(),
1241                layers: low_gradient_layers,
1242                group_characteristics: GroupCharacteristics {
1243                    average_gradient_norm: 0.05,
1244                    gradient_synchronization: 0.4,
1245                    learning_rate_sensitivity: 0.3,
1246                    optimization_difficulty: OptimizationDifficulty::Easy,
1247                },
1248                internal_coherence: 0.5,
1249            });
1250
1251            hierarchy_levels.push(HierarchyLevel {
1252                level_id: 2,
1253                level_name: "Low Gradient Level".to_string(),
1254                layer_groups: vec!["Low Gradient Layers".to_string()],
1255                level_importance: 0.5,
1256                optimization_impact: 0.4,
1257            });
1258        }
1259
1260        GradientHierarchy {
1261            layer_groups,
1262            hierarchy_levels,
1263            cross_level_interactions,
1264        }
1265    }
1266
1267    fn rank_optimization_priorities(
1268        &self,
1269        layer_details: &HashMap<String, LayerGradientDetails>,
1270        network_analysis: &NetworkLevelAnalysis,
1271    ) -> Vec<OptimizationPriority> {
1272        let mut priorities = Vec::new();
1273
1274        for (layer_name, details) in layer_details {
1275            let priority_score = self.calculate_priority_score(details, network_analysis);
1276            let urgency_level = self.determine_urgency_level(details);
1277            let optimization_potential = details.comparative_analysis.optimization_potential;
1278            let recommended_actions = self.generate_prioritized_actions(details);
1279
1280            priorities.push(OptimizationPriority {
1281                target_name: layer_name.clone(),
1282                target_type: OptimizationTarget::IndividualLayer,
1283                priority_score,
1284                urgency_level,
1285                optimization_potential,
1286                recommended_actions,
1287            });
1288        }
1289
1290        // Sort by priority score
1291        priorities.sort_by(|a, b| b.priority_score.partial_cmp(&a.priority_score).unwrap());
1292
1293        priorities
1294    }
1295
1296    fn calculate_priority_score(
1297        &self,
1298        details: &LayerGradientDetails,
1299        network_analysis: &NetworkLevelAnalysis,
1300    ) -> f64 {
1301        let health_weight = match details.health_metrics.overall_health {
1302            LayerHealth::Critical => 1.0,
1303            LayerHealth::Warning => 0.7,
1304            LayerHealth::Healthy => 0.3,
1305            LayerHealth::Unknown => 0.5, // Default moderate weight for unknown health
1306        };
1307
1308        let stability_weight = 1.0 - details.flow_characteristics.stability_measure;
1309        let optimization_weight = details.comparative_analysis.optimization_potential;
1310        let network_impact_weight = details.health_metrics.information_flow_rate
1311            / network_analysis.gradient_distribution.mean_gradient_norm.max(1e-8);
1312
1313        (health_weight * 0.4
1314            + stability_weight * 0.3
1315            + optimization_weight * 0.2
1316            + network_impact_weight * 0.1)
1317            .min(1.0)
1318    }
1319
1320    fn determine_urgency_level(&self, details: &LayerGradientDetails) -> UrgencyLevel {
1321        match details.health_metrics.overall_health {
1322            LayerHealth::Critical => UrgencyLevel::Critical,
1323            LayerHealth::Warning => {
1324                if details.flow_characteristics.stability_measure < 0.3 {
1325                    UrgencyLevel::High
1326                } else {
1327                    UrgencyLevel::Medium
1328                }
1329            },
1330            LayerHealth::Healthy => UrgencyLevel::Low,
1331            LayerHealth::Unknown => UrgencyLevel::Medium, // Default moderate urgency for unknown health
1332        }
1333    }
1334
1335    fn generate_prioritized_actions(
1336        &self,
1337        details: &LayerGradientDetails,
1338    ) -> Vec<PrioritizedAction> {
1339        let mut actions = Vec::new();
1340
1341        if details.gradient_statistics.mean < 1e-5 {
1342            actions.push(PrioritizedAction {
1343                action_name: "Weight Initialization Improvement".to_string(),
1344                action_type: ActionType::ParameterAdjustment,
1345                expected_impact: 0.8,
1346                implementation_effort: ImplementationEffort::Moderate,
1347                prerequisites: vec!["Model architecture review".to_string()],
1348            });
1349        }
1350
1351        if details.gradient_statistics.mean > 10.0 {
1352            actions.push(PrioritizedAction {
1353                action_name: "Learning Rate Reduction".to_string(),
1354                action_type: ActionType::LearningRateScheduling,
1355                expected_impact: 0.7,
1356                implementation_effort: ImplementationEffort::Minimal,
1357                prerequisites: vec![],
1358            });
1359        }
1360
1361        if details.flow_characteristics.stability_measure < 0.5 {
1362            actions.push(PrioritizedAction {
1363                action_name: "Gradient Clipping".to_string(),
1364                action_type: ActionType::OptimizationTechnique,
1365                expected_impact: 0.6,
1366                implementation_effort: ImplementationEffort::Low,
1367                prerequisites: vec!["Hyperparameter tuning".to_string()],
1368            });
1369        }
1370
1371        actions
1372    }
1373
1374    /// Analyze gradients and generate enhanced analysis results
1375    pub fn analyze_gradients(
1376        &self,
1377        gradient_histories: &HashMap<String, GradientHistory>,
1378    ) -> EnhancedLayerGradientAnalysis {
1379        // Use existing method to generate the analysis
1380        self.generate_enhanced_analysis(gradient_histories)
1381    }
1382}