Skip to main content

trustformers_debug/
auto_debugger.rs

1//! Automated debugging system for common issues and optimization suggestions
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{
9    AnomalyDetectorReport, DashboardMetrics, DebugConfig, GradientDebugReport, ProfilerReport,
10};
11
12/// Automated debugging system
13#[derive(Debug)]
14#[allow(dead_code)]
15pub struct AutoDebugger {
16    #[allow(dead_code)]
17    config: DebugConfig,
18    issue_detectors: Vec<Box<dyn IssueDetector>>,
19    fix_suggestions: HashMap<IssueType, Vec<FixSuggestion>>,
20    optimization_history: Vec<OptimizationAttempt>,
21    knowledge_base: KnowledgeBase,
22}
23
24/// Common training and model issues
25#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum IssueType {
27    // Training Issues
28    VanishingGradients,
29    ExplodingGradients,
30    LearningRateProblems,
31    OverfittingDetected,
32    UnderfittingDetected,
33    TrainingStalled,
34    LossNotDecreasing,
35    UnstableTraining,
36    MemoryIssues,
37
38    // Model Architecture Issues
39    ModelTooLarge,
40    ModelTooSmall,
41    InappropriateArchitecture,
42    LayerMismatch,
43    ActivationProblems,
44
45    // Data Issues
46    DataImbalance,
47    DataLeakage,
48    InsufficientData,
49    DataQualityIssues,
50    BatchSizeProblems,
51
52    // Performance Issues
53    SlowTraining,
54    LowGpuUtilization,
55    MemoryBottleneck,
56    IoBottleneck,
57    ComputeBottleneck,
58
59    // Hyperparameter Issues
60    LearningRateTooHigh,
61    LearningRateTooLow,
62    BatchSizeTooLarge,
63    BatchSizeTooSmall,
64    RegularizationIssues,
65}
66
67/// Issue detection result
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DetectedIssue {
70    pub issue_type: IssueType,
71    pub severity: IssueSeverity,
72    pub confidence: f64,
73    pub description: String,
74    pub evidence: Vec<Evidence>,
75    pub metrics: HashMap<String, f64>,
76    pub detected_at: chrono::DateTime<chrono::Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum IssueSeverity {
81    Critical,
82    High,
83    Medium,
84    Low,
85    Info,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Evidence {
90    pub metric_name: String,
91    pub observed_value: f64,
92    pub expected_range: (f64, f64),
93    pub explanation: String,
94}
95
96/// Fix suggestion with implementation guidance
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FixSuggestion {
99    pub fix_id: String,
100    pub fix_type: FixType,
101    pub title: String,
102    pub description: String,
103    pub implementation_steps: Vec<String>,
104    pub expected_impact: ExpectedImpact,
105    pub priority: FixPriority,
106    pub estimated_effort: EstimatedEffort,
107    pub prerequisites: Vec<String>,
108    pub code_examples: Vec<CodeExample>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum FixType {
113    HyperparameterAdjustment,
114    ArchitectureChange,
115    TrainingProcedure,
116    DataProcessing,
117    OptimizationTechnique,
118    EnvironmentConfig,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ExpectedImpact {
123    pub performance_improvement: f64,
124    pub training_speed_improvement: f64,
125    pub stability_improvement: f64,
126    pub memory_usage_change: f64,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum FixPriority {
131    Critical,
132    High,
133    Medium,
134    Low,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub enum EstimatedEffort {
139    Trivial, // < 5 minutes
140    Easy,    // 5-30 minutes
141    Medium,  // 30 minutes - 2 hours
142    Hard,    // 2-8 hours
143    Complex, // > 8 hours
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CodeExample {
148    pub language: String,
149    pub code: String,
150    pub explanation: String,
151}
152
153/// Optimization attempt tracking
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct OptimizationAttempt {
156    pub attempt_id: String,
157    pub issue_addressed: IssueType,
158    pub fix_applied: String,
159    pub before_metrics: HashMap<String, f64>,
160    pub after_metrics: Option<HashMap<String, f64>>,
161    pub success: Option<bool>,
162    pub notes: String,
163    pub timestamp: chrono::DateTime<chrono::Utc>,
164}
165
166/// Knowledge base for common patterns and solutions
167#[derive(Debug)]
168#[allow(dead_code)]
169pub struct KnowledgeBase {
170    #[allow(dead_code)]
171    issue_patterns: HashMap<IssueType, IssuePattern>,
172    hyperparameter_recommendations: HashMap<String, HyperparameterAdvice>,
173    architecture_patterns: Vec<ArchitecturePattern>,
174    best_practices: HashMap<String, Vec<String>>,
175}
176
177#[derive(Debug, Clone)]
178pub struct IssuePattern {
179    pub symptoms: Vec<String>,
180    pub common_causes: Vec<String>,
181    pub diagnostic_metrics: Vec<String>,
182    pub typical_solutions: Vec<String>,
183}
184
185#[derive(Debug, Clone)]
186pub struct HyperparameterAdvice {
187    pub parameter_name: String,
188    pub recommended_range: (f64, f64),
189    pub tuning_strategy: String,
190    pub dependencies: Vec<String>,
191    pub common_mistakes: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
195pub struct ArchitecturePattern {
196    pub pattern_name: String,
197    pub use_cases: Vec<String>,
198    pub typical_layers: Vec<String>,
199    pub hyperparameter_suggestions: HashMap<String, f64>,
200    pub performance_characteristics: String,
201}
202
203/// Issue detector trait for modular detection
204pub trait IssueDetector: std::fmt::Debug {
205    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>>;
206    fn get_detector_name(&self) -> &str;
207    fn get_supported_issues(&self) -> Vec<IssueType>;
208}
209
210/// Context for issue detection
211#[derive(Debug)]
212pub struct DebugContext<'a> {
213    pub profiler_report: Option<&'a ProfilerReport>,
214    pub gradient_report: Option<&'a GradientDebugReport>,
215    pub anomaly_report: Option<&'a AnomalyDetectorReport>,
216    pub recent_metrics: &'a [DashboardMetrics],
217    pub training_duration: Duration,
218    pub model_info: Option<&'a ModelInfo>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ModelInfo {
223    pub model_type: String,
224    pub parameter_count: usize,
225    pub layer_count: usize,
226    pub architecture_details: HashMap<String, String>,
227}
228
229impl AutoDebugger {
230    /// Create new auto-debugger with default detectors
231    pub fn new(config: &DebugConfig) -> Self {
232        let mut auto_debugger = Self {
233            config: config.clone(),
234            issue_detectors: Vec::new(),
235            fix_suggestions: HashMap::new(),
236            optimization_history: Vec::new(),
237            knowledge_base: KnowledgeBase::new(),
238        };
239
240        // Register default detectors
241        auto_debugger.register_default_detectors();
242        auto_debugger.initialize_fix_suggestions();
243
244        auto_debugger
245    }
246
247    /// Register all default issue detectors
248    fn register_default_detectors(&mut self) {
249        self.issue_detectors.push(Box::new(GradientIssueDetector::new()));
250        self.issue_detectors.push(Box::new(TrainingIssueDetector::new()));
251        self.issue_detectors.push(Box::new(PerformanceIssueDetector::new()));
252        self.issue_detectors.push(Box::new(HyperparameterIssueDetector::new()));
253        self.issue_detectors.push(Box::new(ArchitectureIssueDetector::new()));
254        self.issue_detectors.push(Box::new(DataIssueDetector::new()));
255    }
256
257    /// Initialize fix suggestions for common issues
258    fn initialize_fix_suggestions(&mut self) {
259        // Vanishing gradients fixes
260        self.fix_suggestions.insert(
261            IssueType::VanishingGradients,
262            vec![
263                FixSuggestion {
264                    fix_id: "vg_001".to_string(),
265                    fix_type: FixType::ArchitectureChange,
266                    title: "Add Residual Connections".to_string(),
267                    description:
268                        "Implement skip connections to help gradients flow through deep networks"
269                            .to_string(),
270                    implementation_steps: vec![
271                        "Add residual blocks to your model architecture".to_string(),
272                        "Ensure input and output dimensions match for residual connections"
273                            .to_string(),
274                        "Consider using batch normalization within residual blocks".to_string(),
275                    ],
276                    expected_impact: ExpectedImpact {
277                        performance_improvement: 0.15,
278                        training_speed_improvement: 0.05,
279                        stability_improvement: 0.25,
280                        memory_usage_change: 0.02,
281                    },
282                    priority: FixPriority::High,
283                    estimated_effort: EstimatedEffort::Medium,
284                    prerequisites: vec!["Model architecture access".to_string()],
285                    code_examples: vec![CodeExample {
286                        language: "python".to_string(),
287                        code: r#"
288class ResidualBlock(nn.Module):
289    def __init__(self, channels):
290        super().__init__()
291        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
292        self.bn1 = nn.BatchNorm2d(channels)
293        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
294        self.bn2 = nn.BatchNorm2d(channels)
295
296    def forward(self, x):
297        residual = x
298        out = F.relu(self.bn1(self.conv1(x)))
299        out = self.bn2(self.conv2(out))
300        out += residual  # Skip connection
301        return F.relu(out)
302"#
303                        .to_string(),
304                        explanation: "Basic residual block implementation with skip connection"
305                            .to_string(),
306                    }],
307                },
308                FixSuggestion {
309                    fix_id: "vg_002".to_string(),
310                    fix_type: FixType::HyperparameterAdjustment,
311                    title: "Adjust Learning Rate".to_string(),
312                    description:
313                        "Increase learning rate to help gradients propagate more effectively"
314                            .to_string(),
315                    implementation_steps: vec![
316                        "Increase learning rate by 2-5x".to_string(),
317                        "Monitor training stability".to_string(),
318                        "Consider learning rate scheduling".to_string(),
319                    ],
320                    expected_impact: ExpectedImpact {
321                        performance_improvement: 0.08,
322                        training_speed_improvement: 0.10,
323                        stability_improvement: -0.05,
324                        memory_usage_change: 0.0,
325                    },
326                    priority: FixPriority::Medium,
327                    estimated_effort: EstimatedEffort::Trivial,
328                    prerequisites: vec![],
329                    code_examples: vec![CodeExample {
330                        language: "python".to_string(),
331                        code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
332                            .to_string(),
333                        explanation: "Increase learning rate to help overcome vanishing gradients"
334                            .to_string(),
335                    }],
336                },
337            ],
338        );
339
340        // Exploding gradients fixes
341        self.fix_suggestions.insert(
342            IssueType::ExplodingGradients,
343            vec![FixSuggestion {
344                fix_id: "eg_001".to_string(),
345                fix_type: FixType::TrainingProcedure,
346                title: "Apply Gradient Clipping".to_string(),
347                description: "Clip gradients to prevent explosion during backpropagation"
348                    .to_string(),
349                implementation_steps: vec![
350                    "Add gradient clipping to your training loop".to_string(),
351                    "Start with clip value of 1.0 and adjust based on results".to_string(),
352                    "Monitor gradient norms to ensure clipping is effective".to_string(),
353                ],
354                expected_impact: ExpectedImpact {
355                    performance_improvement: 0.10,
356                    training_speed_improvement: 0.0,
357                    stability_improvement: 0.30,
358                    memory_usage_change: 0.0,
359                },
360                priority: FixPriority::Critical,
361                estimated_effort: EstimatedEffort::Easy,
362                prerequisites: vec![],
363                code_examples: vec![CodeExample {
364                    language: "python".to_string(),
365                    code: "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
366                        .to_string(),
367                    explanation: "Clip gradients before optimizer step".to_string(),
368                }],
369            }],
370        );
371
372        // Learning rate issues
373        self.fix_suggestions.insert(
374            IssueType::LearningRateTooHigh,
375            vec![FixSuggestion {
376                fix_id: "lr_high_001".to_string(),
377                fix_type: FixType::HyperparameterAdjustment,
378                title: "Reduce Learning Rate".to_string(),
379                description: "Lower the learning rate to improve training stability".to_string(),
380                implementation_steps: vec![
381                    "Reduce learning rate by 2-10x".to_string(),
382                    "Consider learning rate scheduling".to_string(),
383                    "Monitor loss convergence".to_string(),
384                ],
385                expected_impact: ExpectedImpact {
386                    performance_improvement: 0.12,
387                    training_speed_improvement: -0.05,
388                    stability_improvement: 0.25,
389                    memory_usage_change: 0.0,
390                },
391                priority: FixPriority::High,
392                estimated_effort: EstimatedEffort::Trivial,
393                prerequisites: vec![],
394                code_examples: vec![CodeExample {
395                    language: "python".to_string(),
396                    code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)".to_string(),
397                    explanation: "Reduce learning rate for more stable training".to_string(),
398                }],
399            }],
400        );
401
402        // Performance issues
403        self.fix_suggestions.insert(
404            IssueType::LowGpuUtilization,
405            vec![FixSuggestion {
406                fix_id: "gpu_001".to_string(),
407                fix_type: FixType::HyperparameterAdjustment,
408                title: "Increase Batch Size".to_string(),
409                description: "Increase batch size to better utilize GPU compute capacity"
410                    .to_string(),
411                implementation_steps: vec![
412                    "Double the current batch size".to_string(),
413                    "Monitor memory usage to avoid OOM".to_string(),
414                    "Adjust learning rate proportionally".to_string(),
415                ],
416                expected_impact: ExpectedImpact {
417                    performance_improvement: 0.05,
418                    training_speed_improvement: 0.30,
419                    stability_improvement: 0.0,
420                    memory_usage_change: 0.20,
421                },
422                priority: FixPriority::Medium,
423                estimated_effort: EstimatedEffort::Easy,
424                prerequisites: vec!["Available GPU memory".to_string()],
425                code_examples: vec![CodeExample {
426                    language: "python".to_string(),
427                    code: "train_loader = DataLoader(dataset, batch_size=64, shuffle=True)"
428                        .to_string(),
429                    explanation: "Increase batch size to improve GPU utilization".to_string(),
430                }],
431            }],
432        );
433    }
434
435    /// Analyze debug context and detect issues
436    pub fn analyze_issues(&self, context: &DebugContext) -> Result<AutoDebugReport> {
437        let mut all_issues = Vec::new();
438
439        // Run all issue detectors
440        for detector in &self.issue_detectors {
441            match detector.detect_issues(context) {
442                Ok(mut issues) => all_issues.append(&mut issues),
443                Err(e) => {
444                    tracing::warn!(
445                        "Issue detector '{}' failed: {}",
446                        detector.get_detector_name(),
447                        e
448                    );
449                },
450            }
451        }
452
453        // Sort issues by severity and confidence
454        all_issues.sort_by(|a, b| {
455            let severity_order = |s: &IssueSeverity| match s {
456                IssueSeverity::Critical => 0,
457                IssueSeverity::High => 1,
458                IssueSeverity::Medium => 2,
459                IssueSeverity::Low => 3,
460                IssueSeverity::Info => 4,
461            };
462
463            let severity_cmp = severity_order(&a.severity).cmp(&severity_order(&b.severity));
464            if severity_cmp == std::cmp::Ordering::Equal {
465                b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
466            } else {
467                severity_cmp
468            }
469        });
470
471        // Generate fix recommendations
472        let fix_recommendations = self.generate_fix_recommendations(&all_issues);
473
474        // Generate hyperparameter recommendations
475        let hyperparameter_recommendations = self.generate_hyperparameter_recommendations(context);
476
477        // Generate architecture suggestions
478        let architecture_suggestions = self.generate_architecture_suggestions(context);
479
480        // Generate training recipe optimization
481        let training_recipe = self.generate_training_recipe_optimization(context);
482
483        Ok(AutoDebugReport {
484            detected_issues: all_issues,
485            fix_recommendations: fix_recommendations.clone(),
486            hyperparameter_recommendations,
487            architecture_suggestions,
488            training_recipe,
489            analysis_summary: self.generate_analysis_summary(&fix_recommendations),
490            confidence_score: self.calculate_overall_confidence(&fix_recommendations),
491        })
492    }
493
494    /// Generate fix recommendations for detected issues
495    fn generate_fix_recommendations(&self, issues: &[DetectedIssue]) -> Vec<FixRecommendation> {
496        let mut recommendations = Vec::new();
497
498        for issue in issues {
499            if let Some(suggestions) = self.fix_suggestions.get(&issue.issue_type) {
500                for suggestion in suggestions {
501                    recommendations.push(FixRecommendation {
502                        issue: issue.clone(),
503                        fix_suggestion: suggestion.clone(),
504                        confidence: issue.confidence * 0.9, // Slightly reduce confidence
505                        urgency: self.calculate_urgency(issue),
506                    });
507                }
508            }
509        }
510
511        // Sort by urgency and confidence
512        recommendations.sort_by(|a, b| {
513            let urgency_cmp =
514                b.urgency.partial_cmp(&a.urgency).unwrap_or(std::cmp::Ordering::Equal);
515            if urgency_cmp == std::cmp::Ordering::Equal {
516                b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
517            } else {
518                urgency_cmp
519            }
520        });
521
522        recommendations
523    }
524
525    fn calculate_urgency(&self, issue: &DetectedIssue) -> f64 {
526        let severity_multiplier = match issue.severity {
527            IssueSeverity::Critical => 1.0,
528            IssueSeverity::High => 0.8,
529            IssueSeverity::Medium => 0.6,
530            IssueSeverity::Low => 0.4,
531            IssueSeverity::Info => 0.2,
532        };
533
534        issue.confidence * severity_multiplier
535    }
536
537    /// Generate hyperparameter recommendations
538    fn generate_hyperparameter_recommendations(
539        &self,
540        context: &DebugContext,
541    ) -> Vec<HyperparameterRecommendation> {
542        let mut recommendations = Vec::new();
543
544        // Learning rate recommendations
545        if let Some(metrics) = context.recent_metrics.last() {
546            if let Some(loss) = metrics.loss {
547                if loss > 1.0 {
548                    recommendations.push(HyperparameterRecommendation {
549                        parameter: "learning_rate".to_string(),
550                        current_value: None,
551                        recommended_value: 0.001,
552                        reason: "High loss suggests learning rate might be too low".to_string(),
553                        confidence: 0.7,
554                    });
555                }
556            }
557        }
558
559        // Batch size recommendations based on GPU utilization
560        if let Some(_profiler_report) = context.profiler_report {
561            // Simplified logic - in practice would analyze detailed metrics
562            recommendations.push(HyperparameterRecommendation {
563                parameter: "batch_size".to_string(),
564                current_value: None,
565                recommended_value: 32.0,
566                reason: "Optimize batch size for better GPU utilization".to_string(),
567                confidence: 0.6,
568            });
569        }
570
571        recommendations
572    }
573
574    /// Generate architecture suggestions
575    fn generate_architecture_suggestions(
576        &self,
577        context: &DebugContext,
578    ) -> Vec<ArchitectureSuggestion> {
579        let mut suggestions = Vec::new();
580
581        // Analyze model size vs performance
582        if let Some(model_info) = context.model_info {
583            if model_info.parameter_count > 100_000_000 {
584                suggestions.push(ArchitectureSuggestion {
585                    suggestion_type: "model_compression".to_string(),
586                    title: "Consider Model Compression".to_string(),
587                    description: "Large model may benefit from pruning or distillation".to_string(),
588                    impact_assessment: "Reduce memory usage by 20-50% with minimal accuracy loss"
589                        .to_string(),
590                    implementation_difficulty: "Medium".to_string(),
591                });
592            }
593
594            if model_info.layer_count > 50 {
595                suggestions.push(ArchitectureSuggestion {
596                    suggestion_type: "depth_optimization".to_string(),
597                    title: "Optimize Network Depth".to_string(),
598                    description: "Very deep network may suffer from gradient flow issues"
599                        .to_string(),
600                    impact_assessment: "Improve training stability and convergence speed"
601                        .to_string(),
602                    implementation_difficulty: "High".to_string(),
603                });
604            }
605        }
606
607        suggestions
608    }
609
610    /// Generate training recipe optimization
611    fn generate_training_recipe_optimization(
612        &self,
613        context: &DebugContext,
614    ) -> TrainingRecipeOptimization {
615        let mut optimizations = Vec::new();
616
617        // Analyze training duration and suggest optimizations
618        if context.training_duration > Duration::from_secs(3600) {
619            optimizations
620                .push("Consider learning rate scheduling to speed up convergence".to_string());
621            optimizations.push("Implement early stopping to avoid overtraining".to_string());
622        }
623
624        // Analyze recent metrics for training recipe suggestions
625        if context.recent_metrics.len() > 10 {
626            let recent_losses: Vec<f64> =
627                context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
628
629            if recent_losses.len() >= 5 {
630                let variance = self.calculate_variance(&recent_losses);
631                if variance > 0.1 {
632                    optimizations.push(
633                        "Training loss is unstable - consider reducing learning rate".to_string(),
634                    );
635                }
636            }
637        }
638
639        TrainingRecipeOptimization {
640            recommended_optimizations: optimizations,
641            training_schedule: TrainingSchedule {
642                warmup_steps: 1000,
643                learning_rate_schedule: "cosine_annealing".to_string(),
644                batch_size_schedule: "constant".to_string(),
645                early_stopping: true,
646                checkpoint_frequency: 1000,
647            },
648            data_strategy: DataStrategy {
649                data_augmentation: vec!["horizontal_flip".to_string(), "random_crop".to_string()],
650                sampling_strategy: "balanced".to_string(),
651                preprocessing_optimizations: vec![
652                    "normalization".to_string(),
653                    "standardization".to_string(),
654                ],
655            },
656        }
657    }
658
659    fn calculate_variance(&self, values: &[f64]) -> f64 {
660        if values.len() < 2 {
661            return 0.0;
662        }
663
664        let mean = values.iter().sum::<f64>() / values.len() as f64;
665        let variance =
666            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
667        variance
668    }
669
670    fn generate_analysis_summary(&self, recommendations: &[FixRecommendation]) -> String {
671        let critical_count = recommendations
672            .iter()
673            .filter(|r| matches!(r.issue.severity, IssueSeverity::Critical))
674            .count();
675
676        let high_count = recommendations
677            .iter()
678            .filter(|r| matches!(r.issue.severity, IssueSeverity::High))
679            .count();
680
681        if critical_count > 0 {
682            format!("Found {} critical issues requiring immediate attention. {} high-priority issues also detected.",
683                   critical_count, high_count)
684        } else if high_count > 0 {
685            format!(
686                "Found {} high-priority issues that should be addressed soon.",
687                high_count
688            )
689        } else if !recommendations.is_empty() {
690            "Found some optimization opportunities to improve training performance.".to_string()
691        } else {
692            "No significant issues detected. Training appears to be proceeding normally."
693                .to_string()
694        }
695    }
696
697    fn calculate_overall_confidence(&self, recommendations: &[FixRecommendation]) -> f64 {
698        if recommendations.is_empty() {
699            return 1.0;
700        }
701
702        let sum_confidence: f64 = recommendations.iter().map(|r| r.confidence).sum();
703        sum_confidence / recommendations.len() as f64
704    }
705
706    /// Record optimization attempt for learning
707    pub fn record_optimization_attempt(&mut self, attempt: OptimizationAttempt) {
708        self.optimization_history.push(attempt);
709
710        // Keep only recent attempts to prevent unbounded growth
711        if self.optimization_history.len() > 1000 {
712            self.optimization_history.drain(0..500);
713        }
714    }
715
716    /// Get optimization history for analysis
717    pub fn get_optimization_history(&self) -> &[OptimizationAttempt] {
718        &self.optimization_history
719    }
720}
721
722// Issue detector implementations
723
724#[derive(Debug)]
725struct GradientIssueDetector;
726
727impl GradientIssueDetector {
728    fn new() -> Self {
729        Self
730    }
731}
732
733impl IssueDetector for GradientIssueDetector {
734    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
735        let mut issues = Vec::new();
736
737        if let Some(gradient_report) = context.gradient_report {
738            // Check for vanishing gradients
739            if gradient_report.has_vanishing_gradients() {
740                issues.push(DetectedIssue {
741                    issue_type: IssueType::VanishingGradients,
742                    severity: IssueSeverity::High,
743                    confidence: 0.9,
744                    description: "Vanishing gradients detected in multiple layers".to_string(),
745                    evidence: vec![Evidence {
746                        metric_name: "gradient_norm".to_string(),
747                        observed_value: 0.001,
748                        expected_range: (0.01, 1.0),
749                        explanation: "Gradient norms are significantly below normal range"
750                            .to_string(),
751                    }],
752                    metrics: HashMap::new(),
753                    detected_at: chrono::Utc::now(),
754                });
755            }
756
757            // Check for exploding gradients
758            if gradient_report.has_exploding_gradients() {
759                issues.push(DetectedIssue {
760                    issue_type: IssueType::ExplodingGradients,
761                    severity: IssueSeverity::Critical,
762                    confidence: 0.95,
763                    description: "Exploding gradients detected - training instability likely"
764                        .to_string(),
765                    evidence: vec![Evidence {
766                        metric_name: "gradient_norm".to_string(),
767                        observed_value: 100.0,
768                        expected_range: (0.01, 10.0),
769                        explanation: "Gradient norms are extremely high".to_string(),
770                    }],
771                    metrics: HashMap::new(),
772                    detected_at: chrono::Utc::now(),
773                });
774            }
775        }
776
777        Ok(issues)
778    }
779
780    fn get_detector_name(&self) -> &str {
781        "GradientIssueDetector"
782    }
783
784    fn get_supported_issues(&self) -> Vec<IssueType> {
785        vec![IssueType::VanishingGradients, IssueType::ExplodingGradients]
786    }
787}
788
789#[derive(Debug)]
790struct TrainingIssueDetector;
791
792impl TrainingIssueDetector {
793    fn new() -> Self {
794        Self
795    }
796}
797
798impl IssueDetector for TrainingIssueDetector {
799    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
800        let mut issues = Vec::new();
801
802        // Analyze recent training metrics
803        if context.recent_metrics.len() >= 10 {
804            let recent_losses: Vec<f64> =
805                context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
806
807            if recent_losses.len() >= 5 {
808                // Check for stalled training
809                let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
810                    / (recent_losses.len() / 2) as f64;
811                let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
812                    / (recent_losses.len() - recent_losses.len() / 2) as f64;
813
814                if (first_half_avg - second_half_avg).abs() / first_half_avg < 0.01 {
815                    issues.push(DetectedIssue {
816                        issue_type: IssueType::TrainingStalled,
817                        severity: IssueSeverity::Medium,
818                        confidence: 0.8,
819                        description: "Training appears to have stalled - loss not decreasing"
820                            .to_string(),
821                        evidence: vec![Evidence {
822                            metric_name: "loss_change".to_string(),
823                            observed_value: (first_half_avg - second_half_avg).abs()
824                                / first_half_avg,
825                            expected_range: (0.05, 1.0),
826                            explanation: "Loss change is below expected threshold".to_string(),
827                        }],
828                        metrics: HashMap::new(),
829                        detected_at: chrono::Utc::now(),
830                    });
831                }
832            }
833        }
834
835        Ok(issues)
836    }
837
838    fn get_detector_name(&self) -> &str {
839        "TrainingIssueDetector"
840    }
841
842    fn get_supported_issues(&self) -> Vec<IssueType> {
843        vec![
844            IssueType::TrainingStalled,
845            IssueType::LossNotDecreasing,
846            IssueType::UnstableTraining,
847        ]
848    }
849}
850
851#[derive(Debug)]
852struct PerformanceIssueDetector;
853
854impl PerformanceIssueDetector {
855    fn new() -> Self {
856        Self
857    }
858}
859
860impl IssueDetector for PerformanceIssueDetector {
861    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
862        let mut issues = Vec::new();
863
864        // Check GPU utilization
865        if let Some(metrics) = context.recent_metrics.last() {
866            if let Some(gpu_util) = metrics.gpu_utilization {
867                if gpu_util < 0.5 {
868                    issues.push(DetectedIssue {
869                        issue_type: IssueType::LowGpuUtilization,
870                        severity: IssueSeverity::Medium,
871                        confidence: 0.8,
872                        description:
873                            "Low GPU utilization detected - compute resources underutilized"
874                                .to_string(),
875                        evidence: vec![Evidence {
876                            metric_name: "gpu_utilization".to_string(),
877                            observed_value: gpu_util,
878                            expected_range: (0.7, 1.0),
879                            explanation: "GPU utilization is below optimal range".to_string(),
880                        }],
881                        metrics: HashMap::new(),
882                        detected_at: chrono::Utc::now(),
883                    });
884                }
885            }
886        }
887
888        Ok(issues)
889    }
890
891    fn get_detector_name(&self) -> &str {
892        "PerformanceIssueDetector"
893    }
894
895    fn get_supported_issues(&self) -> Vec<IssueType> {
896        vec![
897            IssueType::LowGpuUtilization,
898            IssueType::SlowTraining,
899            IssueType::MemoryBottleneck,
900        ]
901    }
902}
903
904#[derive(Debug)]
905struct HyperparameterIssueDetector;
906
907impl HyperparameterIssueDetector {
908    fn new() -> Self {
909        Self
910    }
911}
912
913impl IssueDetector for HyperparameterIssueDetector {
914    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
915        let mut issues = Vec::new();
916
917        if let Some(metrics) = context.recent_metrics.last() {
918            // Check learning rate issues
919            if let Some(lr) = metrics.learning_rate {
920                if lr > 0.1 {
921                    issues.push(DetectedIssue {
922                        issue_type: IssueType::LearningRateTooHigh,
923                        severity: IssueSeverity::High,
924                        confidence: 0.7,
925                        description:
926                            "Learning rate appears too high - may cause training instability"
927                                .to_string(),
928                        evidence: vec![Evidence {
929                            metric_name: "learning_rate".to_string(),
930                            observed_value: lr,
931                            expected_range: (0.0001, 0.01),
932                            explanation: "Learning rate is above typical range".to_string(),
933                        }],
934                        metrics: HashMap::new(),
935                        detected_at: chrono::Utc::now(),
936                    });
937                } else if lr < 0.00001 {
938                    issues.push(DetectedIssue {
939                        issue_type: IssueType::LearningRateTooLow,
940                        severity: IssueSeverity::Medium,
941                        confidence: 0.6,
942                        description: "Learning rate might be too low - training could be slow"
943                            .to_string(),
944                        evidence: vec![Evidence {
945                            metric_name: "learning_rate".to_string(),
946                            observed_value: lr,
947                            expected_range: (0.0001, 0.01),
948                            explanation: "Learning rate is below typical range".to_string(),
949                        }],
950                        metrics: HashMap::new(),
951                        detected_at: chrono::Utc::now(),
952                    });
953                }
954            }
955        }
956
957        Ok(issues)
958    }
959
960    fn get_detector_name(&self) -> &str {
961        "HyperparameterIssueDetector"
962    }
963
964    fn get_supported_issues(&self) -> Vec<IssueType> {
965        vec![
966            IssueType::LearningRateTooHigh,
967            IssueType::LearningRateTooLow,
968        ]
969    }
970}
971
972#[derive(Debug)]
973struct ArchitectureIssueDetector;
974
975impl ArchitectureIssueDetector {
976    fn new() -> Self {
977        Self
978    }
979}
980
981impl IssueDetector for ArchitectureIssueDetector {
982    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
983        let mut issues = Vec::new();
984
985        if let Some(model_info) = context.model_info {
986            // Check model size
987            if model_info.parameter_count > 1_000_000_000 {
988                issues.push(DetectedIssue {
989                    issue_type: IssueType::ModelTooLarge,
990                    severity: IssueSeverity::Medium,
991                    confidence: 0.6,
992                    description:
993                        "Model has very large number of parameters - consider optimization"
994                            .to_string(),
995                    evidence: vec![Evidence {
996                        metric_name: "parameter_count".to_string(),
997                        observed_value: model_info.parameter_count as f64,
998                        expected_range: (1_000_000.0, 100_000_000.0),
999                        explanation: "Parameter count is extremely high".to_string(),
1000                    }],
1001                    metrics: HashMap::new(),
1002                    detected_at: chrono::Utc::now(),
1003                });
1004            }
1005
1006            if model_info.layer_count > 100 {
1007                issues.push(DetectedIssue {
1008                    issue_type: IssueType::InappropriateArchitecture,
1009                    severity: IssueSeverity::Low,
1010                    confidence: 0.5,
1011                    description: "Very deep model - may have gradient flow issues".to_string(),
1012                    evidence: vec![Evidence {
1013                        metric_name: "layer_count".to_string(),
1014                        observed_value: model_info.layer_count as f64,
1015                        expected_range: (10.0, 50.0),
1016                        explanation: "Layer count is very high".to_string(),
1017                    }],
1018                    metrics: HashMap::new(),
1019                    detected_at: chrono::Utc::now(),
1020                });
1021            }
1022        }
1023
1024        Ok(issues)
1025    }
1026
1027    fn get_detector_name(&self) -> &str {
1028        "ArchitectureIssueDetector"
1029    }
1030
1031    fn get_supported_issues(&self) -> Vec<IssueType> {
1032        vec![
1033            IssueType::ModelTooLarge,
1034            IssueType::InappropriateArchitecture,
1035        ]
1036    }
1037}
1038
1039#[derive(Debug)]
1040struct DataIssueDetector;
1041
1042impl DataIssueDetector {
1043    fn new() -> Self {
1044        Self
1045    }
1046}
1047
1048impl IssueDetector for DataIssueDetector {
1049    fn detect_issues(&self, _context: &DebugContext) -> Result<Vec<DetectedIssue>> {
1050        // Placeholder for data issue detection
1051        // In practice, would analyze data loading patterns, batch composition, etc.
1052        Ok(Vec::new())
1053    }
1054
1055    fn get_detector_name(&self) -> &str {
1056        "DataIssueDetector"
1057    }
1058
1059    fn get_supported_issues(&self) -> Vec<IssueType> {
1060        vec![
1061            IssueType::DataImbalance,
1062            IssueType::BatchSizeProblems,
1063            IssueType::InsufficientData,
1064        ]
1065    }
1066}
1067
1068impl Default for KnowledgeBase {
1069    fn default() -> Self {
1070        Self::new()
1071    }
1072}
1073
1074impl KnowledgeBase {
1075    pub fn new() -> Self {
1076        Self {
1077            issue_patterns: HashMap::new(),
1078            hyperparameter_recommendations: HashMap::new(),
1079            architecture_patterns: Vec::new(),
1080            best_practices: HashMap::new(),
1081        }
1082    }
1083}
1084
1085// Report structures
1086
1087#[derive(Debug, Serialize, Deserialize)]
1088pub struct AutoDebugReport {
1089    pub detected_issues: Vec<DetectedIssue>,
1090    pub fix_recommendations: Vec<FixRecommendation>,
1091    pub hyperparameter_recommendations: Vec<HyperparameterRecommendation>,
1092    pub architecture_suggestions: Vec<ArchitectureSuggestion>,
1093    pub training_recipe: TrainingRecipeOptimization,
1094    pub analysis_summary: String,
1095    pub confidence_score: f64,
1096}
1097
1098#[derive(Debug, Clone, Serialize, Deserialize)]
1099pub struct FixRecommendation {
1100    pub issue: DetectedIssue,
1101    pub fix_suggestion: FixSuggestion,
1102    pub confidence: f64,
1103    pub urgency: f64,
1104}
1105
1106#[derive(Debug, Clone, Serialize, Deserialize)]
1107pub struct HyperparameterRecommendation {
1108    pub parameter: String,
1109    pub current_value: Option<f64>,
1110    pub recommended_value: f64,
1111    pub reason: String,
1112    pub confidence: f64,
1113}
1114
1115#[derive(Debug, Clone, Serialize, Deserialize)]
1116pub struct ArchitectureSuggestion {
1117    pub suggestion_type: String,
1118    pub title: String,
1119    pub description: String,
1120    pub impact_assessment: String,
1121    pub implementation_difficulty: String,
1122}
1123
1124#[derive(Debug, Clone, Serialize, Deserialize)]
1125pub struct TrainingRecipeOptimization {
1126    pub recommended_optimizations: Vec<String>,
1127    pub training_schedule: TrainingSchedule,
1128    pub data_strategy: DataStrategy,
1129}
1130
1131#[derive(Debug, Clone, Serialize, Deserialize)]
1132pub struct TrainingSchedule {
1133    pub warmup_steps: u32,
1134    pub learning_rate_schedule: String,
1135    pub batch_size_schedule: String,
1136    pub early_stopping: bool,
1137    pub checkpoint_frequency: u32,
1138}
1139
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct DataStrategy {
1142    pub data_augmentation: Vec<String>,
1143    pub sampling_strategy: String,
1144    pub preprocessing_optimizations: Vec<String>,
1145}