trustformers_debug/
advanced_ml_debugging.rs

1//! # Advanced ML Debugging Tools
2//!
3//! Advanced machine learning specific debugging techniques including layer-wise learning rate adaptation,
4//! model sensitivity analysis, gradient flow optimization, and neural architecture debugging.
5
6use anyhow::Result;
7use chrono::{DateTime, Utc};
8use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::{Array1, Array2, Array3, ArrayD};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Configuration for advanced ML debugging
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AdvancedMLDebuggingConfig {
15    /// Enable layer-wise learning rate analysis
16    pub enable_layer_wise_lr_analysis: bool,
17    /// Enable model sensitivity analysis
18    pub enable_model_sensitivity_analysis: bool,
19    /// Enable gradient flow optimization analysis
20    pub enable_gradient_flow_optimization: bool,
21    /// Enable neural architecture debugging
22    pub enable_neural_architecture_debugging: bool,
23    /// Enable activation pattern analysis
24    pub enable_activation_pattern_analysis: bool,
25    /// Enable weight distribution analysis
26    pub enable_weight_distribution_analysis: bool,
27    /// Enable training dynamics analysis
28    pub enable_training_dynamics_analysis: bool,
29    /// Enable optimization landscape analysis
30    pub enable_optimization_landscape_analysis: bool,
31    /// Number of samples for sensitivity analysis
32    pub sensitivity_samples: usize,
33    /// Learning rate adaptation threshold
34    pub lr_adaptation_threshold: f64,
35    /// Maximum number of layers to analyze
36    pub max_layers_to_analyze: usize,
37}
38
39impl Default for AdvancedMLDebuggingConfig {
40    fn default() -> Self {
41        Self {
42            enable_layer_wise_lr_analysis: true,
43            enable_model_sensitivity_analysis: true,
44            enable_gradient_flow_optimization: true,
45            enable_neural_architecture_debugging: true,
46            enable_activation_pattern_analysis: true,
47            enable_weight_distribution_analysis: true,
48            enable_training_dynamics_analysis: true,
49            enable_optimization_landscape_analysis: true,
50            sensitivity_samples: 1000,
51            lr_adaptation_threshold: 0.1,
52            max_layers_to_analyze: 50,
53        }
54    }
55}
56
57/// Layer-wise learning rate adaptation analysis result
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LayerWiseLRAnalysisResult {
60    /// Analysis timestamp
61    pub timestamp: DateTime<Utc>,
62    /// Learning rate recommendations per layer
63    pub layer_lr_recommendations: HashMap<String, LayerLRRecommendation>,
64    /// Global learning rate insights
65    pub global_lr_insights: GlobalLRInsights,
66    /// Learning rate adaptation strategy
67    pub adaptation_strategy: LRAdaptationStrategy,
68    /// Training phase recommendations
69    pub training_phase_recommendations: Vec<TrainingPhaseRecommendation>,
70    /// Performance predictions with different LR schedules
71    pub lr_schedule_predictions: Vec<LRSchedulePrediction>,
72}
73
74/// Learning rate recommendation for a specific layer
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LayerLRRecommendation {
77    /// Layer identifier
78    pub layer_id: String,
79    /// Layer type (e.g., "attention", "feedforward", "embedding")
80    pub layer_type: String,
81    /// Current learning rate
82    pub current_lr: f64,
83    /// Recommended learning rate
84    pub recommended_lr: f64,
85    /// Recommendation confidence
86    pub confidence: f64,
87    /// Reasoning for recommendation
88    pub reasoning: String,
89    /// Layer-specific metrics
90    pub layer_metrics: LayerLRMetrics,
91    /// Sensitivity to learning rate changes
92    pub lr_sensitivity: f64,
93    /// Adaptation urgency level
94    pub urgency: AdaptationUrgency,
95}
96
97/// Layer-specific learning rate metrics
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct LayerLRMetrics {
100    /// Gradient magnitude
101    pub gradient_magnitude: f64,
102    /// Weight update magnitude
103    pub weight_update_magnitude: f64,
104    /// Parameter norm
105    pub parameter_norm: f64,
106    /// Loss contribution
107    pub loss_contribution: f64,
108    /// Training stability score
109    pub stability_score: f64,
110    /// Convergence rate
111    pub convergence_rate: f64,
112    /// Learning efficiency
113    pub learning_efficiency: f64,
114}
115
116/// Urgency level for learning rate adaptation
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub enum AdaptationUrgency {
119    Low,
120    Medium,
121    High,
122    Critical,
123}
124
125/// Global learning rate insights
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GlobalLRInsights {
128    /// Overall model learning efficiency
129    pub overall_efficiency: f64,
130    /// Learning rate distribution health
131    pub lr_distribution_health: f64,
132    /// Gradient flow quality
133    pub gradient_flow_quality: f64,
134    /// Training stability assessment
135    pub training_stability: TrainingStability,
136    /// Recommended global adjustments
137    pub global_adjustments: Vec<GlobalLRAdjustment>,
138    /// Critical issues requiring immediate attention
139    pub critical_issues: Vec<String>,
140}
141
142/// Training stability assessment
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TrainingStability {
145    /// Stability score (0-1)
146    pub stability_score: f64,
147    /// Instability indicators
148    pub instability_indicators: Vec<InstabilityIndicator>,
149    /// Stability trends over time
150    pub stability_trends: Vec<StabilityTrendPoint>,
151    /// Predicted stability with current settings
152    pub predicted_stability: f64,
153}
154
155/// Indicator of training instability
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct InstabilityIndicator {
158    /// Type of instability
159    pub instability_type: InstabilityType,
160    /// Severity level
161    pub severity: f64,
162    /// Affected layers
163    pub affected_layers: Vec<String>,
164    /// Recommended actions
165    pub recommended_actions: Vec<String>,
166}
167
168/// Type of training instability
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum InstabilityType {
171    GradientExplosion,
172    GradientVanishing,
173    OscillatingLoss,
174    SlowConvergence,
175    WeightDivergence,
176    NumericalInstability,
177}
178
179/// Point in stability trend analysis
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct StabilityTrendPoint {
182    /// Time step or epoch
183    pub time_step: usize,
184    /// Stability score at this point
185    pub stability_score: f64,
186    /// Contributing factors
187    pub contributing_factors: HashMap<String, f64>,
188}
189
190/// Global learning rate adjustment recommendation
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GlobalLRAdjustment {
193    /// Adjustment type
194    pub adjustment_type: GlobalAdjustmentType,
195    /// Adjustment magnitude
196    pub magnitude: f64,
197    /// Expected impact
198    pub expected_impact: f64,
199    /// Implementation priority
200    pub priority: AdjustmentPriority,
201    /// Implementation instructions
202    pub instructions: String,
203}
204
205/// Type of global learning rate adjustment
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub enum GlobalAdjustmentType {
208    UniformScaling,
209    LayerTypeSpecific,
210    DepthDependent,
211    AdaptiveScheduling,
212    WarmupAdjustment,
213    DecayRateModification,
214}
215
216/// Priority level for adjustments
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub enum AdjustmentPriority {
219    Low,
220    Medium,
221    High,
222    Immediate,
223}
224
225/// Learning rate adaptation strategy
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LRAdaptationStrategy {
228    /// Strategy name
229    pub strategy_name: String,
230    /// Strategy description
231    pub description: String,
232    /// Implementation steps
233    pub implementation_steps: Vec<ImplementationStep>,
234    /// Expected benefits
235    pub expected_benefits: Vec<String>,
236    /// Potential risks
237    pub potential_risks: Vec<String>,
238    /// Success metrics
239    pub success_metrics: Vec<String>,
240    /// Monitoring requirements
241    pub monitoring_requirements: Vec<String>,
242}
243
244/// Step in implementing an adaptation strategy
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ImplementationStep {
247    /// Step number
248    pub step_number: usize,
249    /// Step description
250    pub description: String,
251    /// Code changes required
252    pub code_changes: Vec<String>,
253    /// Expected timeline
254    pub timeline: String,
255    /// Dependencies
256    pub dependencies: Vec<String>,
257}
258
259/// Training phase recommendation
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct TrainingPhaseRecommendation {
262    /// Phase name
263    pub phase_name: String,
264    /// Phase duration (epochs)
265    pub duration_epochs: usize,
266    /// Learning rate schedule for this phase
267    pub lr_schedule: LRSchedule,
268    /// Phase objectives
269    pub objectives: Vec<String>,
270    /// Success criteria
271    pub success_criteria: Vec<String>,
272    /// Transition conditions
273    pub transition_conditions: Vec<String>,
274}
275
276/// Learning rate schedule definition
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct LRSchedule {
279    /// Schedule type
280    pub schedule_type: LRScheduleType,
281    /// Initial learning rate
282    pub initial_lr: f64,
283    /// Schedule parameters
284    pub parameters: HashMap<String, f64>,
285    /// Layer-specific multipliers
286    pub layer_multipliers: HashMap<String, f64>,
287}
288
289/// Type of learning rate schedule
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum LRScheduleType {
292    Constant,
293    LinearDecay,
294    ExponentialDecay,
295    CosineAnnealing,
296    StepDecay,
297    CyclicalLR,
298    OneCycleLR,
299    AdaptiveSchedule,
300}
301
302/// Prediction of performance with different LR schedules
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct LRSchedulePrediction {
305    /// Schedule being evaluated
306    pub schedule: LRSchedule,
307    /// Predicted final accuracy
308    pub predicted_accuracy: f64,
309    /// Predicted convergence time
310    pub predicted_convergence_epochs: usize,
311    /// Predicted training stability
312    pub predicted_stability: f64,
313    /// Confidence in prediction
314    pub prediction_confidence: f64,
315    /// Risk assessment
316    pub risk_assessment: RiskAssessment,
317}
318
319/// Risk assessment for a learning rate schedule
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct RiskAssessment {
322    /// Overall risk level
323    pub overall_risk: RiskLevel,
324    /// Specific risks
325    pub specific_risks: Vec<SpecificRisk>,
326    /// Mitigation strategies
327    pub mitigation_strategies: Vec<String>,
328}
329
330/// Risk level assessment
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub enum RiskLevel {
333    VeryLow,
334    Low,
335    Medium,
336    High,
337    VeryHigh,
338}
339
340/// Specific risk in training
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct SpecificRisk {
343    /// Risk type
344    pub risk_type: String,
345    /// Probability of occurrence
346    pub probability: f64,
347    /// Impact severity
348    pub impact: f64,
349    /// Description
350    pub description: String,
351}
352
353/// Model sensitivity analysis result
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ModelSensitivityAnalysisResult {
356    /// Analysis timestamp
357    pub timestamp: DateTime<Utc>,
358    /// Hyperparameter sensitivity analysis
359    pub hyperparameter_sensitivity: HyperparameterSensitivity,
360    /// Architecture sensitivity analysis
361    pub architecture_sensitivity: ArchitectureSensitivity,
362    /// Data sensitivity analysis
363    pub data_sensitivity: DataSensitivity,
364    /// Training procedure sensitivity
365    pub training_sensitivity: TrainingSensitivity,
366    /// Overall sensitivity insights
367    pub sensitivity_insights: SensitivityInsights,
368}
369
370/// Hyperparameter sensitivity analysis
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct HyperparameterSensitivity {
373    /// Learning rate sensitivity
374    pub learning_rate_sensitivity: ParameterSensitivity,
375    /// Batch size sensitivity
376    pub batch_size_sensitivity: ParameterSensitivity,
377    /// Regularization sensitivity
378    pub regularization_sensitivity: ParameterSensitivity,
379    /// Architecture parameter sensitivity
380    pub architecture_param_sensitivity: HashMap<String, ParameterSensitivity>,
381    /// Most sensitive parameters
382    pub most_sensitive_params: Vec<String>,
383    /// Least sensitive parameters
384    pub least_sensitive_params: Vec<String>,
385    /// Parameter interaction effects
386    pub interaction_effects: Vec<ParameterInteraction>,
387}
388
389/// Sensitivity analysis for a specific parameter
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct ParameterSensitivity {
392    /// Parameter name
393    pub parameter_name: String,
394    /// Current value
395    pub current_value: f64,
396    /// Sensitivity score
397    pub sensitivity_score: f64,
398    /// Optimal value range
399    pub optimal_range: (f64, f64),
400    /// Performance impact curve
401    pub impact_curve: Vec<(f64, f64)>,
402    /// Stability region
403    pub stability_region: (f64, f64),
404    /// Critical thresholds
405    pub critical_thresholds: Vec<f64>,
406}
407
408/// Interaction between parameters
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct ParameterInteraction {
411    /// First parameter
412    pub param1: String,
413    /// Second parameter
414    pub param2: String,
415    /// Interaction strength
416    pub interaction_strength: f64,
417    /// Interaction type
418    pub interaction_type: InteractionType,
419    /// Joint optimal region
420    pub joint_optimal_region: HashMap<String, (f64, f64)>,
421}
422
423/// Type of parameter interaction
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub enum InteractionType {
426    Synergistic,
427    Antagonistic,
428    Independent,
429    Conditional,
430}
431
432/// Architecture sensitivity analysis
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ArchitectureSensitivity {
435    /// Layer depth sensitivity
436    pub depth_sensitivity: ArchitecturalSensitivity,
437    /// Layer width sensitivity
438    pub width_sensitivity: ArchitecturalSensitivity,
439    /// Attention head sensitivity
440    pub attention_head_sensitivity: ArchitecturalSensitivity,
441    /// Skip connection sensitivity
442    pub skip_connection_sensitivity: ArchitecturalSensitivity,
443    /// Architectural component importance
444    pub component_importance: HashMap<String, f64>,
445    /// Architectural bottlenecks
446    pub bottlenecks: Vec<ArchitecturalBottleneck>,
447}
448
449/// Sensitivity analysis for architectural component
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct ArchitecturalSensitivity {
452    /// Component name
453    pub component_name: String,
454    /// Sensitivity to changes
455    pub change_sensitivity: f64,
456    /// Performance degradation curve
457    pub degradation_curve: Vec<(f64, f64)>,
458    /// Minimum viable configuration
459    pub min_viable_config: f64,
460    /// Optimal configuration
461    pub optimal_config: f64,
462    /// Diminishing returns threshold
463    pub diminishing_returns_threshold: f64,
464}
465
466/// Architectural bottleneck
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct ArchitecturalBottleneck {
469    /// Bottleneck location
470    pub location: String,
471    /// Bottleneck type
472    pub bottleneck_type: BottleneckType,
473    /// Severity
474    pub severity: f64,
475    /// Performance impact
476    pub performance_impact: f64,
477    /// Resolution recommendations
478    pub resolution_recommendations: Vec<String>,
479}
480
481/// Type of architectural bottleneck
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum BottleneckType {
484    ComputationalBottleneck,
485    MemoryBottleneck,
486    InformationBottleneck,
487    CapacityBottleneck,
488    CommunicationBottleneck,
489}
490
491/// Data sensitivity analysis
492#[derive(Debug, Clone, Serialize, Deserialize)]
493pub struct DataSensitivity {
494    /// Training data size sensitivity
495    pub data_size_sensitivity: DataSizeSensitivity,
496    /// Data quality sensitivity
497    pub data_quality_sensitivity: DataQualitySensitivity,
498    /// Data distribution sensitivity
499    pub distribution_sensitivity: DistributionSensitivity,
500    /// Feature sensitivity analysis
501    pub feature_sensitivity: FeatureSensitivityAnalysis,
502}
503
504/// Sensitivity to training data size
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct DataSizeSensitivity {
507    /// Current data size
508    pub current_size: usize,
509    /// Minimum effective size
510    pub minimum_effective_size: usize,
511    /// Performance vs size curve
512    pub performance_curve: Vec<(usize, f64)>,
513    /// Data efficiency score
514    pub data_efficiency: f64,
515    /// Diminishing returns point
516    pub diminishing_returns_point: usize,
517}
518
519/// Sensitivity to data quality
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct DataQualitySensitivity {
522    /// Noise tolerance
523    pub noise_tolerance: f64,
524    /// Label quality importance
525    pub label_quality_importance: f64,
526    /// Feature quality importance
527    pub feature_quality_importance: f64,
528    /// Quality degradation impact
529    pub quality_impact_curve: Vec<(f64, f64)>,
530}
531
532/// Sensitivity to data distribution
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct DistributionSensitivity {
535    /// Distribution shift sensitivity
536    pub shift_sensitivity: f64,
537    /// Class imbalance sensitivity
538    pub imbalance_sensitivity: f64,
539    /// Domain adaptation requirements
540    pub domain_adaptation_requirements: Vec<String>,
541    /// Robustness to distribution changes
542    pub distribution_robustness: f64,
543}
544
545/// Feature-level sensitivity analysis
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct FeatureSensitivityAnalysis {
548    /// Most important features
549    pub most_important_features: Vec<String>,
550    /// Least important features
551    pub least_important_features: Vec<String>,
552    /// Feature interaction importance
553    pub feature_interactions: HashMap<(String, String), f64>,
554    /// Feature stability analysis
555    pub feature_stability: HashMap<String, f64>,
556}
557
558/// Training procedure sensitivity
559#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct TrainingSensitivity {
561    /// Initialization sensitivity
562    pub initialization_sensitivity: InitializationSensitivity,
563    /// Optimization method sensitivity
564    pub optimization_sensitivity: OptimizationSensitivity,
565    /// Training schedule sensitivity
566    pub schedule_sensitivity: ScheduleSensitivity,
567    /// Regularization sensitivity
568    pub regularization_sensitivity: RegularizationSensitivity,
569}
570
571/// Sensitivity to initialization
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub struct InitializationSensitivity {
574    /// Weight initialization sensitivity
575    pub weight_init_sensitivity: f64,
576    /// Bias initialization sensitivity
577    pub bias_init_sensitivity: f64,
578    /// Random seed sensitivity
579    pub seed_sensitivity: f64,
580    /// Initialization scheme importance
581    pub scheme_importance: HashMap<String, f64>,
582}
583
584/// Sensitivity to optimization method
585#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct OptimizationSensitivity {
587    /// Optimizer choice sensitivity
588    pub optimizer_sensitivity: f64,
589    /// Momentum parameter sensitivity
590    pub momentum_sensitivity: f64,
591    /// Second-order moment sensitivity
592    pub second_moment_sensitivity: f64,
593    /// Optimizer comparison
594    pub optimizer_comparison: HashMap<String, f64>,
595}
596
597/// Sensitivity to training schedule
598#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct ScheduleSensitivity {
600    /// Learning rate schedule sensitivity
601    pub lr_schedule_sensitivity: f64,
602    /// Training duration sensitivity
603    pub duration_sensitivity: f64,
604    /// Warmup sensitivity
605    pub warmup_sensitivity: f64,
606    /// Schedule parameter importance
607    pub schedule_param_importance: HashMap<String, f64>,
608}
609
610/// Sensitivity to regularization
611#[derive(Debug, Clone, Serialize, Deserialize)]
612pub struct RegularizationSensitivity {
613    /// Dropout sensitivity
614    pub dropout_sensitivity: f64,
615    /// Weight decay sensitivity
616    pub weight_decay_sensitivity: f64,
617    /// Batch normalization sensitivity
618    pub batch_norm_sensitivity: f64,
619    /// Regularization method comparison
620    pub method_comparison: HashMap<String, f64>,
621}
622
623/// Overall sensitivity insights
624#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct SensitivityInsights {
626    /// Most critical factors
627    pub most_critical_factors: Vec<String>,
628    /// Least critical factors
629    pub least_critical_factors: Vec<String>,
630    /// Surprising findings
631    pub surprising_findings: Vec<String>,
632    /// Robustness assessment
633    pub robustness_assessment: RobustnessAssessment,
634    /// Optimization recommendations
635    pub optimization_recommendations: Vec<String>,
636}
637
638/// Model robustness assessment
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct RobustnessAssessment {
641    /// Overall robustness score
642    pub overall_robustness: f64,
643    /// Robustness breakdown by category
644    pub category_robustness: HashMap<String, f64>,
645    /// Vulnerability areas
646    pub vulnerabilities: Vec<Vulnerability>,
647    /// Strength areas
648    pub strengths: Vec<String>,
649}
650
651/// Model vulnerability
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct Vulnerability {
654    /// Vulnerability type
655    pub vulnerability_type: String,
656    /// Severity level
657    pub severity: f64,
658    /// Impact description
659    pub impact: String,
660    /// Mitigation strategies
661    pub mitigation_strategies: Vec<String>,
662}
663
664/// Advanced ML debugger
665#[derive(Debug)]
666pub struct AdvancedMLDebugger {
667    config: AdvancedMLDebuggingConfig,
668    lr_analysis_results: Vec<LayerWiseLRAnalysisResult>,
669    sensitivity_analysis_results: Vec<ModelSensitivityAnalysisResult>,
670}
671
672impl AdvancedMLDebugger {
673    /// Create a new advanced ML debugger
674    pub fn new(config: AdvancedMLDebuggingConfig) -> Self {
675        Self {
676            config,
677            lr_analysis_results: Vec::new(),
678            sensitivity_analysis_results: Vec::new(),
679        }
680    }
681
682    /// Perform layer-wise learning rate analysis
683    pub async fn analyze_layer_wise_learning_rates(
684        &mut self,
685        layer_gradients: &HashMap<String, ArrayD<f32>>,
686        layer_weights: &HashMap<String, ArrayD<f32>>,
687        current_lr: f64,
688        loss_history: &[f64],
689    ) -> Result<LayerWiseLRAnalysisResult> {
690        if !self.config.enable_layer_wise_lr_analysis {
691            return Err(anyhow::anyhow!(
692                "Layer-wise learning rate analysis is disabled"
693            ));
694        }
695
696        let mut layer_lr_recommendations = HashMap::new();
697
698        // Analyze each layer
699        for (layer_id, gradients) in layer_gradients {
700            if let Some(weights) = layer_weights.get(layer_id) {
701                let recommendation = self.analyze_single_layer_lr(
702                    layer_id,
703                    gradients,
704                    weights,
705                    current_lr,
706                    loss_history,
707                );
708                layer_lr_recommendations.insert(layer_id.clone(), recommendation);
709            }
710        }
711
712        // Generate global insights
713        let global_lr_insights =
714            self.generate_global_lr_insights(&layer_lr_recommendations, loss_history);
715
716        // Create adaptation strategy
717        let adaptation_strategy =
718            self.create_lr_adaptation_strategy(&layer_lr_recommendations, &global_lr_insights);
719
720        // Generate training phase recommendations
721        let training_phase_recommendations =
722            self.generate_training_phase_recommendations(&adaptation_strategy);
723
724        // Predict performance with different schedules
725        let lr_schedule_predictions =
726            self.predict_lr_schedule_performance(&layer_lr_recommendations);
727
728        let result = LayerWiseLRAnalysisResult {
729            timestamp: Utc::now(),
730            layer_lr_recommendations,
731            global_lr_insights,
732            adaptation_strategy,
733            training_phase_recommendations,
734            lr_schedule_predictions,
735        };
736
737        self.lr_analysis_results.push(result.clone());
738        Ok(result)
739    }
740
741    /// Perform comprehensive model sensitivity analysis
742    pub async fn analyze_model_sensitivity(
743        &mut self,
744        model_params: &HashMap<String, f64>,
745        performance_metrics: &[f64],
746        architecture_config: &HashMap<String, f64>,
747    ) -> Result<ModelSensitivityAnalysisResult> {
748        if !self.config.enable_model_sensitivity_analysis {
749            return Err(anyhow::anyhow!("Model sensitivity analysis is disabled"));
750        }
751
752        // Analyze hyperparameter sensitivity
753        let hyperparameter_sensitivity =
754            self.analyze_hyperparameter_sensitivity(model_params, performance_metrics);
755
756        // Analyze architecture sensitivity
757        let architecture_sensitivity =
758            self.analyze_architecture_sensitivity(architecture_config, performance_metrics);
759
760        // Analyze data sensitivity (simulated)
761        let data_sensitivity = self.analyze_data_sensitivity(performance_metrics);
762
763        // Analyze training sensitivity
764        let training_sensitivity =
765            self.analyze_training_sensitivity(model_params, performance_metrics);
766
767        // Generate overall insights
768        let sensitivity_insights = self.generate_sensitivity_insights(
769            &hyperparameter_sensitivity,
770            &architecture_sensitivity,
771            &data_sensitivity,
772            &training_sensitivity,
773        );
774
775        let result = ModelSensitivityAnalysisResult {
776            timestamp: Utc::now(),
777            hyperparameter_sensitivity,
778            architecture_sensitivity,
779            data_sensitivity,
780            training_sensitivity,
781            sensitivity_insights,
782        };
783
784        self.sensitivity_analysis_results.push(result.clone());
785        Ok(result)
786    }
787
788    /// Generate comprehensive advanced ML debugging report
789    pub async fn generate_report(&self) -> Result<AdvancedMLDebuggingReport> {
790        Ok(AdvancedMLDebuggingReport {
791            timestamp: Utc::now(),
792            config: self.config.clone(),
793            lr_analysis_count: self.lr_analysis_results.len(),
794            sensitivity_analysis_count: self.sensitivity_analysis_results.len(),
795            recent_lr_analyses: self.lr_analysis_results.iter().rev().take(3).cloned().collect(),
796            recent_sensitivity_analyses: self
797                .sensitivity_analysis_results
798                .iter()
799                .rev()
800                .take(3)
801                .cloned()
802                .collect(),
803            advanced_insights: self.generate_advanced_insights(),
804        })
805    }
806
807    // Helper methods for layer-wise LR analysis
808
809    fn analyze_single_layer_lr(
810        &self,
811        layer_id: &str,
812        gradients: &ArrayD<f32>,
813        weights: &ArrayD<f32>,
814        current_lr: f64,
815        loss_history: &[f64],
816    ) -> LayerLRRecommendation {
817        // Calculate gradient statistics
818        let gradient_magnitude =
819            gradients.iter().map(|&x| x.abs() as f64).sum::<f64>() / gradients.len() as f64;
820        let weight_magnitude =
821            weights.iter().map(|&x| x.abs() as f64).sum::<f64>() / weights.len() as f64;
822
823        // Estimate optimal learning rate based on gradient properties
824        let gradient_variance =
825            gradients.iter().map(|&x| (x as f64 - gradient_magnitude).powi(2)).sum::<f64>()
826                / gradients.len() as f64;
827
828        let gradient_norm = gradient_magnitude;
829        let recommended_lr = if gradient_norm > 0.0 {
830            // Adaptive learning rate based on gradient properties
831            let base_lr = 0.001;
832            let adaptation_factor = (1.0 / (1.0 + gradient_variance)).sqrt();
833            let magnitude_factor = (1.0 / (1.0 + gradient_norm)).sqrt();
834            base_lr * adaptation_factor * magnitude_factor * 10.0
835        } else {
836            current_lr
837        };
838
839        // Calculate layer metrics
840        let layer_metrics = LayerLRMetrics {
841            gradient_magnitude,
842            weight_update_magnitude: gradient_magnitude * current_lr,
843            parameter_norm: weight_magnitude,
844            loss_contribution: self.estimate_layer_loss_contribution(loss_history),
845            stability_score: self.calculate_layer_stability(gradients),
846            convergence_rate: self.estimate_convergence_rate(loss_history),
847            learning_efficiency: gradient_magnitude / (weight_magnitude + 1e-8),
848        };
849
850        // Determine urgency
851        let lr_ratio = recommended_lr / current_lr;
852        let urgency = if lr_ratio > 10.0 || lr_ratio < 0.1 {
853            AdaptationUrgency::Critical
854        } else if lr_ratio > 3.0 || lr_ratio < 0.33 {
855            AdaptationUrgency::High
856        } else if lr_ratio > 1.5 || lr_ratio < 0.67 {
857            AdaptationUrgency::Medium
858        } else {
859            AdaptationUrgency::Low
860        };
861
862        // Generate reasoning
863        let reasoning = if recommended_lr > current_lr * 1.2 {
864            "Layer shows slow learning with small gradients, increase learning rate".to_string()
865        } else if recommended_lr < current_lr * 0.8 {
866            "Layer shows instability or large gradients, decrease learning rate".to_string()
867        } else {
868            "Current learning rate appears appropriate for this layer".to_string()
869        };
870
871        LayerLRRecommendation {
872            layer_id: layer_id.to_string(),
873            layer_type: self.infer_layer_type(layer_id),
874            current_lr,
875            recommended_lr,
876            confidence: 0.8, // Would be calculated based on statistical confidence
877            reasoning,
878            layer_metrics,
879            lr_sensitivity: lr_ratio.abs(),
880            urgency,
881        }
882    }
883
884    fn generate_global_lr_insights(
885        &self,
886        layer_recommendations: &HashMap<String, LayerLRRecommendation>,
887        loss_history: &[f64],
888    ) -> GlobalLRInsights {
889        let overall_efficiency = layer_recommendations
890            .values()
891            .map(|rec| rec.layer_metrics.learning_efficiency)
892            .sum::<f64>()
893            / layer_recommendations.len() as f64;
894
895        let lr_distribution_health = self.calculate_lr_distribution_health(layer_recommendations);
896        let gradient_flow_quality = self.calculate_gradient_flow_quality(layer_recommendations);
897        let training_stability =
898            self.assess_training_stability(layer_recommendations, loss_history);
899        let global_adjustments = self.generate_global_adjustments(layer_recommendations);
900        let critical_issues = self.identify_critical_issues(layer_recommendations);
901
902        GlobalLRInsights {
903            overall_efficiency,
904            lr_distribution_health,
905            gradient_flow_quality,
906            training_stability,
907            global_adjustments,
908            critical_issues,
909        }
910    }
911
912    fn create_lr_adaptation_strategy(
913        &self,
914        _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
915        global_insights: &GlobalLRInsights,
916    ) -> LRAdaptationStrategy {
917        // Simplified strategy creation
918        let strategy_name = if global_insights.overall_efficiency < 0.5 {
919            "Aggressive Learning Rate Adaptation".to_string()
920        } else {
921            "Conservative Learning Rate Tuning".to_string()
922        };
923
924        LRAdaptationStrategy {
925            strategy_name: strategy_name.clone(),
926            description: format!(
927                "Strategy to optimize learning rates based on current model state"
928            ),
929            implementation_steps: vec![ImplementationStep {
930                step_number: 1,
931                description: "Implement layer-wise learning rate multipliers".to_string(),
932                code_changes: vec!["Add lr_multipliers to optimizer config".to_string()],
933                timeline: "1-2 days".to_string(),
934                dependencies: vec!["Optimizer modification".to_string()],
935            }],
936            expected_benefits: vec![
937                "Improved convergence speed".to_string(),
938                "Better training stability".to_string(),
939                "Reduced overfitting risk".to_string(),
940            ],
941            potential_risks: vec!["Initial instability during adaptation".to_string()],
942            success_metrics: vec![
943                "Faster loss reduction".to_string(),
944                "Improved validation accuracy".to_string(),
945            ],
946            monitoring_requirements: vec!["Track per-layer gradient norms".to_string()],
947        }
948    }
949
950    fn generate_training_phase_recommendations(
951        &self,
952        _strategy: &LRAdaptationStrategy,
953    ) -> Vec<TrainingPhaseRecommendation> {
954        vec![TrainingPhaseRecommendation {
955            phase_name: "Warmup Phase".to_string(),
956            duration_epochs: 5,
957            lr_schedule: LRSchedule {
958                schedule_type: LRScheduleType::LinearDecay,
959                initial_lr: 0.0001,
960                parameters: HashMap::new(),
961                layer_multipliers: HashMap::new(),
962            },
963            objectives: vec!["Stabilize training".to_string()],
964            success_criteria: vec!["Decreasing loss".to_string()],
965            transition_conditions: vec!["Stable gradient norms".to_string()],
966        }]
967    }
968
969    fn predict_lr_schedule_performance(
970        &self,
971        _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
972    ) -> Vec<LRSchedulePrediction> {
973        vec![LRSchedulePrediction {
974            schedule: LRSchedule {
975                schedule_type: LRScheduleType::ExponentialDecay,
976                initial_lr: 0.001,
977                parameters: HashMap::new(),
978                layer_multipliers: HashMap::new(),
979            },
980            predicted_accuracy: 0.92,
981            predicted_convergence_epochs: 50,
982            predicted_stability: 0.8,
983            prediction_confidence: 0.7,
984            risk_assessment: RiskAssessment {
985                overall_risk: RiskLevel::Medium,
986                specific_risks: vec![],
987                mitigation_strategies: vec![],
988            },
989        }]
990    }
991
992    // Helper methods for sensitivity analysis
993
994    fn analyze_hyperparameter_sensitivity(
995        &self,
996        params: &HashMap<String, f64>,
997        _metrics: &[f64],
998    ) -> HyperparameterSensitivity {
999        let learning_rate_sensitivity = ParameterSensitivity {
1000            parameter_name: "learning_rate".to_string(),
1001            current_value: params.get("learning_rate").copied().unwrap_or(0.001),
1002            sensitivity_score: 0.8,
1003            optimal_range: (0.0001, 0.01),
1004            impact_curve: vec![(0.0001, 0.7), (0.001, 0.9), (0.01, 0.85)],
1005            stability_region: (0.0005, 0.005),
1006            critical_thresholds: vec![0.0001, 0.1],
1007        };
1008
1009        let batch_size_sensitivity = ParameterSensitivity {
1010            parameter_name: "batch_size".to_string(),
1011            current_value: params.get("batch_size").copied().unwrap_or(32.0),
1012            sensitivity_score: 0.6,
1013            optimal_range: (16.0, 128.0),
1014            impact_curve: vec![(16.0, 0.85), (32.0, 0.9), (64.0, 0.88), (128.0, 0.82)],
1015            stability_region: (16.0, 64.0),
1016            critical_thresholds: vec![8.0, 256.0],
1017        };
1018
1019        let regularization_sensitivity = ParameterSensitivity {
1020            parameter_name: "weight_decay".to_string(),
1021            current_value: params.get("weight_decay").copied().unwrap_or(0.01),
1022            sensitivity_score: 0.4,
1023            optimal_range: (0.001, 0.1),
1024            impact_curve: vec![(0.001, 0.88), (0.01, 0.9), (0.1, 0.87)],
1025            stability_region: (0.005, 0.05),
1026            critical_thresholds: vec![0.0001, 1.0],
1027        };
1028
1029        HyperparameterSensitivity {
1030            learning_rate_sensitivity,
1031            batch_size_sensitivity,
1032            regularization_sensitivity,
1033            architecture_param_sensitivity: HashMap::new(),
1034            most_sensitive_params: vec!["learning_rate".to_string(), "batch_size".to_string()],
1035            least_sensitive_params: vec!["weight_decay".to_string()],
1036            interaction_effects: vec![],
1037        }
1038    }
1039
1040    fn analyze_architecture_sensitivity(
1041        &self,
1042        _config: &HashMap<String, f64>,
1043        _metrics: &[f64],
1044    ) -> ArchitectureSensitivity {
1045        ArchitectureSensitivity {
1046            depth_sensitivity: ArchitecturalSensitivity {
1047                component_name: "model_depth".to_string(),
1048                change_sensitivity: 0.7,
1049                degradation_curve: vec![(6.0, 0.85), (12.0, 0.9), (24.0, 0.88)],
1050                min_viable_config: 6.0,
1051                optimal_config: 12.0,
1052                diminishing_returns_threshold: 18.0,
1053            },
1054            width_sensitivity: ArchitecturalSensitivity {
1055                component_name: "hidden_size".to_string(),
1056                change_sensitivity: 0.6,
1057                degradation_curve: vec![(256.0, 0.82), (512.0, 0.9), (1024.0, 0.91)],
1058                min_viable_config: 256.0,
1059                optimal_config: 512.0,
1060                diminishing_returns_threshold: 768.0,
1061            },
1062            attention_head_sensitivity: ArchitecturalSensitivity {
1063                component_name: "num_attention_heads".to_string(),
1064                change_sensitivity: 0.5,
1065                degradation_curve: vec![(4.0, 0.87), (8.0, 0.9), (16.0, 0.89)],
1066                min_viable_config: 4.0,
1067                optimal_config: 8.0,
1068                diminishing_returns_threshold: 12.0,
1069            },
1070            skip_connection_sensitivity: ArchitecturalSensitivity {
1071                component_name: "skip_connections".to_string(),
1072                change_sensitivity: 0.8,
1073                degradation_curve: vec![(0.0, 0.75), (1.0, 0.9)],
1074                min_viable_config: 1.0,
1075                optimal_config: 1.0,
1076                diminishing_returns_threshold: 1.0,
1077            },
1078            component_importance: HashMap::new(),
1079            bottlenecks: vec![],
1080        }
1081    }
1082
1083    fn analyze_data_sensitivity(&self, _metrics: &[f64]) -> DataSensitivity {
1084        DataSensitivity {
1085            data_size_sensitivity: DataSizeSensitivity {
1086                current_size: 10000,
1087                minimum_effective_size: 1000,
1088                performance_curve: vec![(1000, 0.7), (5000, 0.85), (10000, 0.9), (20000, 0.92)],
1089                data_efficiency: 0.85,
1090                diminishing_returns_point: 15000,
1091            },
1092            data_quality_sensitivity: DataQualitySensitivity {
1093                noise_tolerance: 0.1,
1094                label_quality_importance: 0.9,
1095                feature_quality_importance: 0.7,
1096                quality_impact_curve: vec![(0.9, 0.9), (0.8, 0.85), (0.7, 0.75)],
1097            },
1098            distribution_sensitivity: DistributionSensitivity {
1099                shift_sensitivity: 0.6,
1100                imbalance_sensitivity: 0.5,
1101                domain_adaptation_requirements: vec!["Gradual domain adaptation".to_string()],
1102                distribution_robustness: 0.7,
1103            },
1104            feature_sensitivity: FeatureSensitivityAnalysis {
1105                most_important_features: vec!["feature_1".to_string(), "feature_2".to_string()],
1106                least_important_features: vec!["feature_10".to_string()],
1107                feature_interactions: HashMap::new(),
1108                feature_stability: HashMap::new(),
1109            },
1110        }
1111    }
1112
1113    fn analyze_training_sensitivity(
1114        &self,
1115        _params: &HashMap<String, f64>,
1116        _metrics: &[f64],
1117    ) -> TrainingSensitivity {
1118        TrainingSensitivity {
1119            initialization_sensitivity: InitializationSensitivity {
1120                weight_init_sensitivity: 0.6,
1121                bias_init_sensitivity: 0.3,
1122                seed_sensitivity: 0.2,
1123                scheme_importance: HashMap::new(),
1124            },
1125            optimization_sensitivity: OptimizationSensitivity {
1126                optimizer_sensitivity: 0.7,
1127                momentum_sensitivity: 0.5,
1128                second_moment_sensitivity: 0.4,
1129                optimizer_comparison: HashMap::new(),
1130            },
1131            schedule_sensitivity: ScheduleSensitivity {
1132                lr_schedule_sensitivity: 0.8,
1133                duration_sensitivity: 0.6,
1134                warmup_sensitivity: 0.4,
1135                schedule_param_importance: HashMap::new(),
1136            },
1137            regularization_sensitivity: RegularizationSensitivity {
1138                dropout_sensitivity: 0.5,
1139                weight_decay_sensitivity: 0.4,
1140                batch_norm_sensitivity: 0.6,
1141                method_comparison: HashMap::new(),
1142            },
1143        }
1144    }
1145
1146    fn generate_sensitivity_insights(
1147        &self,
1148        _hyper_sens: &HyperparameterSensitivity,
1149        _arch_sens: &ArchitectureSensitivity,
1150        _data_sens: &DataSensitivity,
1151        _training_sens: &TrainingSensitivity,
1152    ) -> SensitivityInsights {
1153        SensitivityInsights {
1154            most_critical_factors: vec![
1155                "learning_rate".to_string(),
1156                "model_depth".to_string(),
1157                "skip_connections".to_string(),
1158            ],
1159            least_critical_factors: vec!["bias_initialization".to_string()],
1160            surprising_findings: vec!["Batch size has higher than expected impact".to_string()],
1161            robustness_assessment: RobustnessAssessment {
1162                overall_robustness: 0.7,
1163                category_robustness: HashMap::new(),
1164                vulnerabilities: vec![],
1165                strengths: vec!["Good hyperparameter stability".to_string()],
1166            },
1167            optimization_recommendations: vec![
1168                "Focus on learning rate tuning first".to_string(),
1169                "Consider architectural modifications second".to_string(),
1170            ],
1171        }
1172    }
1173
1174    // Additional helper methods
1175
1176    fn estimate_layer_loss_contribution(&self, loss_history: &[f64]) -> f64 {
1177        // Simplified estimation
1178        if loss_history.len() >= 2 {
1179            (loss_history[loss_history.len() - 2] - loss_history[loss_history.len() - 1]).abs()
1180        } else {
1181            0.1
1182        }
1183    }
1184
1185    fn calculate_layer_stability(&self, gradients: &ArrayD<f32>) -> f64 {
1186        let gradient_variance = gradients.iter().map(|&x| x as f64).collect::<Vec<_>>();
1187
1188        if gradient_variance.is_empty() {
1189            return 0.5;
1190        }
1191
1192        let mean = gradient_variance.iter().sum::<f64>() / gradient_variance.len() as f64;
1193        let variance = gradient_variance.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
1194            / gradient_variance.len() as f64;
1195
1196        1.0 / (1.0 + variance) // Higher stability for lower variance
1197    }
1198
1199    fn estimate_convergence_rate(&self, loss_history: &[f64]) -> f64 {
1200        if loss_history.len() < 3 {
1201            return 0.5;
1202        }
1203
1204        let recent_improvement =
1205            loss_history[loss_history.len() - 3] - loss_history[loss_history.len() - 1];
1206        recent_improvement.abs()
1207    }
1208
1209    fn infer_layer_type(&self, layer_id: &str) -> String {
1210        if layer_id.contains("attention") {
1211            "attention".to_string()
1212        } else if layer_id.contains("feedforward") || layer_id.contains("mlp") {
1213            "feedforward".to_string()
1214        } else if layer_id.contains("embedding") {
1215            "embedding".to_string()
1216        } else {
1217            "unknown".to_string()
1218        }
1219    }
1220
1221    fn calculate_lr_distribution_health(
1222        &self,
1223        recommendations: &HashMap<String, LayerLRRecommendation>,
1224    ) -> f64 {
1225        let lr_ratios: Vec<f64> = recommendations
1226            .values()
1227            .map(|rec| rec.recommended_lr / rec.current_lr)
1228            .collect();
1229
1230        if lr_ratios.is_empty() {
1231            return 0.5;
1232        }
1233
1234        let mean_ratio = lr_ratios.iter().sum::<f64>() / lr_ratios.len() as f64;
1235        let variance = lr_ratios.iter().map(|&x| (x - mean_ratio).powi(2)).sum::<f64>()
1236            / lr_ratios.len() as f64;
1237
1238        1.0 / (1.0 + variance) // Better health for lower variance
1239    }
1240
1241    fn calculate_gradient_flow_quality(
1242        &self,
1243        recommendations: &HashMap<String, LayerLRRecommendation>,
1244    ) -> f64 {
1245        recommendations
1246            .values()
1247            .map(|rec| rec.layer_metrics.stability_score)
1248            .sum::<f64>()
1249            / recommendations.len() as f64
1250    }
1251
1252    fn assess_training_stability(
1253        &self,
1254        recommendations: &HashMap<String, LayerLRRecommendation>,
1255        _loss_history: &[f64],
1256    ) -> TrainingStability {
1257        let stability_score = recommendations
1258            .values()
1259            .map(|rec| rec.layer_metrics.stability_score)
1260            .sum::<f64>()
1261            / recommendations.len() as f64;
1262
1263        TrainingStability {
1264            stability_score,
1265            instability_indicators: vec![],
1266            stability_trends: vec![],
1267            predicted_stability: stability_score * 0.9, // Slightly pessimistic prediction
1268        }
1269    }
1270
1271    fn generate_global_adjustments(
1272        &self,
1273        _recommendations: &HashMap<String, LayerLRRecommendation>,
1274    ) -> Vec<GlobalLRAdjustment> {
1275        vec![GlobalLRAdjustment {
1276            adjustment_type: GlobalAdjustmentType::LayerTypeSpecific,
1277            magnitude: 1.5,
1278            expected_impact: 0.1,
1279            priority: AdjustmentPriority::Medium,
1280            instructions: "Apply different learning rates to attention vs feedforward layers"
1281                .to_string(),
1282        }]
1283    }
1284
1285    fn identify_critical_issues(
1286        &self,
1287        recommendations: &HashMap<String, LayerLRRecommendation>,
1288    ) -> Vec<String> {
1289        let mut issues = Vec::new();
1290
1291        for recommendation in recommendations.values() {
1292            if matches!(recommendation.urgency, AdaptationUrgency::Critical) {
1293                issues.push(format!(
1294                    "Critical learning rate issue in layer {}",
1295                    recommendation.layer_id
1296                ));
1297            }
1298        }
1299
1300        issues
1301    }
1302
1303    fn generate_advanced_insights(&self) -> HashMap<String, String> {
1304        let mut insights = HashMap::new();
1305
1306        insights.insert(
1307            "total_lr_analyses".to_string(),
1308            self.lr_analysis_results.len().to_string(),
1309        );
1310        insights.insert(
1311            "total_sensitivity_analyses".to_string(),
1312            self.sensitivity_analysis_results.len().to_string(),
1313        );
1314
1315        if let Some(latest_lr) = self.lr_analysis_results.last() {
1316            insights.insert(
1317                "latest_lr_efficiency".to_string(),
1318                format!("{:.2}", latest_lr.global_lr_insights.overall_efficiency),
1319            );
1320        }
1321
1322        insights
1323    }
1324}
1325
1326/// Comprehensive advanced ML debugging report
1327#[derive(Debug, Clone, Serialize, Deserialize)]
1328pub struct AdvancedMLDebuggingReport {
1329    pub timestamp: DateTime<Utc>,
1330    pub config: AdvancedMLDebuggingConfig,
1331    pub lr_analysis_count: usize,
1332    pub sensitivity_analysis_count: usize,
1333    pub recent_lr_analyses: Vec<LayerWiseLRAnalysisResult>,
1334    pub recent_sensitivity_analyses: Vec<ModelSensitivityAnalysisResult>,
1335    pub advanced_insights: HashMap<String, String>,
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340    use super::*;
1341
1342    #[tokio::test]
1343    async fn test_advanced_ml_debugger_creation() {
1344        let config = AdvancedMLDebuggingConfig::default();
1345        let debugger = AdvancedMLDebugger::new(config);
1346        assert_eq!(debugger.lr_analysis_results.len(), 0);
1347    }
1348
1349    #[tokio::test]
1350    async fn test_layer_wise_lr_analysis() {
1351        let config = AdvancedMLDebuggingConfig::default();
1352        let mut debugger = AdvancedMLDebugger::new(config);
1353
1354        let mut layer_gradients = HashMap::new();
1355        let mut layer_weights = HashMap::new();
1356
1357        // Create test data
1358        let gradients =
1359            ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.01).collect())
1360                .unwrap();
1361        let weights =
1362            ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.1).collect())
1363                .unwrap();
1364
1365        layer_gradients.insert("layer_0".to_string(), gradients);
1366        layer_weights.insert("layer_0".to_string(), weights);
1367
1368        let loss_history = vec![1.0, 0.8, 0.6, 0.5];
1369
1370        let result = debugger
1371            .analyze_layer_wise_learning_rates(
1372                &layer_gradients,
1373                &layer_weights,
1374                0.001,
1375                &loss_history,
1376            )
1377            .await;
1378        assert!(result.is_ok());
1379
1380        let analysis = result.unwrap();
1381        assert_eq!(analysis.layer_lr_recommendations.len(), 1);
1382        assert!(analysis.layer_lr_recommendations.contains_key("layer_0"));
1383    }
1384}