Skip to main content

trustformers_debug/
advanced_ml_debugging.rs

1//! # Advanced ML Debugging Tools
2//!
3//! Advanced machine learning specific debugging techniques including layer-wise learning rate adaptation,
4//! model sensitivity analysis, gradient flow optimization, and neural architecture debugging.
5
6use anyhow::Result;
7use chrono::{DateTime, Utc};
8use scirs2_core::ndarray::*; // SciRS2 Integration Policy - was: use ndarray::{Array1, Array2, Array3, ArrayD};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Configuration for advanced ML debugging
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AdvancedMLDebuggingConfig {
15    /// Enable layer-wise learning rate analysis
16    pub enable_layer_wise_lr_analysis: bool,
17    /// Enable model sensitivity analysis
18    pub enable_model_sensitivity_analysis: bool,
19    /// Enable gradient flow optimization analysis
20    pub enable_gradient_flow_optimization: bool,
21    /// Enable neural architecture debugging
22    pub enable_neural_architecture_debugging: bool,
23    /// Enable activation pattern analysis
24    pub enable_activation_pattern_analysis: bool,
25    /// Enable weight distribution analysis
26    pub enable_weight_distribution_analysis: bool,
27    /// Enable training dynamics analysis
28    pub enable_training_dynamics_analysis: bool,
29    /// Enable optimization landscape analysis
30    pub enable_optimization_landscape_analysis: bool,
31    /// Number of samples for sensitivity analysis
32    pub sensitivity_samples: usize,
33    /// Learning rate adaptation threshold
34    pub lr_adaptation_threshold: f64,
35    /// Maximum number of layers to analyze
36    pub max_layers_to_analyze: usize,
37}
38
39impl Default for AdvancedMLDebuggingConfig {
40    fn default() -> Self {
41        Self {
42            enable_layer_wise_lr_analysis: true,
43            enable_model_sensitivity_analysis: true,
44            enable_gradient_flow_optimization: true,
45            enable_neural_architecture_debugging: true,
46            enable_activation_pattern_analysis: true,
47            enable_weight_distribution_analysis: true,
48            enable_training_dynamics_analysis: true,
49            enable_optimization_landscape_analysis: true,
50            sensitivity_samples: 1000,
51            lr_adaptation_threshold: 0.1,
52            max_layers_to_analyze: 50,
53        }
54    }
55}
56
57/// Layer-wise learning rate adaptation analysis result
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LayerWiseLRAnalysisResult {
60    /// Analysis timestamp
61    pub timestamp: DateTime<Utc>,
62    /// Learning rate recommendations per layer
63    pub layer_lr_recommendations: HashMap<String, LayerLRRecommendation>,
64    /// Global learning rate insights
65    pub global_lr_insights: GlobalLRInsights,
66    /// Learning rate adaptation strategy
67    pub adaptation_strategy: LRAdaptationStrategy,
68    /// Training phase recommendations
69    pub training_phase_recommendations: Vec<TrainingPhaseRecommendation>,
70    /// Performance predictions with different LR schedules
71    pub lr_schedule_predictions: Vec<LRSchedulePrediction>,
72}
73
74/// Learning rate recommendation for a specific layer
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LayerLRRecommendation {
77    /// Layer identifier
78    pub layer_id: String,
79    /// Layer type (e.g., "attention", "feedforward", "embedding")
80    pub layer_type: String,
81    /// Current learning rate
82    pub current_lr: f64,
83    /// Recommended learning rate
84    pub recommended_lr: f64,
85    /// Recommendation confidence
86    pub confidence: f64,
87    /// Reasoning for recommendation
88    pub reasoning: String,
89    /// Layer-specific metrics
90    pub layer_metrics: LayerLRMetrics,
91    /// Sensitivity to learning rate changes
92    pub lr_sensitivity: f64,
93    /// Adaptation urgency level
94    pub urgency: AdaptationUrgency,
95}
96
97/// Layer-specific learning rate metrics
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct LayerLRMetrics {
100    /// Gradient magnitude
101    pub gradient_magnitude: f64,
102    /// Weight update magnitude
103    pub weight_update_magnitude: f64,
104    /// Parameter norm
105    pub parameter_norm: f64,
106    /// Loss contribution
107    pub loss_contribution: f64,
108    /// Training stability score
109    pub stability_score: f64,
110    /// Convergence rate
111    pub convergence_rate: f64,
112    /// Learning efficiency
113    pub learning_efficiency: f64,
114}
115
116/// Urgency level for learning rate adaptation
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub enum AdaptationUrgency {
119    Low,
120    Medium,
121    High,
122    Critical,
123}
124
125/// Global learning rate insights
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GlobalLRInsights {
128    /// Overall model learning efficiency
129    pub overall_efficiency: f64,
130    /// Learning rate distribution health
131    pub lr_distribution_health: f64,
132    /// Gradient flow quality
133    pub gradient_flow_quality: f64,
134    /// Training stability assessment
135    pub training_stability: TrainingStability,
136    /// Recommended global adjustments
137    pub global_adjustments: Vec<GlobalLRAdjustment>,
138    /// Critical issues requiring immediate attention
139    pub critical_issues: Vec<String>,
140}
141
142/// Training stability assessment
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TrainingStability {
145    /// Stability score (0-1)
146    pub stability_score: f64,
147    /// Instability indicators
148    pub instability_indicators: Vec<InstabilityIndicator>,
149    /// Stability trends over time
150    pub stability_trends: Vec<StabilityTrendPoint>,
151    /// Predicted stability with current settings
152    pub predicted_stability: f64,
153}
154
155/// Indicator of training instability
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct InstabilityIndicator {
158    /// Type of instability
159    pub instability_type: InstabilityType,
160    /// Severity level
161    pub severity: f64,
162    /// Affected layers
163    pub affected_layers: Vec<String>,
164    /// Recommended actions
165    pub recommended_actions: Vec<String>,
166}
167
168/// Type of training instability
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum InstabilityType {
171    GradientExplosion,
172    GradientVanishing,
173    OscillatingLoss,
174    SlowConvergence,
175    WeightDivergence,
176    NumericalInstability,
177}
178
179/// Point in stability trend analysis
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct StabilityTrendPoint {
182    /// Time step or epoch
183    pub time_step: usize,
184    /// Stability score at this point
185    pub stability_score: f64,
186    /// Contributing factors
187    pub contributing_factors: HashMap<String, f64>,
188}
189
190/// Global learning rate adjustment recommendation
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GlobalLRAdjustment {
193    /// Adjustment type
194    pub adjustment_type: GlobalAdjustmentType,
195    /// Adjustment magnitude
196    pub magnitude: f64,
197    /// Expected impact
198    pub expected_impact: f64,
199    /// Implementation priority
200    pub priority: AdjustmentPriority,
201    /// Implementation instructions
202    pub instructions: String,
203}
204
205/// Type of global learning rate adjustment
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub enum GlobalAdjustmentType {
208    UniformScaling,
209    LayerTypeSpecific,
210    DepthDependent,
211    AdaptiveScheduling,
212    WarmupAdjustment,
213    DecayRateModification,
214}
215
216/// Priority level for adjustments
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub enum AdjustmentPriority {
219    Low,
220    Medium,
221    High,
222    Immediate,
223}
224
225/// Learning rate adaptation strategy
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LRAdaptationStrategy {
228    /// Strategy name
229    pub strategy_name: String,
230    /// Strategy description
231    pub description: String,
232    /// Implementation steps
233    pub implementation_steps: Vec<ImplementationStep>,
234    /// Expected benefits
235    pub expected_benefits: Vec<String>,
236    /// Potential risks
237    pub potential_risks: Vec<String>,
238    /// Success metrics
239    pub success_metrics: Vec<String>,
240    /// Monitoring requirements
241    pub monitoring_requirements: Vec<String>,
242}
243
244/// Step in implementing an adaptation strategy
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ImplementationStep {
247    /// Step number
248    pub step_number: usize,
249    /// Step description
250    pub description: String,
251    /// Code changes required
252    pub code_changes: Vec<String>,
253    /// Expected timeline
254    pub timeline: String,
255    /// Dependencies
256    pub dependencies: Vec<String>,
257}
258
259/// Training phase recommendation
260#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct TrainingPhaseRecommendation {
262    /// Phase name
263    pub phase_name: String,
264    /// Phase duration (epochs)
265    pub duration_epochs: usize,
266    /// Learning rate schedule for this phase
267    pub lr_schedule: LRSchedule,
268    /// Phase objectives
269    pub objectives: Vec<String>,
270    /// Success criteria
271    pub success_criteria: Vec<String>,
272    /// Transition conditions
273    pub transition_conditions: Vec<String>,
274}
275
276/// Learning rate schedule definition
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct LRSchedule {
279    /// Schedule type
280    pub schedule_type: LRScheduleType,
281    /// Initial learning rate
282    pub initial_lr: f64,
283    /// Schedule parameters
284    pub parameters: HashMap<String, f64>,
285    /// Layer-specific multipliers
286    pub layer_multipliers: HashMap<String, f64>,
287}
288
289/// Type of learning rate schedule
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum LRScheduleType {
292    Constant,
293    LinearDecay,
294    ExponentialDecay,
295    CosineAnnealing,
296    StepDecay,
297    CyclicalLR,
298    OneCycleLR,
299    AdaptiveSchedule,
300}
301
302/// Prediction of performance with different LR schedules
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct LRSchedulePrediction {
305    /// Schedule being evaluated
306    pub schedule: LRSchedule,
307    /// Predicted final accuracy
308    pub predicted_accuracy: f64,
309    /// Predicted convergence time
310    pub predicted_convergence_epochs: usize,
311    /// Predicted training stability
312    pub predicted_stability: f64,
313    /// Confidence in prediction
314    pub prediction_confidence: f64,
315    /// Risk assessment
316    pub risk_assessment: RiskAssessment,
317}
318
319/// Risk assessment for a learning rate schedule
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct RiskAssessment {
322    /// Overall risk level
323    pub overall_risk: RiskLevel,
324    /// Specific risks
325    pub specific_risks: Vec<SpecificRisk>,
326    /// Mitigation strategies
327    pub mitigation_strategies: Vec<String>,
328}
329
330/// Risk level assessment
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub enum RiskLevel {
333    VeryLow,
334    Low,
335    Medium,
336    High,
337    VeryHigh,
338}
339
340/// Specific risk in training
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct SpecificRisk {
343    /// Risk type
344    pub risk_type: String,
345    /// Probability of occurrence
346    pub probability: f64,
347    /// Impact severity
348    pub impact: f64,
349    /// Description
350    pub description: String,
351}
352
353/// Model sensitivity analysis result
354#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ModelSensitivityAnalysisResult {
356    /// Analysis timestamp
357    pub timestamp: DateTime<Utc>,
358    /// Hyperparameter sensitivity analysis
359    pub hyperparameter_sensitivity: HyperparameterSensitivity,
360    /// Architecture sensitivity analysis
361    pub architecture_sensitivity: ArchitectureSensitivity,
362    /// Data sensitivity analysis
363    pub data_sensitivity: DataSensitivity,
364    /// Training procedure sensitivity
365    pub training_sensitivity: TrainingSensitivity,
366    /// Overall sensitivity insights
367    pub sensitivity_insights: SensitivityInsights,
368}
369
370/// Hyperparameter sensitivity analysis
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct HyperparameterSensitivity {
373    /// Learning rate sensitivity
374    pub learning_rate_sensitivity: ParameterSensitivity,
375    /// Batch size sensitivity
376    pub batch_size_sensitivity: ParameterSensitivity,
377    /// Regularization sensitivity
378    pub regularization_sensitivity: ParameterSensitivity,
379    /// Architecture parameter sensitivity
380    pub architecture_param_sensitivity: HashMap<String, ParameterSensitivity>,
381    /// Most sensitive parameters
382    pub most_sensitive_params: Vec<String>,
383    /// Least sensitive parameters
384    pub least_sensitive_params: Vec<String>,
385    /// Parameter interaction effects
386    pub interaction_effects: Vec<ParameterInteraction>,
387}
388
389/// Sensitivity analysis for a specific parameter
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct ParameterSensitivity {
392    /// Parameter name
393    pub parameter_name: String,
394    /// Current value
395    pub current_value: f64,
396    /// Sensitivity score
397    pub sensitivity_score: f64,
398    /// Optimal value range
399    pub optimal_range: (f64, f64),
400    /// Performance impact curve
401    pub impact_curve: Vec<(f64, f64)>,
402    /// Stability region
403    pub stability_region: (f64, f64),
404    /// Critical thresholds
405    pub critical_thresholds: Vec<f64>,
406}
407
408/// Interaction between parameters
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct ParameterInteraction {
411    /// First parameter
412    pub param1: String,
413    /// Second parameter
414    pub param2: String,
415    /// Interaction strength
416    pub interaction_strength: f64,
417    /// Interaction type
418    pub interaction_type: InteractionType,
419    /// Joint optimal region
420    pub joint_optimal_region: HashMap<String, (f64, f64)>,
421}
422
423/// Type of parameter interaction
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub enum InteractionType {
426    Synergistic,
427    Antagonistic,
428    Independent,
429    Conditional,
430}
431
432/// Architecture sensitivity analysis
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ArchitectureSensitivity {
435    /// Layer depth sensitivity
436    pub depth_sensitivity: ArchitecturalSensitivity,
437    /// Layer width sensitivity
438    pub width_sensitivity: ArchitecturalSensitivity,
439    /// Attention head sensitivity
440    pub attention_head_sensitivity: ArchitecturalSensitivity,
441    /// Skip connection sensitivity
442    pub skip_connection_sensitivity: ArchitecturalSensitivity,
443    /// Architectural component importance
444    pub component_importance: HashMap<String, f64>,
445    /// Architectural bottlenecks
446    pub bottlenecks: Vec<ArchitecturalBottleneck>,
447}
448
449/// Sensitivity analysis for architectural component
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct ArchitecturalSensitivity {
452    /// Component name
453    pub component_name: String,
454    /// Sensitivity to changes
455    pub change_sensitivity: f64,
456    /// Performance degradation curve
457    pub degradation_curve: Vec<(f64, f64)>,
458    /// Minimum viable configuration
459    pub min_viable_config: f64,
460    /// Optimal configuration
461    pub optimal_config: f64,
462    /// Diminishing returns threshold
463    pub diminishing_returns_threshold: f64,
464}
465
466/// Architectural bottleneck
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct ArchitecturalBottleneck {
469    /// Bottleneck location
470    pub location: String,
471    /// Bottleneck type
472    pub bottleneck_type: BottleneckType,
473    /// Severity
474    pub severity: f64,
475    /// Performance impact
476    pub performance_impact: f64,
477    /// Resolution recommendations
478    pub resolution_recommendations: Vec<String>,
479}
480
481/// Type of architectural bottleneck
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum BottleneckType {
484    ComputationalBottleneck,
485    MemoryBottleneck,
486    InformationBottleneck,
487    CapacityBottleneck,
488    CommunicationBottleneck,
489}
490
491/// Data sensitivity analysis
492#[derive(Debug, Clone, Serialize, Deserialize)]
493pub struct DataSensitivity {
494    /// Training data size sensitivity
495    pub data_size_sensitivity: DataSizeSensitivity,
496    /// Data quality sensitivity
497    pub data_quality_sensitivity: DataQualitySensitivity,
498    /// Data distribution sensitivity
499    pub distribution_sensitivity: DistributionSensitivity,
500    /// Feature sensitivity analysis
501    pub feature_sensitivity: FeatureSensitivityAnalysis,
502}
503
504/// Sensitivity to training data size
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct DataSizeSensitivity {
507    /// Current data size
508    pub current_size: usize,
509    /// Minimum effective size
510    pub minimum_effective_size: usize,
511    /// Performance vs size curve
512    pub performance_curve: Vec<(usize, f64)>,
513    /// Data efficiency score
514    pub data_efficiency: f64,
515    /// Diminishing returns point
516    pub diminishing_returns_point: usize,
517}
518
519/// Sensitivity to data quality
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct DataQualitySensitivity {
522    /// Noise tolerance
523    pub noise_tolerance: f64,
524    /// Label quality importance
525    pub label_quality_importance: f64,
526    /// Feature quality importance
527    pub feature_quality_importance: f64,
528    /// Quality degradation impact
529    pub quality_impact_curve: Vec<(f64, f64)>,
530}
531
532/// Sensitivity to data distribution
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct DistributionSensitivity {
535    /// Distribution shift sensitivity
536    pub shift_sensitivity: f64,
537    /// Class imbalance sensitivity
538    pub imbalance_sensitivity: f64,
539    /// Domain adaptation requirements
540    pub domain_adaptation_requirements: Vec<String>,
541    /// Robustness to distribution changes
542    pub distribution_robustness: f64,
543}
544
545/// Feature-level sensitivity analysis
546#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct FeatureSensitivityAnalysis {
548    /// Most important features
549    pub most_important_features: Vec<String>,
550    /// Least important features
551    pub least_important_features: Vec<String>,
552    /// Feature interaction importance
553    pub feature_interactions: HashMap<(String, String), f64>,
554    /// Feature stability analysis
555    pub feature_stability: HashMap<String, f64>,
556}
557
558/// Training procedure sensitivity
559#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct TrainingSensitivity {
561    /// Initialization sensitivity
562    pub initialization_sensitivity: InitializationSensitivity,
563    /// Optimization method sensitivity
564    pub optimization_sensitivity: OptimizationSensitivity,
565    /// Training schedule sensitivity
566    pub schedule_sensitivity: ScheduleSensitivity,
567    /// Regularization sensitivity
568    pub regularization_sensitivity: RegularizationSensitivity,
569}
570
571/// Sensitivity to initialization
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub struct InitializationSensitivity {
574    /// Weight initialization sensitivity
575    pub weight_init_sensitivity: f64,
576    /// Bias initialization sensitivity
577    pub bias_init_sensitivity: f64,
578    /// Random seed sensitivity
579    pub seed_sensitivity: f64,
580    /// Initialization scheme importance
581    pub scheme_importance: HashMap<String, f64>,
582}
583
584/// Sensitivity to optimization method
585#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct OptimizationSensitivity {
587    /// Optimizer choice sensitivity
588    pub optimizer_sensitivity: f64,
589    /// Momentum parameter sensitivity
590    pub momentum_sensitivity: f64,
591    /// Second-order moment sensitivity
592    pub second_moment_sensitivity: f64,
593    /// Optimizer comparison
594    pub optimizer_comparison: HashMap<String, f64>,
595}
596
597/// Sensitivity to training schedule
598#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct ScheduleSensitivity {
600    /// Learning rate schedule sensitivity
601    pub lr_schedule_sensitivity: f64,
602    /// Training duration sensitivity
603    pub duration_sensitivity: f64,
604    /// Warmup sensitivity
605    pub warmup_sensitivity: f64,
606    /// Schedule parameter importance
607    pub schedule_param_importance: HashMap<String, f64>,
608}
609
610/// Sensitivity to regularization
611#[derive(Debug, Clone, Serialize, Deserialize)]
612pub struct RegularizationSensitivity {
613    /// Dropout sensitivity
614    pub dropout_sensitivity: f64,
615    /// Weight decay sensitivity
616    pub weight_decay_sensitivity: f64,
617    /// Batch normalization sensitivity
618    pub batch_norm_sensitivity: f64,
619    /// Regularization method comparison
620    pub method_comparison: HashMap<String, f64>,
621}
622
623/// Overall sensitivity insights
624#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct SensitivityInsights {
626    /// Most critical factors
627    pub most_critical_factors: Vec<String>,
628    /// Least critical factors
629    pub least_critical_factors: Vec<String>,
630    /// Surprising findings
631    pub surprising_findings: Vec<String>,
632    /// Robustness assessment
633    pub robustness_assessment: RobustnessAssessment,
634    /// Optimization recommendations
635    pub optimization_recommendations: Vec<String>,
636}
637
638/// Model robustness assessment
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct RobustnessAssessment {
641    /// Overall robustness score
642    pub overall_robustness: f64,
643    /// Robustness breakdown by category
644    pub category_robustness: HashMap<String, f64>,
645    /// Vulnerability areas
646    pub vulnerabilities: Vec<Vulnerability>,
647    /// Strength areas
648    pub strengths: Vec<String>,
649}
650
651/// Model vulnerability
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct Vulnerability {
654    /// Vulnerability type
655    pub vulnerability_type: String,
656    /// Severity level
657    pub severity: f64,
658    /// Impact description
659    pub impact: String,
660    /// Mitigation strategies
661    pub mitigation_strategies: Vec<String>,
662}
663
664/// Advanced ML debugger
665#[derive(Debug)]
666pub struct AdvancedMLDebugger {
667    config: AdvancedMLDebuggingConfig,
668    lr_analysis_results: Vec<LayerWiseLRAnalysisResult>,
669    sensitivity_analysis_results: Vec<ModelSensitivityAnalysisResult>,
670}
671
672impl AdvancedMLDebugger {
673    /// Create a new advanced ML debugger
674    pub fn new(config: AdvancedMLDebuggingConfig) -> Self {
675        Self {
676            config,
677            lr_analysis_results: Vec::new(),
678            sensitivity_analysis_results: Vec::new(),
679        }
680    }
681
682    /// Perform layer-wise learning rate analysis
683    pub async fn analyze_layer_wise_learning_rates(
684        &mut self,
685        layer_gradients: &HashMap<String, ArrayD<f32>>,
686        layer_weights: &HashMap<String, ArrayD<f32>>,
687        current_lr: f64,
688        loss_history: &[f64],
689    ) -> Result<LayerWiseLRAnalysisResult> {
690        if !self.config.enable_layer_wise_lr_analysis {
691            return Err(anyhow::anyhow!(
692                "Layer-wise learning rate analysis is disabled"
693            ));
694        }
695
696        let mut layer_lr_recommendations = HashMap::new();
697
698        // Analyze each layer
699        for (layer_id, gradients) in layer_gradients {
700            if let Some(weights) = layer_weights.get(layer_id) {
701                let recommendation = self.analyze_single_layer_lr(
702                    layer_id,
703                    gradients,
704                    weights,
705                    current_lr,
706                    loss_history,
707                );
708                layer_lr_recommendations.insert(layer_id.clone(), recommendation);
709            }
710        }
711
712        // Generate global insights
713        let global_lr_insights =
714            self.generate_global_lr_insights(&layer_lr_recommendations, loss_history);
715
716        // Create adaptation strategy
717        let adaptation_strategy =
718            self.create_lr_adaptation_strategy(&layer_lr_recommendations, &global_lr_insights);
719
720        // Generate training phase recommendations
721        let training_phase_recommendations =
722            self.generate_training_phase_recommendations(&adaptation_strategy);
723
724        // Predict performance with different schedules
725        let lr_schedule_predictions =
726            self.predict_lr_schedule_performance(&layer_lr_recommendations);
727
728        let result = LayerWiseLRAnalysisResult {
729            timestamp: Utc::now(),
730            layer_lr_recommendations,
731            global_lr_insights,
732            adaptation_strategy,
733            training_phase_recommendations,
734            lr_schedule_predictions,
735        };
736
737        self.lr_analysis_results.push(result.clone());
738        Ok(result)
739    }
740
741    /// Perform comprehensive model sensitivity analysis
742    pub async fn analyze_model_sensitivity(
743        &mut self,
744        model_params: &HashMap<String, f64>,
745        performance_metrics: &[f64],
746        architecture_config: &HashMap<String, f64>,
747    ) -> Result<ModelSensitivityAnalysisResult> {
748        if !self.config.enable_model_sensitivity_analysis {
749            return Err(anyhow::anyhow!("Model sensitivity analysis is disabled"));
750        }
751
752        // Analyze hyperparameter sensitivity
753        let hyperparameter_sensitivity =
754            self.analyze_hyperparameter_sensitivity(model_params, performance_metrics);
755
756        // Analyze architecture sensitivity
757        let architecture_sensitivity =
758            self.analyze_architecture_sensitivity(architecture_config, performance_metrics);
759
760        // Analyze data sensitivity (simulated)
761        let data_sensitivity = self.analyze_data_sensitivity(performance_metrics);
762
763        // Analyze training sensitivity
764        let training_sensitivity =
765            self.analyze_training_sensitivity(model_params, performance_metrics);
766
767        // Generate overall insights
768        let sensitivity_insights = self.generate_sensitivity_insights(
769            &hyperparameter_sensitivity,
770            &architecture_sensitivity,
771            &data_sensitivity,
772            &training_sensitivity,
773        );
774
775        let result = ModelSensitivityAnalysisResult {
776            timestamp: Utc::now(),
777            hyperparameter_sensitivity,
778            architecture_sensitivity,
779            data_sensitivity,
780            training_sensitivity,
781            sensitivity_insights,
782        };
783
784        self.sensitivity_analysis_results.push(result.clone());
785        Ok(result)
786    }
787
788    /// Generate comprehensive advanced ML debugging report
789    pub async fn generate_report(&self) -> Result<AdvancedMLDebuggingReport> {
790        Ok(AdvancedMLDebuggingReport {
791            timestamp: Utc::now(),
792            config: self.config.clone(),
793            lr_analysis_count: self.lr_analysis_results.len(),
794            sensitivity_analysis_count: self.sensitivity_analysis_results.len(),
795            recent_lr_analyses: self.lr_analysis_results.iter().rev().take(3).cloned().collect(),
796            recent_sensitivity_analyses: self
797                .sensitivity_analysis_results
798                .iter()
799                .rev()
800                .take(3)
801                .cloned()
802                .collect(),
803            advanced_insights: self.generate_advanced_insights(),
804        })
805    }
806
807    // Helper methods for layer-wise LR analysis
808
809    fn analyze_single_layer_lr(
810        &self,
811        layer_id: &str,
812        gradients: &ArrayD<f32>,
813        weights: &ArrayD<f32>,
814        current_lr: f64,
815        loss_history: &[f64],
816    ) -> LayerLRRecommendation {
817        // Calculate gradient statistics
818        let gradient_magnitude =
819            gradients.iter().map(|&x| x.abs() as f64).sum::<f64>() / gradients.len() as f64;
820        let weight_magnitude =
821            weights.iter().map(|&x| x.abs() as f64).sum::<f64>() / weights.len() as f64;
822
823        // Estimate optimal learning rate based on gradient properties
824        let gradient_variance =
825            gradients.iter().map(|&x| (x as f64 - gradient_magnitude).powi(2)).sum::<f64>()
826                / gradients.len() as f64;
827
828        let gradient_norm = gradient_magnitude;
829        let recommended_lr = if gradient_norm > 0.0 {
830            // Adaptive learning rate based on gradient properties
831            let base_lr = 0.001;
832            let adaptation_factor = (1.0 / (1.0 + gradient_variance)).sqrt();
833            let magnitude_factor = (1.0 / (1.0 + gradient_norm)).sqrt();
834            base_lr * adaptation_factor * magnitude_factor * 10.0
835        } else {
836            current_lr
837        };
838
839        // Calculate layer metrics
840        let layer_metrics = LayerLRMetrics {
841            gradient_magnitude,
842            weight_update_magnitude: gradient_magnitude * current_lr,
843            parameter_norm: weight_magnitude,
844            loss_contribution: self.estimate_layer_loss_contribution(loss_history),
845            stability_score: self.calculate_layer_stability(gradients),
846            convergence_rate: self.estimate_convergence_rate(loss_history),
847            learning_efficiency: gradient_magnitude / (weight_magnitude + 1e-8),
848        };
849
850        // Determine urgency
851        let lr_ratio = recommended_lr / current_lr;
852        let urgency = if !(0.1..=10.0).contains(&lr_ratio) {
853            AdaptationUrgency::Critical
854        } else if !(0.33..=3.0).contains(&lr_ratio) {
855            AdaptationUrgency::High
856        } else if !(0.67..=1.5).contains(&lr_ratio) {
857            AdaptationUrgency::Medium
858        } else {
859            AdaptationUrgency::Low
860        };
861
862        // Generate reasoning
863        let reasoning = if recommended_lr > current_lr * 1.2 {
864            "Layer shows slow learning with small gradients, increase learning rate".to_string()
865        } else if recommended_lr < current_lr * 0.8 {
866            "Layer shows instability or large gradients, decrease learning rate".to_string()
867        } else {
868            "Current learning rate appears appropriate for this layer".to_string()
869        };
870
871        LayerLRRecommendation {
872            layer_id: layer_id.to_string(),
873            layer_type: self.infer_layer_type(layer_id),
874            current_lr,
875            recommended_lr,
876            confidence: 0.8, // Would be calculated based on statistical confidence
877            reasoning,
878            layer_metrics,
879            lr_sensitivity: lr_ratio.abs(),
880            urgency,
881        }
882    }
883
884    fn generate_global_lr_insights(
885        &self,
886        layer_recommendations: &HashMap<String, LayerLRRecommendation>,
887        loss_history: &[f64],
888    ) -> GlobalLRInsights {
889        let overall_efficiency = layer_recommendations
890            .values()
891            .map(|rec| rec.layer_metrics.learning_efficiency)
892            .sum::<f64>()
893            / layer_recommendations.len() as f64;
894
895        let lr_distribution_health = self.calculate_lr_distribution_health(layer_recommendations);
896        let gradient_flow_quality = self.calculate_gradient_flow_quality(layer_recommendations);
897        let training_stability =
898            self.assess_training_stability(layer_recommendations, loss_history);
899        let global_adjustments = self.generate_global_adjustments(layer_recommendations);
900        let critical_issues = self.identify_critical_issues(layer_recommendations);
901
902        GlobalLRInsights {
903            overall_efficiency,
904            lr_distribution_health,
905            gradient_flow_quality,
906            training_stability,
907            global_adjustments,
908            critical_issues,
909        }
910    }
911
912    fn create_lr_adaptation_strategy(
913        &self,
914        _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
915        global_insights: &GlobalLRInsights,
916    ) -> LRAdaptationStrategy {
917        // Simplified strategy creation
918        let strategy_name = if global_insights.overall_efficiency < 0.5 {
919            "Aggressive Learning Rate Adaptation".to_string()
920        } else {
921            "Conservative Learning Rate Tuning".to_string()
922        };
923
924        LRAdaptationStrategy {
925            strategy_name: strategy_name.clone(),
926            description: "Strategy to optimize learning rates based on current model state"
927                .to_string(),
928            implementation_steps: vec![ImplementationStep {
929                step_number: 1,
930                description: "Implement layer-wise learning rate multipliers".to_string(),
931                code_changes: vec!["Add lr_multipliers to optimizer config".to_string()],
932                timeline: "1-2 days".to_string(),
933                dependencies: vec!["Optimizer modification".to_string()],
934            }],
935            expected_benefits: vec![
936                "Improved convergence speed".to_string(),
937                "Better training stability".to_string(),
938                "Reduced overfitting risk".to_string(),
939            ],
940            potential_risks: vec!["Initial instability during adaptation".to_string()],
941            success_metrics: vec![
942                "Faster loss reduction".to_string(),
943                "Improved validation accuracy".to_string(),
944            ],
945            monitoring_requirements: vec!["Track per-layer gradient norms".to_string()],
946        }
947    }
948
949    fn generate_training_phase_recommendations(
950        &self,
951        _strategy: &LRAdaptationStrategy,
952    ) -> Vec<TrainingPhaseRecommendation> {
953        vec![TrainingPhaseRecommendation {
954            phase_name: "Warmup Phase".to_string(),
955            duration_epochs: 5,
956            lr_schedule: LRSchedule {
957                schedule_type: LRScheduleType::LinearDecay,
958                initial_lr: 0.0001,
959                parameters: HashMap::new(),
960                layer_multipliers: HashMap::new(),
961            },
962            objectives: vec!["Stabilize training".to_string()],
963            success_criteria: vec!["Decreasing loss".to_string()],
964            transition_conditions: vec!["Stable gradient norms".to_string()],
965        }]
966    }
967
968    fn predict_lr_schedule_performance(
969        &self,
970        _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
971    ) -> Vec<LRSchedulePrediction> {
972        vec![LRSchedulePrediction {
973            schedule: LRSchedule {
974                schedule_type: LRScheduleType::ExponentialDecay,
975                initial_lr: 0.001,
976                parameters: HashMap::new(),
977                layer_multipliers: HashMap::new(),
978            },
979            predicted_accuracy: 0.92,
980            predicted_convergence_epochs: 50,
981            predicted_stability: 0.8,
982            prediction_confidence: 0.7,
983            risk_assessment: RiskAssessment {
984                overall_risk: RiskLevel::Medium,
985                specific_risks: vec![],
986                mitigation_strategies: vec![],
987            },
988        }]
989    }
990
991    // Helper methods for sensitivity analysis
992
993    fn analyze_hyperparameter_sensitivity(
994        &self,
995        params: &HashMap<String, f64>,
996        _metrics: &[f64],
997    ) -> HyperparameterSensitivity {
998        let learning_rate_sensitivity = ParameterSensitivity {
999            parameter_name: "learning_rate".to_string(),
1000            current_value: params.get("learning_rate").copied().unwrap_or(0.001),
1001            sensitivity_score: 0.8,
1002            optimal_range: (0.0001, 0.01),
1003            impact_curve: vec![(0.0001, 0.7), (0.001, 0.9), (0.01, 0.85)],
1004            stability_region: (0.0005, 0.005),
1005            critical_thresholds: vec![0.0001, 0.1],
1006        };
1007
1008        let batch_size_sensitivity = ParameterSensitivity {
1009            parameter_name: "batch_size".to_string(),
1010            current_value: params.get("batch_size").copied().unwrap_or(32.0),
1011            sensitivity_score: 0.6,
1012            optimal_range: (16.0, 128.0),
1013            impact_curve: vec![(16.0, 0.85), (32.0, 0.9), (64.0, 0.88), (128.0, 0.82)],
1014            stability_region: (16.0, 64.0),
1015            critical_thresholds: vec![8.0, 256.0],
1016        };
1017
1018        let regularization_sensitivity = ParameterSensitivity {
1019            parameter_name: "weight_decay".to_string(),
1020            current_value: params.get("weight_decay").copied().unwrap_or(0.01),
1021            sensitivity_score: 0.4,
1022            optimal_range: (0.001, 0.1),
1023            impact_curve: vec![(0.001, 0.88), (0.01, 0.9), (0.1, 0.87)],
1024            stability_region: (0.005, 0.05),
1025            critical_thresholds: vec![0.0001, 1.0],
1026        };
1027
1028        HyperparameterSensitivity {
1029            learning_rate_sensitivity,
1030            batch_size_sensitivity,
1031            regularization_sensitivity,
1032            architecture_param_sensitivity: HashMap::new(),
1033            most_sensitive_params: vec!["learning_rate".to_string(), "batch_size".to_string()],
1034            least_sensitive_params: vec!["weight_decay".to_string()],
1035            interaction_effects: vec![],
1036        }
1037    }
1038
1039    fn analyze_architecture_sensitivity(
1040        &self,
1041        _config: &HashMap<String, f64>,
1042        _metrics: &[f64],
1043    ) -> ArchitectureSensitivity {
1044        ArchitectureSensitivity {
1045            depth_sensitivity: ArchitecturalSensitivity {
1046                component_name: "model_depth".to_string(),
1047                change_sensitivity: 0.7,
1048                degradation_curve: vec![(6.0, 0.85), (12.0, 0.9), (24.0, 0.88)],
1049                min_viable_config: 6.0,
1050                optimal_config: 12.0,
1051                diminishing_returns_threshold: 18.0,
1052            },
1053            width_sensitivity: ArchitecturalSensitivity {
1054                component_name: "hidden_size".to_string(),
1055                change_sensitivity: 0.6,
1056                degradation_curve: vec![(256.0, 0.82), (512.0, 0.9), (1024.0, 0.91)],
1057                min_viable_config: 256.0,
1058                optimal_config: 512.0,
1059                diminishing_returns_threshold: 768.0,
1060            },
1061            attention_head_sensitivity: ArchitecturalSensitivity {
1062                component_name: "num_attention_heads".to_string(),
1063                change_sensitivity: 0.5,
1064                degradation_curve: vec![(4.0, 0.87), (8.0, 0.9), (16.0, 0.89)],
1065                min_viable_config: 4.0,
1066                optimal_config: 8.0,
1067                diminishing_returns_threshold: 12.0,
1068            },
1069            skip_connection_sensitivity: ArchitecturalSensitivity {
1070                component_name: "skip_connections".to_string(),
1071                change_sensitivity: 0.8,
1072                degradation_curve: vec![(0.0, 0.75), (1.0, 0.9)],
1073                min_viable_config: 1.0,
1074                optimal_config: 1.0,
1075                diminishing_returns_threshold: 1.0,
1076            },
1077            component_importance: HashMap::new(),
1078            bottlenecks: vec![],
1079        }
1080    }
1081
1082    fn analyze_data_sensitivity(&self, _metrics: &[f64]) -> DataSensitivity {
1083        DataSensitivity {
1084            data_size_sensitivity: DataSizeSensitivity {
1085                current_size: 10000,
1086                minimum_effective_size: 1000,
1087                performance_curve: vec![(1000, 0.7), (5000, 0.85), (10000, 0.9), (20000, 0.92)],
1088                data_efficiency: 0.85,
1089                diminishing_returns_point: 15000,
1090            },
1091            data_quality_sensitivity: DataQualitySensitivity {
1092                noise_tolerance: 0.1,
1093                label_quality_importance: 0.9,
1094                feature_quality_importance: 0.7,
1095                quality_impact_curve: vec![(0.9, 0.9), (0.8, 0.85), (0.7, 0.75)],
1096            },
1097            distribution_sensitivity: DistributionSensitivity {
1098                shift_sensitivity: 0.6,
1099                imbalance_sensitivity: 0.5,
1100                domain_adaptation_requirements: vec!["Gradual domain adaptation".to_string()],
1101                distribution_robustness: 0.7,
1102            },
1103            feature_sensitivity: FeatureSensitivityAnalysis {
1104                most_important_features: vec!["feature_1".to_string(), "feature_2".to_string()],
1105                least_important_features: vec!["feature_10".to_string()],
1106                feature_interactions: HashMap::new(),
1107                feature_stability: HashMap::new(),
1108            },
1109        }
1110    }
1111
1112    fn analyze_training_sensitivity(
1113        &self,
1114        _params: &HashMap<String, f64>,
1115        _metrics: &[f64],
1116    ) -> TrainingSensitivity {
1117        TrainingSensitivity {
1118            initialization_sensitivity: InitializationSensitivity {
1119                weight_init_sensitivity: 0.6,
1120                bias_init_sensitivity: 0.3,
1121                seed_sensitivity: 0.2,
1122                scheme_importance: HashMap::new(),
1123            },
1124            optimization_sensitivity: OptimizationSensitivity {
1125                optimizer_sensitivity: 0.7,
1126                momentum_sensitivity: 0.5,
1127                second_moment_sensitivity: 0.4,
1128                optimizer_comparison: HashMap::new(),
1129            },
1130            schedule_sensitivity: ScheduleSensitivity {
1131                lr_schedule_sensitivity: 0.8,
1132                duration_sensitivity: 0.6,
1133                warmup_sensitivity: 0.4,
1134                schedule_param_importance: HashMap::new(),
1135            },
1136            regularization_sensitivity: RegularizationSensitivity {
1137                dropout_sensitivity: 0.5,
1138                weight_decay_sensitivity: 0.4,
1139                batch_norm_sensitivity: 0.6,
1140                method_comparison: HashMap::new(),
1141            },
1142        }
1143    }
1144
1145    fn generate_sensitivity_insights(
1146        &self,
1147        _hyper_sens: &HyperparameterSensitivity,
1148        _arch_sens: &ArchitectureSensitivity,
1149        _data_sens: &DataSensitivity,
1150        _training_sens: &TrainingSensitivity,
1151    ) -> SensitivityInsights {
1152        SensitivityInsights {
1153            most_critical_factors: vec![
1154                "learning_rate".to_string(),
1155                "model_depth".to_string(),
1156                "skip_connections".to_string(),
1157            ],
1158            least_critical_factors: vec!["bias_initialization".to_string()],
1159            surprising_findings: vec!["Batch size has higher than expected impact".to_string()],
1160            robustness_assessment: RobustnessAssessment {
1161                overall_robustness: 0.7,
1162                category_robustness: HashMap::new(),
1163                vulnerabilities: vec![],
1164                strengths: vec!["Good hyperparameter stability".to_string()],
1165            },
1166            optimization_recommendations: vec![
1167                "Focus on learning rate tuning first".to_string(),
1168                "Consider architectural modifications second".to_string(),
1169            ],
1170        }
1171    }
1172
1173    // Additional helper methods
1174
1175    fn estimate_layer_loss_contribution(&self, loss_history: &[f64]) -> f64 {
1176        // Simplified estimation
1177        if loss_history.len() >= 2 {
1178            (loss_history[loss_history.len() - 2] - loss_history[loss_history.len() - 1]).abs()
1179        } else {
1180            0.1
1181        }
1182    }
1183
1184    fn calculate_layer_stability(&self, gradients: &ArrayD<f32>) -> f64 {
1185        let gradient_variance = gradients.iter().map(|&x| x as f64).collect::<Vec<_>>();
1186
1187        if gradient_variance.is_empty() {
1188            return 0.5;
1189        }
1190
1191        let mean = gradient_variance.iter().sum::<f64>() / gradient_variance.len() as f64;
1192        let variance = gradient_variance.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
1193            / gradient_variance.len() as f64;
1194
1195        1.0 / (1.0 + variance) // Higher stability for lower variance
1196    }
1197
1198    fn estimate_convergence_rate(&self, loss_history: &[f64]) -> f64 {
1199        if loss_history.len() < 3 {
1200            return 0.5;
1201        }
1202
1203        let recent_improvement =
1204            loss_history[loss_history.len() - 3] - loss_history[loss_history.len() - 1];
1205        recent_improvement.abs()
1206    }
1207
1208    fn infer_layer_type(&self, layer_id: &str) -> String {
1209        if layer_id.contains("attention") {
1210            "attention".to_string()
1211        } else if layer_id.contains("feedforward") || layer_id.contains("mlp") {
1212            "feedforward".to_string()
1213        } else if layer_id.contains("embedding") {
1214            "embedding".to_string()
1215        } else {
1216            "unknown".to_string()
1217        }
1218    }
1219
1220    fn calculate_lr_distribution_health(
1221        &self,
1222        recommendations: &HashMap<String, LayerLRRecommendation>,
1223    ) -> f64 {
1224        let lr_ratios: Vec<f64> = recommendations
1225            .values()
1226            .map(|rec| rec.recommended_lr / rec.current_lr)
1227            .collect();
1228
1229        if lr_ratios.is_empty() {
1230            return 0.5;
1231        }
1232
1233        let mean_ratio = lr_ratios.iter().sum::<f64>() / lr_ratios.len() as f64;
1234        let variance = lr_ratios.iter().map(|&x| (x - mean_ratio).powi(2)).sum::<f64>()
1235            / lr_ratios.len() as f64;
1236
1237        1.0 / (1.0 + variance) // Better health for lower variance
1238    }
1239
1240    fn calculate_gradient_flow_quality(
1241        &self,
1242        recommendations: &HashMap<String, LayerLRRecommendation>,
1243    ) -> f64 {
1244        recommendations
1245            .values()
1246            .map(|rec| rec.layer_metrics.stability_score)
1247            .sum::<f64>()
1248            / recommendations.len() as f64
1249    }
1250
1251    fn assess_training_stability(
1252        &self,
1253        recommendations: &HashMap<String, LayerLRRecommendation>,
1254        _loss_history: &[f64],
1255    ) -> TrainingStability {
1256        let stability_score = recommendations
1257            .values()
1258            .map(|rec| rec.layer_metrics.stability_score)
1259            .sum::<f64>()
1260            / recommendations.len() as f64;
1261
1262        TrainingStability {
1263            stability_score,
1264            instability_indicators: vec![],
1265            stability_trends: vec![],
1266            predicted_stability: stability_score * 0.9, // Slightly pessimistic prediction
1267        }
1268    }
1269
1270    fn generate_global_adjustments(
1271        &self,
1272        _recommendations: &HashMap<String, LayerLRRecommendation>,
1273    ) -> Vec<GlobalLRAdjustment> {
1274        vec![GlobalLRAdjustment {
1275            adjustment_type: GlobalAdjustmentType::LayerTypeSpecific,
1276            magnitude: 1.5,
1277            expected_impact: 0.1,
1278            priority: AdjustmentPriority::Medium,
1279            instructions: "Apply different learning rates to attention vs feedforward layers"
1280                .to_string(),
1281        }]
1282    }
1283
1284    fn identify_critical_issues(
1285        &self,
1286        recommendations: &HashMap<String, LayerLRRecommendation>,
1287    ) -> Vec<String> {
1288        let mut issues = Vec::new();
1289
1290        for recommendation in recommendations.values() {
1291            if matches!(recommendation.urgency, AdaptationUrgency::Critical) {
1292                issues.push(format!(
1293                    "Critical learning rate issue in layer {}",
1294                    recommendation.layer_id
1295                ));
1296            }
1297        }
1298
1299        issues
1300    }
1301
1302    fn generate_advanced_insights(&self) -> HashMap<String, String> {
1303        let mut insights = HashMap::new();
1304
1305        insights.insert(
1306            "total_lr_analyses".to_string(),
1307            self.lr_analysis_results.len().to_string(),
1308        );
1309        insights.insert(
1310            "total_sensitivity_analyses".to_string(),
1311            self.sensitivity_analysis_results.len().to_string(),
1312        );
1313
1314        if let Some(latest_lr) = self.lr_analysis_results.last() {
1315            insights.insert(
1316                "latest_lr_efficiency".to_string(),
1317                format!("{:.2}", latest_lr.global_lr_insights.overall_efficiency),
1318            );
1319        }
1320
1321        insights
1322    }
1323}
1324
1325/// Comprehensive advanced ML debugging report
1326#[derive(Debug, Clone, Serialize, Deserialize)]
1327pub struct AdvancedMLDebuggingReport {
1328    pub timestamp: DateTime<Utc>,
1329    pub config: AdvancedMLDebuggingConfig,
1330    pub lr_analysis_count: usize,
1331    pub sensitivity_analysis_count: usize,
1332    pub recent_lr_analyses: Vec<LayerWiseLRAnalysisResult>,
1333    pub recent_sensitivity_analyses: Vec<ModelSensitivityAnalysisResult>,
1334    pub advanced_insights: HashMap<String, String>,
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339    use super::*;
1340
1341    #[tokio::test]
1342    async fn test_advanced_ml_debugger_creation() {
1343        let config = AdvancedMLDebuggingConfig::default();
1344        let debugger = AdvancedMLDebugger::new(config);
1345        assert_eq!(debugger.lr_analysis_results.len(), 0);
1346    }
1347
1348    #[tokio::test]
1349    async fn test_layer_wise_lr_analysis() {
1350        let config = AdvancedMLDebuggingConfig::default();
1351        let mut debugger = AdvancedMLDebugger::new(config);
1352
1353        let mut layer_gradients = HashMap::new();
1354        let mut layer_weights = HashMap::new();
1355
1356        // Create test data
1357        let gradients =
1358            ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.01).collect())
1359                .expect("operation failed in test");
1360        let weights =
1361            ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.1).collect())
1362                .expect("operation failed in test");
1363
1364        layer_gradients.insert("layer_0".to_string(), gradients);
1365        layer_weights.insert("layer_0".to_string(), weights);
1366
1367        let loss_history = vec![1.0, 0.8, 0.6, 0.5];
1368
1369        let result = debugger
1370            .analyze_layer_wise_learning_rates(
1371                &layer_gradients,
1372                &layer_weights,
1373                0.001,
1374                &loss_history,
1375            )
1376            .await;
1377        assert!(result.is_ok());
1378
1379        let analysis = result.expect("operation failed in test");
1380        assert_eq!(analysis.layer_lr_recommendations.len(), 1);
1381        assert!(analysis.layer_lr_recommendations.contains_key("layer_0"));
1382    }
1383}