1use anyhow::Result;
7use chrono::{DateTime, Utc};
8use scirs2_core::ndarray::*; use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AdvancedMLDebuggingConfig {
15 pub enable_layer_wise_lr_analysis: bool,
17 pub enable_model_sensitivity_analysis: bool,
19 pub enable_gradient_flow_optimization: bool,
21 pub enable_neural_architecture_debugging: bool,
23 pub enable_activation_pattern_analysis: bool,
25 pub enable_weight_distribution_analysis: bool,
27 pub enable_training_dynamics_analysis: bool,
29 pub enable_optimization_landscape_analysis: bool,
31 pub sensitivity_samples: usize,
33 pub lr_adaptation_threshold: f64,
35 pub max_layers_to_analyze: usize,
37}
38
39impl Default for AdvancedMLDebuggingConfig {
40 fn default() -> Self {
41 Self {
42 enable_layer_wise_lr_analysis: true,
43 enable_model_sensitivity_analysis: true,
44 enable_gradient_flow_optimization: true,
45 enable_neural_architecture_debugging: true,
46 enable_activation_pattern_analysis: true,
47 enable_weight_distribution_analysis: true,
48 enable_training_dynamics_analysis: true,
49 enable_optimization_landscape_analysis: true,
50 sensitivity_samples: 1000,
51 lr_adaptation_threshold: 0.1,
52 max_layers_to_analyze: 50,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct LayerWiseLRAnalysisResult {
60 pub timestamp: DateTime<Utc>,
62 pub layer_lr_recommendations: HashMap<String, LayerLRRecommendation>,
64 pub global_lr_insights: GlobalLRInsights,
66 pub adaptation_strategy: LRAdaptationStrategy,
68 pub training_phase_recommendations: Vec<TrainingPhaseRecommendation>,
70 pub lr_schedule_predictions: Vec<LRSchedulePrediction>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LayerLRRecommendation {
77 pub layer_id: String,
79 pub layer_type: String,
81 pub current_lr: f64,
83 pub recommended_lr: f64,
85 pub confidence: f64,
87 pub reasoning: String,
89 pub layer_metrics: LayerLRMetrics,
91 pub lr_sensitivity: f64,
93 pub urgency: AdaptationUrgency,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct LayerLRMetrics {
100 pub gradient_magnitude: f64,
102 pub weight_update_magnitude: f64,
104 pub parameter_norm: f64,
106 pub loss_contribution: f64,
108 pub stability_score: f64,
110 pub convergence_rate: f64,
112 pub learning_efficiency: f64,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub enum AdaptationUrgency {
119 Low,
120 Medium,
121 High,
122 Critical,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct GlobalLRInsights {
128 pub overall_efficiency: f64,
130 pub lr_distribution_health: f64,
132 pub gradient_flow_quality: f64,
134 pub training_stability: TrainingStability,
136 pub global_adjustments: Vec<GlobalLRAdjustment>,
138 pub critical_issues: Vec<String>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TrainingStability {
145 pub stability_score: f64,
147 pub instability_indicators: Vec<InstabilityIndicator>,
149 pub stability_trends: Vec<StabilityTrendPoint>,
151 pub predicted_stability: f64,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct InstabilityIndicator {
158 pub instability_type: InstabilityType,
160 pub severity: f64,
162 pub affected_layers: Vec<String>,
164 pub recommended_actions: Vec<String>,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum InstabilityType {
171 GradientExplosion,
172 GradientVanishing,
173 OscillatingLoss,
174 SlowConvergence,
175 WeightDivergence,
176 NumericalInstability,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct StabilityTrendPoint {
182 pub time_step: usize,
184 pub stability_score: f64,
186 pub contributing_factors: HashMap<String, f64>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct GlobalLRAdjustment {
193 pub adjustment_type: GlobalAdjustmentType,
195 pub magnitude: f64,
197 pub expected_impact: f64,
199 pub priority: AdjustmentPriority,
201 pub instructions: String,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub enum GlobalAdjustmentType {
208 UniformScaling,
209 LayerTypeSpecific,
210 DepthDependent,
211 AdaptiveScheduling,
212 WarmupAdjustment,
213 DecayRateModification,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub enum AdjustmentPriority {
219 Low,
220 Medium,
221 High,
222 Immediate,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LRAdaptationStrategy {
228 pub strategy_name: String,
230 pub description: String,
232 pub implementation_steps: Vec<ImplementationStep>,
234 pub expected_benefits: Vec<String>,
236 pub potential_risks: Vec<String>,
238 pub success_metrics: Vec<String>,
240 pub monitoring_requirements: Vec<String>,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct ImplementationStep {
247 pub step_number: usize,
249 pub description: String,
251 pub code_changes: Vec<String>,
253 pub timeline: String,
255 pub dependencies: Vec<String>,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct TrainingPhaseRecommendation {
262 pub phase_name: String,
264 pub duration_epochs: usize,
266 pub lr_schedule: LRSchedule,
268 pub objectives: Vec<String>,
270 pub success_criteria: Vec<String>,
272 pub transition_conditions: Vec<String>,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct LRSchedule {
279 pub schedule_type: LRScheduleType,
281 pub initial_lr: f64,
283 pub parameters: HashMap<String, f64>,
285 pub layer_multipliers: HashMap<String, f64>,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum LRScheduleType {
292 Constant,
293 LinearDecay,
294 ExponentialDecay,
295 CosineAnnealing,
296 StepDecay,
297 CyclicalLR,
298 OneCycleLR,
299 AdaptiveSchedule,
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct LRSchedulePrediction {
305 pub schedule: LRSchedule,
307 pub predicted_accuracy: f64,
309 pub predicted_convergence_epochs: usize,
311 pub predicted_stability: f64,
313 pub prediction_confidence: f64,
315 pub risk_assessment: RiskAssessment,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct RiskAssessment {
322 pub overall_risk: RiskLevel,
324 pub specific_risks: Vec<SpecificRisk>,
326 pub mitigation_strategies: Vec<String>,
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332pub enum RiskLevel {
333 VeryLow,
334 Low,
335 Medium,
336 High,
337 VeryHigh,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct SpecificRisk {
343 pub risk_type: String,
345 pub probability: f64,
347 pub impact: f64,
349 pub description: String,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct ModelSensitivityAnalysisResult {
356 pub timestamp: DateTime<Utc>,
358 pub hyperparameter_sensitivity: HyperparameterSensitivity,
360 pub architecture_sensitivity: ArchitectureSensitivity,
362 pub data_sensitivity: DataSensitivity,
364 pub training_sensitivity: TrainingSensitivity,
366 pub sensitivity_insights: SensitivityInsights,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct HyperparameterSensitivity {
373 pub learning_rate_sensitivity: ParameterSensitivity,
375 pub batch_size_sensitivity: ParameterSensitivity,
377 pub regularization_sensitivity: ParameterSensitivity,
379 pub architecture_param_sensitivity: HashMap<String, ParameterSensitivity>,
381 pub most_sensitive_params: Vec<String>,
383 pub least_sensitive_params: Vec<String>,
385 pub interaction_effects: Vec<ParameterInteraction>,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct ParameterSensitivity {
392 pub parameter_name: String,
394 pub current_value: f64,
396 pub sensitivity_score: f64,
398 pub optimal_range: (f64, f64),
400 pub impact_curve: Vec<(f64, f64)>,
402 pub stability_region: (f64, f64),
404 pub critical_thresholds: Vec<f64>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct ParameterInteraction {
411 pub param1: String,
413 pub param2: String,
415 pub interaction_strength: f64,
417 pub interaction_type: InteractionType,
419 pub joint_optimal_region: HashMap<String, (f64, f64)>,
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425pub enum InteractionType {
426 Synergistic,
427 Antagonistic,
428 Independent,
429 Conditional,
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct ArchitectureSensitivity {
435 pub depth_sensitivity: ArchitecturalSensitivity,
437 pub width_sensitivity: ArchitecturalSensitivity,
439 pub attention_head_sensitivity: ArchitecturalSensitivity,
441 pub skip_connection_sensitivity: ArchitecturalSensitivity,
443 pub component_importance: HashMap<String, f64>,
445 pub bottlenecks: Vec<ArchitecturalBottleneck>,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct ArchitecturalSensitivity {
452 pub component_name: String,
454 pub change_sensitivity: f64,
456 pub degradation_curve: Vec<(f64, f64)>,
458 pub min_viable_config: f64,
460 pub optimal_config: f64,
462 pub diminishing_returns_threshold: f64,
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct ArchitecturalBottleneck {
469 pub location: String,
471 pub bottleneck_type: BottleneckType,
473 pub severity: f64,
475 pub performance_impact: f64,
477 pub resolution_recommendations: Vec<String>,
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum BottleneckType {
484 ComputationalBottleneck,
485 MemoryBottleneck,
486 InformationBottleneck,
487 CapacityBottleneck,
488 CommunicationBottleneck,
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
493pub struct DataSensitivity {
494 pub data_size_sensitivity: DataSizeSensitivity,
496 pub data_quality_sensitivity: DataQualitySensitivity,
498 pub distribution_sensitivity: DistributionSensitivity,
500 pub feature_sensitivity: FeatureSensitivityAnalysis,
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct DataSizeSensitivity {
507 pub current_size: usize,
509 pub minimum_effective_size: usize,
511 pub performance_curve: Vec<(usize, f64)>,
513 pub data_efficiency: f64,
515 pub diminishing_returns_point: usize,
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct DataQualitySensitivity {
522 pub noise_tolerance: f64,
524 pub label_quality_importance: f64,
526 pub feature_quality_importance: f64,
528 pub quality_impact_curve: Vec<(f64, f64)>,
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct DistributionSensitivity {
535 pub shift_sensitivity: f64,
537 pub imbalance_sensitivity: f64,
539 pub domain_adaptation_requirements: Vec<String>,
541 pub distribution_robustness: f64,
543}
544
545#[derive(Debug, Clone, Serialize, Deserialize)]
547pub struct FeatureSensitivityAnalysis {
548 pub most_important_features: Vec<String>,
550 pub least_important_features: Vec<String>,
552 pub feature_interactions: HashMap<(String, String), f64>,
554 pub feature_stability: HashMap<String, f64>,
556}
557
558#[derive(Debug, Clone, Serialize, Deserialize)]
560pub struct TrainingSensitivity {
561 pub initialization_sensitivity: InitializationSensitivity,
563 pub optimization_sensitivity: OptimizationSensitivity,
565 pub schedule_sensitivity: ScheduleSensitivity,
567 pub regularization_sensitivity: RegularizationSensitivity,
569}
570
571#[derive(Debug, Clone, Serialize, Deserialize)]
573pub struct InitializationSensitivity {
574 pub weight_init_sensitivity: f64,
576 pub bias_init_sensitivity: f64,
578 pub seed_sensitivity: f64,
580 pub scheme_importance: HashMap<String, f64>,
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
586pub struct OptimizationSensitivity {
587 pub optimizer_sensitivity: f64,
589 pub momentum_sensitivity: f64,
591 pub second_moment_sensitivity: f64,
593 pub optimizer_comparison: HashMap<String, f64>,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct ScheduleSensitivity {
600 pub lr_schedule_sensitivity: f64,
602 pub duration_sensitivity: f64,
604 pub warmup_sensitivity: f64,
606 pub schedule_param_importance: HashMap<String, f64>,
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
612pub struct RegularizationSensitivity {
613 pub dropout_sensitivity: f64,
615 pub weight_decay_sensitivity: f64,
617 pub batch_norm_sensitivity: f64,
619 pub method_comparison: HashMap<String, f64>,
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct SensitivityInsights {
626 pub most_critical_factors: Vec<String>,
628 pub least_critical_factors: Vec<String>,
630 pub surprising_findings: Vec<String>,
632 pub robustness_assessment: RobustnessAssessment,
634 pub optimization_recommendations: Vec<String>,
636}
637
638#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct RobustnessAssessment {
641 pub overall_robustness: f64,
643 pub category_robustness: HashMap<String, f64>,
645 pub vulnerabilities: Vec<Vulnerability>,
647 pub strengths: Vec<String>,
649}
650
651#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct Vulnerability {
654 pub vulnerability_type: String,
656 pub severity: f64,
658 pub impact: String,
660 pub mitigation_strategies: Vec<String>,
662}
663
664#[derive(Debug)]
666pub struct AdvancedMLDebugger {
667 config: AdvancedMLDebuggingConfig,
668 lr_analysis_results: Vec<LayerWiseLRAnalysisResult>,
669 sensitivity_analysis_results: Vec<ModelSensitivityAnalysisResult>,
670}
671
672impl AdvancedMLDebugger {
673 pub fn new(config: AdvancedMLDebuggingConfig) -> Self {
675 Self {
676 config,
677 lr_analysis_results: Vec::new(),
678 sensitivity_analysis_results: Vec::new(),
679 }
680 }
681
682 pub async fn analyze_layer_wise_learning_rates(
684 &mut self,
685 layer_gradients: &HashMap<String, ArrayD<f32>>,
686 layer_weights: &HashMap<String, ArrayD<f32>>,
687 current_lr: f64,
688 loss_history: &[f64],
689 ) -> Result<LayerWiseLRAnalysisResult> {
690 if !self.config.enable_layer_wise_lr_analysis {
691 return Err(anyhow::anyhow!(
692 "Layer-wise learning rate analysis is disabled"
693 ));
694 }
695
696 let mut layer_lr_recommendations = HashMap::new();
697
698 for (layer_id, gradients) in layer_gradients {
700 if let Some(weights) = layer_weights.get(layer_id) {
701 let recommendation = self.analyze_single_layer_lr(
702 layer_id,
703 gradients,
704 weights,
705 current_lr,
706 loss_history,
707 );
708 layer_lr_recommendations.insert(layer_id.clone(), recommendation);
709 }
710 }
711
712 let global_lr_insights =
714 self.generate_global_lr_insights(&layer_lr_recommendations, loss_history);
715
716 let adaptation_strategy =
718 self.create_lr_adaptation_strategy(&layer_lr_recommendations, &global_lr_insights);
719
720 let training_phase_recommendations =
722 self.generate_training_phase_recommendations(&adaptation_strategy);
723
724 let lr_schedule_predictions =
726 self.predict_lr_schedule_performance(&layer_lr_recommendations);
727
728 let result = LayerWiseLRAnalysisResult {
729 timestamp: Utc::now(),
730 layer_lr_recommendations,
731 global_lr_insights,
732 adaptation_strategy,
733 training_phase_recommendations,
734 lr_schedule_predictions,
735 };
736
737 self.lr_analysis_results.push(result.clone());
738 Ok(result)
739 }
740
741 pub async fn analyze_model_sensitivity(
743 &mut self,
744 model_params: &HashMap<String, f64>,
745 performance_metrics: &[f64],
746 architecture_config: &HashMap<String, f64>,
747 ) -> Result<ModelSensitivityAnalysisResult> {
748 if !self.config.enable_model_sensitivity_analysis {
749 return Err(anyhow::anyhow!("Model sensitivity analysis is disabled"));
750 }
751
752 let hyperparameter_sensitivity =
754 self.analyze_hyperparameter_sensitivity(model_params, performance_metrics);
755
756 let architecture_sensitivity =
758 self.analyze_architecture_sensitivity(architecture_config, performance_metrics);
759
760 let data_sensitivity = self.analyze_data_sensitivity(performance_metrics);
762
763 let training_sensitivity =
765 self.analyze_training_sensitivity(model_params, performance_metrics);
766
767 let sensitivity_insights = self.generate_sensitivity_insights(
769 &hyperparameter_sensitivity,
770 &architecture_sensitivity,
771 &data_sensitivity,
772 &training_sensitivity,
773 );
774
775 let result = ModelSensitivityAnalysisResult {
776 timestamp: Utc::now(),
777 hyperparameter_sensitivity,
778 architecture_sensitivity,
779 data_sensitivity,
780 training_sensitivity,
781 sensitivity_insights,
782 };
783
784 self.sensitivity_analysis_results.push(result.clone());
785 Ok(result)
786 }
787
788 pub async fn generate_report(&self) -> Result<AdvancedMLDebuggingReport> {
790 Ok(AdvancedMLDebuggingReport {
791 timestamp: Utc::now(),
792 config: self.config.clone(),
793 lr_analysis_count: self.lr_analysis_results.len(),
794 sensitivity_analysis_count: self.sensitivity_analysis_results.len(),
795 recent_lr_analyses: self.lr_analysis_results.iter().rev().take(3).cloned().collect(),
796 recent_sensitivity_analyses: self
797 .sensitivity_analysis_results
798 .iter()
799 .rev()
800 .take(3)
801 .cloned()
802 .collect(),
803 advanced_insights: self.generate_advanced_insights(),
804 })
805 }
806
807 fn analyze_single_layer_lr(
810 &self,
811 layer_id: &str,
812 gradients: &ArrayD<f32>,
813 weights: &ArrayD<f32>,
814 current_lr: f64,
815 loss_history: &[f64],
816 ) -> LayerLRRecommendation {
817 let gradient_magnitude =
819 gradients.iter().map(|&x| x.abs() as f64).sum::<f64>() / gradients.len() as f64;
820 let weight_magnitude =
821 weights.iter().map(|&x| x.abs() as f64).sum::<f64>() / weights.len() as f64;
822
823 let gradient_variance =
825 gradients.iter().map(|&x| (x as f64 - gradient_magnitude).powi(2)).sum::<f64>()
826 / gradients.len() as f64;
827
828 let gradient_norm = gradient_magnitude;
829 let recommended_lr = if gradient_norm > 0.0 {
830 let base_lr = 0.001;
832 let adaptation_factor = (1.0 / (1.0 + gradient_variance)).sqrt();
833 let magnitude_factor = (1.0 / (1.0 + gradient_norm)).sqrt();
834 base_lr * adaptation_factor * magnitude_factor * 10.0
835 } else {
836 current_lr
837 };
838
839 let layer_metrics = LayerLRMetrics {
841 gradient_magnitude,
842 weight_update_magnitude: gradient_magnitude * current_lr,
843 parameter_norm: weight_magnitude,
844 loss_contribution: self.estimate_layer_loss_contribution(loss_history),
845 stability_score: self.calculate_layer_stability(gradients),
846 convergence_rate: self.estimate_convergence_rate(loss_history),
847 learning_efficiency: gradient_magnitude / (weight_magnitude + 1e-8),
848 };
849
850 let lr_ratio = recommended_lr / current_lr;
852 let urgency = if !(0.1..=10.0).contains(&lr_ratio) {
853 AdaptationUrgency::Critical
854 } else if !(0.33..=3.0).contains(&lr_ratio) {
855 AdaptationUrgency::High
856 } else if !(0.67..=1.5).contains(&lr_ratio) {
857 AdaptationUrgency::Medium
858 } else {
859 AdaptationUrgency::Low
860 };
861
862 let reasoning = if recommended_lr > current_lr * 1.2 {
864 "Layer shows slow learning with small gradients, increase learning rate".to_string()
865 } else if recommended_lr < current_lr * 0.8 {
866 "Layer shows instability or large gradients, decrease learning rate".to_string()
867 } else {
868 "Current learning rate appears appropriate for this layer".to_string()
869 };
870
871 LayerLRRecommendation {
872 layer_id: layer_id.to_string(),
873 layer_type: self.infer_layer_type(layer_id),
874 current_lr,
875 recommended_lr,
876 confidence: 0.8, reasoning,
878 layer_metrics,
879 lr_sensitivity: lr_ratio.abs(),
880 urgency,
881 }
882 }
883
884 fn generate_global_lr_insights(
885 &self,
886 layer_recommendations: &HashMap<String, LayerLRRecommendation>,
887 loss_history: &[f64],
888 ) -> GlobalLRInsights {
889 let overall_efficiency = layer_recommendations
890 .values()
891 .map(|rec| rec.layer_metrics.learning_efficiency)
892 .sum::<f64>()
893 / layer_recommendations.len() as f64;
894
895 let lr_distribution_health = self.calculate_lr_distribution_health(layer_recommendations);
896 let gradient_flow_quality = self.calculate_gradient_flow_quality(layer_recommendations);
897 let training_stability =
898 self.assess_training_stability(layer_recommendations, loss_history);
899 let global_adjustments = self.generate_global_adjustments(layer_recommendations);
900 let critical_issues = self.identify_critical_issues(layer_recommendations);
901
902 GlobalLRInsights {
903 overall_efficiency,
904 lr_distribution_health,
905 gradient_flow_quality,
906 training_stability,
907 global_adjustments,
908 critical_issues,
909 }
910 }
911
912 fn create_lr_adaptation_strategy(
913 &self,
914 _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
915 global_insights: &GlobalLRInsights,
916 ) -> LRAdaptationStrategy {
917 let strategy_name = if global_insights.overall_efficiency < 0.5 {
919 "Aggressive Learning Rate Adaptation".to_string()
920 } else {
921 "Conservative Learning Rate Tuning".to_string()
922 };
923
924 LRAdaptationStrategy {
925 strategy_name: strategy_name.clone(),
926 description: "Strategy to optimize learning rates based on current model state"
927 .to_string(),
928 implementation_steps: vec![ImplementationStep {
929 step_number: 1,
930 description: "Implement layer-wise learning rate multipliers".to_string(),
931 code_changes: vec!["Add lr_multipliers to optimizer config".to_string()],
932 timeline: "1-2 days".to_string(),
933 dependencies: vec!["Optimizer modification".to_string()],
934 }],
935 expected_benefits: vec![
936 "Improved convergence speed".to_string(),
937 "Better training stability".to_string(),
938 "Reduced overfitting risk".to_string(),
939 ],
940 potential_risks: vec!["Initial instability during adaptation".to_string()],
941 success_metrics: vec![
942 "Faster loss reduction".to_string(),
943 "Improved validation accuracy".to_string(),
944 ],
945 monitoring_requirements: vec!["Track per-layer gradient norms".to_string()],
946 }
947 }
948
949 fn generate_training_phase_recommendations(
950 &self,
951 _strategy: &LRAdaptationStrategy,
952 ) -> Vec<TrainingPhaseRecommendation> {
953 vec![TrainingPhaseRecommendation {
954 phase_name: "Warmup Phase".to_string(),
955 duration_epochs: 5,
956 lr_schedule: LRSchedule {
957 schedule_type: LRScheduleType::LinearDecay,
958 initial_lr: 0.0001,
959 parameters: HashMap::new(),
960 layer_multipliers: HashMap::new(),
961 },
962 objectives: vec!["Stabilize training".to_string()],
963 success_criteria: vec!["Decreasing loss".to_string()],
964 transition_conditions: vec!["Stable gradient norms".to_string()],
965 }]
966 }
967
968 fn predict_lr_schedule_performance(
969 &self,
970 _layer_recommendations: &HashMap<String, LayerLRRecommendation>,
971 ) -> Vec<LRSchedulePrediction> {
972 vec![LRSchedulePrediction {
973 schedule: LRSchedule {
974 schedule_type: LRScheduleType::ExponentialDecay,
975 initial_lr: 0.001,
976 parameters: HashMap::new(),
977 layer_multipliers: HashMap::new(),
978 },
979 predicted_accuracy: 0.92,
980 predicted_convergence_epochs: 50,
981 predicted_stability: 0.8,
982 prediction_confidence: 0.7,
983 risk_assessment: RiskAssessment {
984 overall_risk: RiskLevel::Medium,
985 specific_risks: vec![],
986 mitigation_strategies: vec![],
987 },
988 }]
989 }
990
991 fn analyze_hyperparameter_sensitivity(
994 &self,
995 params: &HashMap<String, f64>,
996 _metrics: &[f64],
997 ) -> HyperparameterSensitivity {
998 let learning_rate_sensitivity = ParameterSensitivity {
999 parameter_name: "learning_rate".to_string(),
1000 current_value: params.get("learning_rate").copied().unwrap_or(0.001),
1001 sensitivity_score: 0.8,
1002 optimal_range: (0.0001, 0.01),
1003 impact_curve: vec![(0.0001, 0.7), (0.001, 0.9), (0.01, 0.85)],
1004 stability_region: (0.0005, 0.005),
1005 critical_thresholds: vec![0.0001, 0.1],
1006 };
1007
1008 let batch_size_sensitivity = ParameterSensitivity {
1009 parameter_name: "batch_size".to_string(),
1010 current_value: params.get("batch_size").copied().unwrap_or(32.0),
1011 sensitivity_score: 0.6,
1012 optimal_range: (16.0, 128.0),
1013 impact_curve: vec![(16.0, 0.85), (32.0, 0.9), (64.0, 0.88), (128.0, 0.82)],
1014 stability_region: (16.0, 64.0),
1015 critical_thresholds: vec![8.0, 256.0],
1016 };
1017
1018 let regularization_sensitivity = ParameterSensitivity {
1019 parameter_name: "weight_decay".to_string(),
1020 current_value: params.get("weight_decay").copied().unwrap_or(0.01),
1021 sensitivity_score: 0.4,
1022 optimal_range: (0.001, 0.1),
1023 impact_curve: vec![(0.001, 0.88), (0.01, 0.9), (0.1, 0.87)],
1024 stability_region: (0.005, 0.05),
1025 critical_thresholds: vec![0.0001, 1.0],
1026 };
1027
1028 HyperparameterSensitivity {
1029 learning_rate_sensitivity,
1030 batch_size_sensitivity,
1031 regularization_sensitivity,
1032 architecture_param_sensitivity: HashMap::new(),
1033 most_sensitive_params: vec!["learning_rate".to_string(), "batch_size".to_string()],
1034 least_sensitive_params: vec!["weight_decay".to_string()],
1035 interaction_effects: vec![],
1036 }
1037 }
1038
1039 fn analyze_architecture_sensitivity(
1040 &self,
1041 _config: &HashMap<String, f64>,
1042 _metrics: &[f64],
1043 ) -> ArchitectureSensitivity {
1044 ArchitectureSensitivity {
1045 depth_sensitivity: ArchitecturalSensitivity {
1046 component_name: "model_depth".to_string(),
1047 change_sensitivity: 0.7,
1048 degradation_curve: vec![(6.0, 0.85), (12.0, 0.9), (24.0, 0.88)],
1049 min_viable_config: 6.0,
1050 optimal_config: 12.0,
1051 diminishing_returns_threshold: 18.0,
1052 },
1053 width_sensitivity: ArchitecturalSensitivity {
1054 component_name: "hidden_size".to_string(),
1055 change_sensitivity: 0.6,
1056 degradation_curve: vec![(256.0, 0.82), (512.0, 0.9), (1024.0, 0.91)],
1057 min_viable_config: 256.0,
1058 optimal_config: 512.0,
1059 diminishing_returns_threshold: 768.0,
1060 },
1061 attention_head_sensitivity: ArchitecturalSensitivity {
1062 component_name: "num_attention_heads".to_string(),
1063 change_sensitivity: 0.5,
1064 degradation_curve: vec![(4.0, 0.87), (8.0, 0.9), (16.0, 0.89)],
1065 min_viable_config: 4.0,
1066 optimal_config: 8.0,
1067 diminishing_returns_threshold: 12.0,
1068 },
1069 skip_connection_sensitivity: ArchitecturalSensitivity {
1070 component_name: "skip_connections".to_string(),
1071 change_sensitivity: 0.8,
1072 degradation_curve: vec![(0.0, 0.75), (1.0, 0.9)],
1073 min_viable_config: 1.0,
1074 optimal_config: 1.0,
1075 diminishing_returns_threshold: 1.0,
1076 },
1077 component_importance: HashMap::new(),
1078 bottlenecks: vec![],
1079 }
1080 }
1081
1082 fn analyze_data_sensitivity(&self, _metrics: &[f64]) -> DataSensitivity {
1083 DataSensitivity {
1084 data_size_sensitivity: DataSizeSensitivity {
1085 current_size: 10000,
1086 minimum_effective_size: 1000,
1087 performance_curve: vec![(1000, 0.7), (5000, 0.85), (10000, 0.9), (20000, 0.92)],
1088 data_efficiency: 0.85,
1089 diminishing_returns_point: 15000,
1090 },
1091 data_quality_sensitivity: DataQualitySensitivity {
1092 noise_tolerance: 0.1,
1093 label_quality_importance: 0.9,
1094 feature_quality_importance: 0.7,
1095 quality_impact_curve: vec![(0.9, 0.9), (0.8, 0.85), (0.7, 0.75)],
1096 },
1097 distribution_sensitivity: DistributionSensitivity {
1098 shift_sensitivity: 0.6,
1099 imbalance_sensitivity: 0.5,
1100 domain_adaptation_requirements: vec!["Gradual domain adaptation".to_string()],
1101 distribution_robustness: 0.7,
1102 },
1103 feature_sensitivity: FeatureSensitivityAnalysis {
1104 most_important_features: vec!["feature_1".to_string(), "feature_2".to_string()],
1105 least_important_features: vec!["feature_10".to_string()],
1106 feature_interactions: HashMap::new(),
1107 feature_stability: HashMap::new(),
1108 },
1109 }
1110 }
1111
1112 fn analyze_training_sensitivity(
1113 &self,
1114 _params: &HashMap<String, f64>,
1115 _metrics: &[f64],
1116 ) -> TrainingSensitivity {
1117 TrainingSensitivity {
1118 initialization_sensitivity: InitializationSensitivity {
1119 weight_init_sensitivity: 0.6,
1120 bias_init_sensitivity: 0.3,
1121 seed_sensitivity: 0.2,
1122 scheme_importance: HashMap::new(),
1123 },
1124 optimization_sensitivity: OptimizationSensitivity {
1125 optimizer_sensitivity: 0.7,
1126 momentum_sensitivity: 0.5,
1127 second_moment_sensitivity: 0.4,
1128 optimizer_comparison: HashMap::new(),
1129 },
1130 schedule_sensitivity: ScheduleSensitivity {
1131 lr_schedule_sensitivity: 0.8,
1132 duration_sensitivity: 0.6,
1133 warmup_sensitivity: 0.4,
1134 schedule_param_importance: HashMap::new(),
1135 },
1136 regularization_sensitivity: RegularizationSensitivity {
1137 dropout_sensitivity: 0.5,
1138 weight_decay_sensitivity: 0.4,
1139 batch_norm_sensitivity: 0.6,
1140 method_comparison: HashMap::new(),
1141 },
1142 }
1143 }
1144
1145 fn generate_sensitivity_insights(
1146 &self,
1147 _hyper_sens: &HyperparameterSensitivity,
1148 _arch_sens: &ArchitectureSensitivity,
1149 _data_sens: &DataSensitivity,
1150 _training_sens: &TrainingSensitivity,
1151 ) -> SensitivityInsights {
1152 SensitivityInsights {
1153 most_critical_factors: vec![
1154 "learning_rate".to_string(),
1155 "model_depth".to_string(),
1156 "skip_connections".to_string(),
1157 ],
1158 least_critical_factors: vec!["bias_initialization".to_string()],
1159 surprising_findings: vec!["Batch size has higher than expected impact".to_string()],
1160 robustness_assessment: RobustnessAssessment {
1161 overall_robustness: 0.7,
1162 category_robustness: HashMap::new(),
1163 vulnerabilities: vec![],
1164 strengths: vec!["Good hyperparameter stability".to_string()],
1165 },
1166 optimization_recommendations: vec![
1167 "Focus on learning rate tuning first".to_string(),
1168 "Consider architectural modifications second".to_string(),
1169 ],
1170 }
1171 }
1172
1173 fn estimate_layer_loss_contribution(&self, loss_history: &[f64]) -> f64 {
1176 if loss_history.len() >= 2 {
1178 (loss_history[loss_history.len() - 2] - loss_history[loss_history.len() - 1]).abs()
1179 } else {
1180 0.1
1181 }
1182 }
1183
1184 fn calculate_layer_stability(&self, gradients: &ArrayD<f32>) -> f64 {
1185 let gradient_variance = gradients.iter().map(|&x| x as f64).collect::<Vec<_>>();
1186
1187 if gradient_variance.is_empty() {
1188 return 0.5;
1189 }
1190
1191 let mean = gradient_variance.iter().sum::<f64>() / gradient_variance.len() as f64;
1192 let variance = gradient_variance.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
1193 / gradient_variance.len() as f64;
1194
1195 1.0 / (1.0 + variance) }
1197
1198 fn estimate_convergence_rate(&self, loss_history: &[f64]) -> f64 {
1199 if loss_history.len() < 3 {
1200 return 0.5;
1201 }
1202
1203 let recent_improvement =
1204 loss_history[loss_history.len() - 3] - loss_history[loss_history.len() - 1];
1205 recent_improvement.abs()
1206 }
1207
1208 fn infer_layer_type(&self, layer_id: &str) -> String {
1209 if layer_id.contains("attention") {
1210 "attention".to_string()
1211 } else if layer_id.contains("feedforward") || layer_id.contains("mlp") {
1212 "feedforward".to_string()
1213 } else if layer_id.contains("embedding") {
1214 "embedding".to_string()
1215 } else {
1216 "unknown".to_string()
1217 }
1218 }
1219
1220 fn calculate_lr_distribution_health(
1221 &self,
1222 recommendations: &HashMap<String, LayerLRRecommendation>,
1223 ) -> f64 {
1224 let lr_ratios: Vec<f64> = recommendations
1225 .values()
1226 .map(|rec| rec.recommended_lr / rec.current_lr)
1227 .collect();
1228
1229 if lr_ratios.is_empty() {
1230 return 0.5;
1231 }
1232
1233 let mean_ratio = lr_ratios.iter().sum::<f64>() / lr_ratios.len() as f64;
1234 let variance = lr_ratios.iter().map(|&x| (x - mean_ratio).powi(2)).sum::<f64>()
1235 / lr_ratios.len() as f64;
1236
1237 1.0 / (1.0 + variance) }
1239
1240 fn calculate_gradient_flow_quality(
1241 &self,
1242 recommendations: &HashMap<String, LayerLRRecommendation>,
1243 ) -> f64 {
1244 recommendations
1245 .values()
1246 .map(|rec| rec.layer_metrics.stability_score)
1247 .sum::<f64>()
1248 / recommendations.len() as f64
1249 }
1250
1251 fn assess_training_stability(
1252 &self,
1253 recommendations: &HashMap<String, LayerLRRecommendation>,
1254 _loss_history: &[f64],
1255 ) -> TrainingStability {
1256 let stability_score = recommendations
1257 .values()
1258 .map(|rec| rec.layer_metrics.stability_score)
1259 .sum::<f64>()
1260 / recommendations.len() as f64;
1261
1262 TrainingStability {
1263 stability_score,
1264 instability_indicators: vec![],
1265 stability_trends: vec![],
1266 predicted_stability: stability_score * 0.9, }
1268 }
1269
1270 fn generate_global_adjustments(
1271 &self,
1272 _recommendations: &HashMap<String, LayerLRRecommendation>,
1273 ) -> Vec<GlobalLRAdjustment> {
1274 vec![GlobalLRAdjustment {
1275 adjustment_type: GlobalAdjustmentType::LayerTypeSpecific,
1276 magnitude: 1.5,
1277 expected_impact: 0.1,
1278 priority: AdjustmentPriority::Medium,
1279 instructions: "Apply different learning rates to attention vs feedforward layers"
1280 .to_string(),
1281 }]
1282 }
1283
1284 fn identify_critical_issues(
1285 &self,
1286 recommendations: &HashMap<String, LayerLRRecommendation>,
1287 ) -> Vec<String> {
1288 let mut issues = Vec::new();
1289
1290 for recommendation in recommendations.values() {
1291 if matches!(recommendation.urgency, AdaptationUrgency::Critical) {
1292 issues.push(format!(
1293 "Critical learning rate issue in layer {}",
1294 recommendation.layer_id
1295 ));
1296 }
1297 }
1298
1299 issues
1300 }
1301
1302 fn generate_advanced_insights(&self) -> HashMap<String, String> {
1303 let mut insights = HashMap::new();
1304
1305 insights.insert(
1306 "total_lr_analyses".to_string(),
1307 self.lr_analysis_results.len().to_string(),
1308 );
1309 insights.insert(
1310 "total_sensitivity_analyses".to_string(),
1311 self.sensitivity_analysis_results.len().to_string(),
1312 );
1313
1314 if let Some(latest_lr) = self.lr_analysis_results.last() {
1315 insights.insert(
1316 "latest_lr_efficiency".to_string(),
1317 format!("{:.2}", latest_lr.global_lr_insights.overall_efficiency),
1318 );
1319 }
1320
1321 insights
1322 }
1323}
1324
1325#[derive(Debug, Clone, Serialize, Deserialize)]
1327pub struct AdvancedMLDebuggingReport {
1328 pub timestamp: DateTime<Utc>,
1329 pub config: AdvancedMLDebuggingConfig,
1330 pub lr_analysis_count: usize,
1331 pub sensitivity_analysis_count: usize,
1332 pub recent_lr_analyses: Vec<LayerWiseLRAnalysisResult>,
1333 pub recent_sensitivity_analyses: Vec<ModelSensitivityAnalysisResult>,
1334 pub advanced_insights: HashMap<String, String>,
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339 use super::*;
1340
1341 #[tokio::test]
1342 async fn test_advanced_ml_debugger_creation() {
1343 let config = AdvancedMLDebuggingConfig::default();
1344 let debugger = AdvancedMLDebugger::new(config);
1345 assert_eq!(debugger.lr_analysis_results.len(), 0);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_layer_wise_lr_analysis() {
1350 let config = AdvancedMLDebuggingConfig::default();
1351 let mut debugger = AdvancedMLDebugger::new(config);
1352
1353 let mut layer_gradients = HashMap::new();
1354 let mut layer_weights = HashMap::new();
1355
1356 let gradients =
1358 ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.01).collect())
1359 .expect("operation failed in test");
1360 let weights =
1361 ArrayD::from_shape_vec(vec![10, 10], (0..100).map(|x| x as f32 * 0.1).collect())
1362 .expect("operation failed in test");
1363
1364 layer_gradients.insert("layer_0".to_string(), gradients);
1365 layer_weights.insert("layer_0".to_string(), weights);
1366
1367 let loss_history = vec![1.0, 0.8, 0.6, 0.5];
1368
1369 let result = debugger
1370 .analyze_layer_wise_learning_rates(
1371 &layer_gradients,
1372 &layer_weights,
1373 0.001,
1374 &loss_history,
1375 )
1376 .await;
1377 assert!(result.is_ok());
1378
1379 let analysis = result.expect("operation failed in test");
1380 assert_eq!(analysis.layer_lr_recommendations.len(), 1);
1381 assert!(analysis.layer_lr_recommendations.contains_key("layer_0"));
1382 }
1383}