Skip to main content

trustformers_debug/
auto_debugger.rs

1//! Automated debugging system for common issues and optimization suggestions
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{
9    AnomalyDetectorReport, DashboardMetrics, DebugConfig, GradientDebugReport, ProfilerReport,
10};
11
12/// Automated debugging system
13#[derive(Debug)]
14#[allow(dead_code)]
15pub struct AutoDebugger {
16    #[allow(dead_code)]
17    config: DebugConfig,
18    issue_detectors: Vec<Box<dyn IssueDetector>>,
19    fix_suggestions: HashMap<IssueType, Vec<FixSuggestion>>,
20    optimization_history: Vec<OptimizationAttempt>,
21    knowledge_base: KnowledgeBase,
22}
23
24/// Common training and model issues
25#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum IssueType {
27    // Training Issues
28    VanishingGradients,
29    ExplodingGradients,
30    LearningRateProblems,
31    OverfittingDetected,
32    UnderfittingDetected,
33    TrainingStalled,
34    LossNotDecreasing,
35    UnstableTraining,
36    MemoryIssues,
37
38    // Model Architecture Issues
39    ModelTooLarge,
40    ModelTooSmall,
41    InappropriateArchitecture,
42    LayerMismatch,
43    ActivationProblems,
44
45    // Data Issues
46    DataImbalance,
47    DataLeakage,
48    InsufficientData,
49    DataQualityIssues,
50    BatchSizeProblems,
51
52    // Performance Issues
53    SlowTraining,
54    LowGpuUtilization,
55    MemoryBottleneck,
56    IoBottleneck,
57    ComputeBottleneck,
58
59    // Hyperparameter Issues
60    LearningRateTooHigh,
61    LearningRateTooLow,
62    BatchSizeTooLarge,
63    BatchSizeTooSmall,
64    RegularizationIssues,
65}
66
67/// Issue detection result
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DetectedIssue {
70    pub issue_type: IssueType,
71    pub severity: IssueSeverity,
72    pub confidence: f64,
73    pub description: String,
74    pub evidence: Vec<Evidence>,
75    pub metrics: HashMap<String, f64>,
76    pub detected_at: chrono::DateTime<chrono::Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum IssueSeverity {
81    Critical,
82    High,
83    Medium,
84    Low,
85    Info,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Evidence {
90    pub metric_name: String,
91    pub observed_value: f64,
92    pub expected_range: (f64, f64),
93    pub explanation: String,
94}
95
96/// Fix suggestion with implementation guidance
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FixSuggestion {
99    pub fix_id: String,
100    pub fix_type: FixType,
101    pub title: String,
102    pub description: String,
103    pub implementation_steps: Vec<String>,
104    pub expected_impact: ExpectedImpact,
105    pub priority: FixPriority,
106    pub estimated_effort: EstimatedEffort,
107    pub prerequisites: Vec<String>,
108    pub code_examples: Vec<CodeExample>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum FixType {
113    HyperparameterAdjustment,
114    ArchitectureChange,
115    TrainingProcedure,
116    DataProcessing,
117    OptimizationTechnique,
118    EnvironmentConfig,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ExpectedImpact {
123    pub performance_improvement: f64,
124    pub training_speed_improvement: f64,
125    pub stability_improvement: f64,
126    pub memory_usage_change: f64,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum FixPriority {
131    Critical,
132    High,
133    Medium,
134    Low,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub enum EstimatedEffort {
139    Trivial, // < 5 minutes
140    Easy,    // 5-30 minutes
141    Medium,  // 30 minutes - 2 hours
142    Hard,    // 2-8 hours
143    Complex, // > 8 hours
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CodeExample {
148    pub language: String,
149    pub code: String,
150    pub explanation: String,
151}
152
153/// Optimization attempt tracking
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct OptimizationAttempt {
156    pub attempt_id: String,
157    pub issue_addressed: IssueType,
158    pub fix_applied: String,
159    pub before_metrics: HashMap<String, f64>,
160    pub after_metrics: Option<HashMap<String, f64>>,
161    pub success: Option<bool>,
162    pub notes: String,
163    pub timestamp: chrono::DateTime<chrono::Utc>,
164}
165
166/// Knowledge base for common patterns and solutions
167#[derive(Debug)]
168#[allow(dead_code)]
169pub struct KnowledgeBase {
170    #[allow(dead_code)]
171    issue_patterns: HashMap<IssueType, IssuePattern>,
172    hyperparameter_recommendations: HashMap<String, HyperparameterAdvice>,
173    architecture_patterns: Vec<ArchitecturePattern>,
174    best_practices: HashMap<String, Vec<String>>,
175}
176
177#[derive(Debug, Clone)]
178pub struct IssuePattern {
179    pub symptoms: Vec<String>,
180    pub common_causes: Vec<String>,
181    pub diagnostic_metrics: Vec<String>,
182    pub typical_solutions: Vec<String>,
183}
184
185#[derive(Debug, Clone)]
186pub struct HyperparameterAdvice {
187    pub parameter_name: String,
188    pub recommended_range: (f64, f64),
189    pub tuning_strategy: String,
190    pub dependencies: Vec<String>,
191    pub common_mistakes: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
195pub struct ArchitecturePattern {
196    pub pattern_name: String,
197    pub use_cases: Vec<String>,
198    pub typical_layers: Vec<String>,
199    pub hyperparameter_suggestions: HashMap<String, f64>,
200    pub performance_characteristics: String,
201}
202
203/// Issue detector trait for modular detection
204pub trait IssueDetector: std::fmt::Debug {
205    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>>;
206    fn get_detector_name(&self) -> &str;
207    fn get_supported_issues(&self) -> Vec<IssueType>;
208}
209
210/// Context for issue detection
211#[derive(Debug)]
212pub struct DebugContext<'a> {
213    pub profiler_report: Option<&'a ProfilerReport>,
214    pub gradient_report: Option<&'a GradientDebugReport>,
215    pub anomaly_report: Option<&'a AnomalyDetectorReport>,
216    pub recent_metrics: &'a [DashboardMetrics],
217    pub training_duration: Duration,
218    pub model_info: Option<&'a ModelInfo>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ModelInfo {
223    pub model_type: String,
224    pub parameter_count: usize,
225    pub layer_count: usize,
226    pub architecture_details: HashMap<String, String>,
227}
228
229impl AutoDebugger {
230    /// Create new auto-debugger with default detectors
231    pub fn new(config: &DebugConfig) -> Self {
232        let mut auto_debugger = Self {
233            config: config.clone(),
234            issue_detectors: Vec::new(),
235            fix_suggestions: HashMap::new(),
236            optimization_history: Vec::new(),
237            knowledge_base: KnowledgeBase::new(),
238        };
239
240        // Register default detectors
241        auto_debugger.register_default_detectors();
242        auto_debugger.initialize_fix_suggestions();
243
244        auto_debugger
245    }
246
247    /// Register all default issue detectors
248    fn register_default_detectors(&mut self) {
249        self.issue_detectors.push(Box::new(GradientIssueDetector::new()));
250        self.issue_detectors.push(Box::new(TrainingIssueDetector::new()));
251        self.issue_detectors.push(Box::new(PerformanceIssueDetector::new()));
252        self.issue_detectors.push(Box::new(HyperparameterIssueDetector::new()));
253        self.issue_detectors.push(Box::new(ArchitectureIssueDetector::new()));
254        self.issue_detectors.push(Box::new(DataIssueDetector::new()));
255    }
256
257    /// Initialize fix suggestions for common issues
258    fn initialize_fix_suggestions(&mut self) {
259        // Vanishing gradients fixes
260        self.fix_suggestions.insert(
261            IssueType::VanishingGradients,
262            vec![
263                FixSuggestion {
264                    fix_id: "vg_001".to_string(),
265                    fix_type: FixType::ArchitectureChange,
266                    title: "Add Residual Connections".to_string(),
267                    description:
268                        "Implement skip connections to help gradients flow through deep networks"
269                            .to_string(),
270                    implementation_steps: vec![
271                        "Add residual blocks to your model architecture".to_string(),
272                        "Ensure input and output dimensions match for residual connections"
273                            .to_string(),
274                        "Consider using batch normalization within residual blocks".to_string(),
275                    ],
276                    expected_impact: ExpectedImpact {
277                        performance_improvement: 0.15,
278                        training_speed_improvement: 0.05,
279                        stability_improvement: 0.25,
280                        memory_usage_change: 0.02,
281                    },
282                    priority: FixPriority::High,
283                    estimated_effort: EstimatedEffort::Medium,
284                    prerequisites: vec!["Model architecture access".to_string()],
285                    code_examples: vec![CodeExample {
286                        language: "python".to_string(),
287                        code: r#"
288class ResidualBlock(nn.Module):
289    def __init__(self, channels):
290        super().__init__()
291        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
292        self.bn1 = nn.BatchNorm2d(channels)
293        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
294        self.bn2 = nn.BatchNorm2d(channels)
295
296    def forward(self, x):
297        residual = x
298        out = F.relu(self.bn1(self.conv1(x)))
299        out = self.bn2(self.conv2(out))
300        out += residual  # Skip connection
301        return F.relu(out)
302"#
303                        .to_string(),
304                        explanation: "Basic residual block implementation with skip connection"
305                            .to_string(),
306                    }],
307                },
308                FixSuggestion {
309                    fix_id: "vg_002".to_string(),
310                    fix_type: FixType::HyperparameterAdjustment,
311                    title: "Adjust Learning Rate".to_string(),
312                    description:
313                        "Increase learning rate to help gradients propagate more effectively"
314                            .to_string(),
315                    implementation_steps: vec![
316                        "Increase learning rate by 2-5x".to_string(),
317                        "Monitor training stability".to_string(),
318                        "Consider learning rate scheduling".to_string(),
319                    ],
320                    expected_impact: ExpectedImpact {
321                        performance_improvement: 0.08,
322                        training_speed_improvement: 0.10,
323                        stability_improvement: -0.05,
324                        memory_usage_change: 0.0,
325                    },
326                    priority: FixPriority::Medium,
327                    estimated_effort: EstimatedEffort::Trivial,
328                    prerequisites: vec![],
329                    code_examples: vec![CodeExample {
330                        language: "python".to_string(),
331                        code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
332                            .to_string(),
333                        explanation: "Increase learning rate to help overcome vanishing gradients"
334                            .to_string(),
335                    }],
336                },
337            ],
338        );
339
340        // Exploding gradients fixes
341        self.fix_suggestions.insert(
342            IssueType::ExplodingGradients,
343            vec![FixSuggestion {
344                fix_id: "eg_001".to_string(),
345                fix_type: FixType::TrainingProcedure,
346                title: "Apply Gradient Clipping".to_string(),
347                description: "Clip gradients to prevent explosion during backpropagation"
348                    .to_string(),
349                implementation_steps: vec![
350                    "Add gradient clipping to your training loop".to_string(),
351                    "Start with clip value of 1.0 and adjust based on results".to_string(),
352                    "Monitor gradient norms to ensure clipping is effective".to_string(),
353                ],
354                expected_impact: ExpectedImpact {
355                    performance_improvement: 0.10,
356                    training_speed_improvement: 0.0,
357                    stability_improvement: 0.30,
358                    memory_usage_change: 0.0,
359                },
360                priority: FixPriority::Critical,
361                estimated_effort: EstimatedEffort::Easy,
362                prerequisites: vec![],
363                code_examples: vec![CodeExample {
364                    language: "python".to_string(),
365                    code: "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
366                        .to_string(),
367                    explanation: "Clip gradients before optimizer step".to_string(),
368                }],
369            }],
370        );
371
372        // Learning rate issues
373        self.fix_suggestions.insert(
374            IssueType::LearningRateTooHigh,
375            vec![FixSuggestion {
376                fix_id: "lr_high_001".to_string(),
377                fix_type: FixType::HyperparameterAdjustment,
378                title: "Reduce Learning Rate".to_string(),
379                description: "Lower the learning rate to improve training stability".to_string(),
380                implementation_steps: vec![
381                    "Reduce learning rate by 2-10x".to_string(),
382                    "Consider learning rate scheduling".to_string(),
383                    "Monitor loss convergence".to_string(),
384                ],
385                expected_impact: ExpectedImpact {
386                    performance_improvement: 0.12,
387                    training_speed_improvement: -0.05,
388                    stability_improvement: 0.25,
389                    memory_usage_change: 0.0,
390                },
391                priority: FixPriority::High,
392                estimated_effort: EstimatedEffort::Trivial,
393                prerequisites: vec![],
394                code_examples: vec![CodeExample {
395                    language: "python".to_string(),
396                    code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)".to_string(),
397                    explanation: "Reduce learning rate for more stable training".to_string(),
398                }],
399            }],
400        );
401
402        // Performance issues
403        self.fix_suggestions.insert(
404            IssueType::LowGpuUtilization,
405            vec![FixSuggestion {
406                fix_id: "gpu_001".to_string(),
407                fix_type: FixType::HyperparameterAdjustment,
408                title: "Increase Batch Size".to_string(),
409                description: "Increase batch size to better utilize GPU compute capacity"
410                    .to_string(),
411                implementation_steps: vec![
412                    "Double the current batch size".to_string(),
413                    "Monitor memory usage to avoid OOM".to_string(),
414                    "Adjust learning rate proportionally".to_string(),
415                ],
416                expected_impact: ExpectedImpact {
417                    performance_improvement: 0.05,
418                    training_speed_improvement: 0.30,
419                    stability_improvement: 0.0,
420                    memory_usage_change: 0.20,
421                },
422                priority: FixPriority::Medium,
423                estimated_effort: EstimatedEffort::Easy,
424                prerequisites: vec!["Available GPU memory".to_string()],
425                code_examples: vec![CodeExample {
426                    language: "python".to_string(),
427                    code: "train_loader = DataLoader(dataset, batch_size=64, shuffle=True)"
428                        .to_string(),
429                    explanation: "Increase batch size to improve GPU utilization".to_string(),
430                }],
431            }],
432        );
433    }
434
435    /// Analyze debug context and detect issues
436    pub fn analyze_issues(&self, context: &DebugContext) -> Result<AutoDebugReport> {
437        let mut all_issues = Vec::new();
438
439        // Run all issue detectors
440        for detector in &self.issue_detectors {
441            match detector.detect_issues(context) {
442                Ok(mut issues) => all_issues.append(&mut issues),
443                Err(e) => {
444                    tracing::warn!(
445                        "Issue detector '{}' failed: {}",
446                        detector.get_detector_name(),
447                        e
448                    );
449                },
450            }
451        }
452
453        // Sort issues by severity and confidence
454        all_issues.sort_by(|a, b| {
455            let severity_order = |s: &IssueSeverity| match s {
456                IssueSeverity::Critical => 0,
457                IssueSeverity::High => 1,
458                IssueSeverity::Medium => 2,
459                IssueSeverity::Low => 3,
460                IssueSeverity::Info => 4,
461            };
462
463            let severity_cmp = severity_order(&a.severity).cmp(&severity_order(&b.severity));
464            if severity_cmp == std::cmp::Ordering::Equal {
465                b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
466            } else {
467                severity_cmp
468            }
469        });
470
471        // Generate fix recommendations
472        let fix_recommendations = self.generate_fix_recommendations(&all_issues);
473
474        // Generate hyperparameter recommendations
475        let hyperparameter_recommendations = self.generate_hyperparameter_recommendations(context);
476
477        // Generate architecture suggestions
478        let architecture_suggestions = self.generate_architecture_suggestions(context);
479
480        // Generate training recipe optimization
481        let training_recipe = self.generate_training_recipe_optimization(context);
482
483        Ok(AutoDebugReport {
484            detected_issues: all_issues,
485            fix_recommendations: fix_recommendations.clone(),
486            hyperparameter_recommendations,
487            architecture_suggestions,
488            training_recipe,
489            analysis_summary: self.generate_analysis_summary(&fix_recommendations),
490            confidence_score: self.calculate_overall_confidence(&fix_recommendations),
491        })
492    }
493
494    /// Generate fix recommendations for detected issues
495    fn generate_fix_recommendations(&self, issues: &[DetectedIssue]) -> Vec<FixRecommendation> {
496        let mut recommendations = Vec::new();
497
498        for issue in issues {
499            if let Some(suggestions) = self.fix_suggestions.get(&issue.issue_type) {
500                for suggestion in suggestions {
501                    recommendations.push(FixRecommendation {
502                        issue: issue.clone(),
503                        fix_suggestion: suggestion.clone(),
504                        confidence: issue.confidence * 0.9, // Slightly reduce confidence
505                        urgency: self.calculate_urgency(issue),
506                    });
507                }
508            }
509        }
510
511        // Sort by urgency and confidence
512        recommendations.sort_by(|a, b| {
513            let urgency_cmp =
514                b.urgency.partial_cmp(&a.urgency).unwrap_or(std::cmp::Ordering::Equal);
515            if urgency_cmp == std::cmp::Ordering::Equal {
516                b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
517            } else {
518                urgency_cmp
519            }
520        });
521
522        recommendations
523    }
524
525    fn calculate_urgency(&self, issue: &DetectedIssue) -> f64 {
526        let severity_multiplier = match issue.severity {
527            IssueSeverity::Critical => 1.0,
528            IssueSeverity::High => 0.8,
529            IssueSeverity::Medium => 0.6,
530            IssueSeverity::Low => 0.4,
531            IssueSeverity::Info => 0.2,
532        };
533
534        issue.confidence * severity_multiplier
535    }
536
537    /// Generate hyperparameter recommendations
538    fn generate_hyperparameter_recommendations(
539        &self,
540        context: &DebugContext,
541    ) -> Vec<HyperparameterRecommendation> {
542        let mut recommendations = Vec::new();
543
544        // Learning rate recommendations
545        if let Some(metrics) = context.recent_metrics.last() {
546            if let Some(loss) = metrics.loss {
547                if loss > 1.0 {
548                    recommendations.push(HyperparameterRecommendation {
549                        parameter: "learning_rate".to_string(),
550                        current_value: None,
551                        recommended_value: 0.001,
552                        reason: "High loss suggests learning rate might be too low".to_string(),
553                        confidence: 0.7,
554                    });
555                }
556            }
557        }
558
559        // Batch size recommendations based on GPU utilization
560        if let Some(_profiler_report) = context.profiler_report {
561            // Simplified logic - in practice would analyze detailed metrics
562            recommendations.push(HyperparameterRecommendation {
563                parameter: "batch_size".to_string(),
564                current_value: None,
565                recommended_value: 32.0,
566                reason: "Optimize batch size for better GPU utilization".to_string(),
567                confidence: 0.6,
568            });
569        }
570
571        recommendations
572    }
573
574    /// Generate architecture suggestions
575    fn generate_architecture_suggestions(
576        &self,
577        context: &DebugContext,
578    ) -> Vec<ArchitectureSuggestion> {
579        let mut suggestions = Vec::new();
580
581        // Analyze model size vs performance
582        if let Some(model_info) = context.model_info {
583            if model_info.parameter_count > 100_000_000 {
584                suggestions.push(ArchitectureSuggestion {
585                    suggestion_type: "model_compression".to_string(),
586                    title: "Consider Model Compression".to_string(),
587                    description: "Large model may benefit from pruning or distillation".to_string(),
588                    impact_assessment: "Reduce memory usage by 20-50% with minimal accuracy loss"
589                        .to_string(),
590                    implementation_difficulty: "Medium".to_string(),
591                });
592            }
593
594            if model_info.layer_count > 50 {
595                suggestions.push(ArchitectureSuggestion {
596                    suggestion_type: "depth_optimization".to_string(),
597                    title: "Optimize Network Depth".to_string(),
598                    description: "Very deep network may suffer from gradient flow issues"
599                        .to_string(),
600                    impact_assessment: "Improve training stability and convergence speed"
601                        .to_string(),
602                    implementation_difficulty: "High".to_string(),
603                });
604            }
605        }
606
607        suggestions
608    }
609
610    /// Generate training recipe optimization
611    fn generate_training_recipe_optimization(
612        &self,
613        context: &DebugContext,
614    ) -> TrainingRecipeOptimization {
615        let mut optimizations = Vec::new();
616
617        // Analyze training duration and suggest optimizations
618        if context.training_duration > Duration::from_secs(3600) {
619            optimizations
620                .push("Consider learning rate scheduling to speed up convergence".to_string());
621            optimizations.push("Implement early stopping to avoid overtraining".to_string());
622        }
623
624        // Analyze recent metrics for training recipe suggestions
625        if context.recent_metrics.len() > 10 {
626            let recent_losses: Vec<f64> =
627                context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
628
629            if recent_losses.len() >= 5 {
630                let variance = self.calculate_variance(&recent_losses);
631                if variance > 0.1 {
632                    optimizations.push(
633                        "Training loss is unstable - consider reducing learning rate".to_string(),
634                    );
635                }
636            }
637        }
638
639        TrainingRecipeOptimization {
640            recommended_optimizations: optimizations,
641            training_schedule: TrainingSchedule {
642                warmup_steps: 1000,
643                learning_rate_schedule: "cosine_annealing".to_string(),
644                batch_size_schedule: "constant".to_string(),
645                early_stopping: true,
646                checkpoint_frequency: 1000,
647            },
648            data_strategy: DataStrategy {
649                data_augmentation: vec!["horizontal_flip".to_string(), "random_crop".to_string()],
650                sampling_strategy: "balanced".to_string(),
651                preprocessing_optimizations: vec![
652                    "normalization".to_string(),
653                    "standardization".to_string(),
654                ],
655            },
656        }
657    }
658
659    fn calculate_variance(&self, values: &[f64]) -> f64 {
660        if values.len() < 2 {
661            return 0.0;
662        }
663
664        let mean = values.iter().sum::<f64>() / values.len() as f64;
665        let variance =
666            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
667        variance
668    }
669
670    fn generate_analysis_summary(&self, recommendations: &[FixRecommendation]) -> String {
671        let critical_count = recommendations
672            .iter()
673            .filter(|r| matches!(r.issue.severity, IssueSeverity::Critical))
674            .count();
675
676        let high_count = recommendations
677            .iter()
678            .filter(|r| matches!(r.issue.severity, IssueSeverity::High))
679            .count();
680
681        if critical_count > 0 {
682            format!("Found {} critical issues requiring immediate attention. {} high-priority issues also detected.",
683                   critical_count, high_count)
684        } else if high_count > 0 {
685            format!(
686                "Found {} high-priority issues that should be addressed soon.",
687                high_count
688            )
689        } else if !recommendations.is_empty() {
690            "Found some optimization opportunities to improve training performance.".to_string()
691        } else {
692            "No significant issues detected. Training appears to be proceeding normally."
693                .to_string()
694        }
695    }
696
697    fn calculate_overall_confidence(&self, recommendations: &[FixRecommendation]) -> f64 {
698        if recommendations.is_empty() {
699            return 1.0;
700        }
701
702        let sum_confidence: f64 = recommendations.iter().map(|r| r.confidence).sum();
703        sum_confidence / recommendations.len() as f64
704    }
705
706    /// Record optimization attempt for learning
707    pub fn record_optimization_attempt(&mut self, attempt: OptimizationAttempt) {
708        self.optimization_history.push(attempt);
709
710        // Keep only recent attempts to prevent unbounded growth
711        if self.optimization_history.len() > 1000 {
712            self.optimization_history.drain(0..500);
713        }
714    }
715
716    /// Get optimization history for analysis
717    pub fn get_optimization_history(&self) -> &[OptimizationAttempt] {
718        &self.optimization_history
719    }
720}
721
722// Issue detector implementations
723
724#[derive(Debug)]
725struct GradientIssueDetector;
726
727impl GradientIssueDetector {
728    fn new() -> Self {
729        Self
730    }
731}
732
733impl IssueDetector for GradientIssueDetector {
734    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
735        let mut issues = Vec::new();
736
737        if let Some(gradient_report) = context.gradient_report {
738            // Check for vanishing gradients
739            if gradient_report.has_vanishing_gradients() {
740                issues.push(DetectedIssue {
741                    issue_type: IssueType::VanishingGradients,
742                    severity: IssueSeverity::High,
743                    confidence: 0.9,
744                    description: "Vanishing gradients detected in multiple layers".to_string(),
745                    evidence: vec![Evidence {
746                        metric_name: "gradient_norm".to_string(),
747                        observed_value: 0.001,
748                        expected_range: (0.01, 1.0),
749                        explanation: "Gradient norms are significantly below normal range"
750                            .to_string(),
751                    }],
752                    metrics: HashMap::new(),
753                    detected_at: chrono::Utc::now(),
754                });
755            }
756
757            // Check for exploding gradients
758            if gradient_report.has_exploding_gradients() {
759                issues.push(DetectedIssue {
760                    issue_type: IssueType::ExplodingGradients,
761                    severity: IssueSeverity::Critical,
762                    confidence: 0.95,
763                    description: "Exploding gradients detected - training instability likely"
764                        .to_string(),
765                    evidence: vec![Evidence {
766                        metric_name: "gradient_norm".to_string(),
767                        observed_value: 100.0,
768                        expected_range: (0.01, 10.0),
769                        explanation: "Gradient norms are extremely high".to_string(),
770                    }],
771                    metrics: HashMap::new(),
772                    detected_at: chrono::Utc::now(),
773                });
774            }
775        }
776
777        Ok(issues)
778    }
779
780    fn get_detector_name(&self) -> &str {
781        "GradientIssueDetector"
782    }
783
784    fn get_supported_issues(&self) -> Vec<IssueType> {
785        vec![IssueType::VanishingGradients, IssueType::ExplodingGradients]
786    }
787}
788
789#[derive(Debug)]
790struct TrainingIssueDetector;
791
792impl TrainingIssueDetector {
793    fn new() -> Self {
794        Self
795    }
796}
797
798impl IssueDetector for TrainingIssueDetector {
799    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
800        let mut issues = Vec::new();
801
802        // Analyze recent training metrics
803        if context.recent_metrics.len() >= 10 {
804            let recent_losses: Vec<f64> =
805                context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
806
807            if recent_losses.len() >= 5 {
808                // Check for stalled training
809                let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
810                    / (recent_losses.len() / 2) as f64;
811                let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
812                    / (recent_losses.len() - recent_losses.len() / 2) as f64;
813
814                if (first_half_avg - second_half_avg).abs() / first_half_avg < 0.01 {
815                    issues.push(DetectedIssue {
816                        issue_type: IssueType::TrainingStalled,
817                        severity: IssueSeverity::Medium,
818                        confidence: 0.8,
819                        description: "Training appears to have stalled - loss not decreasing"
820                            .to_string(),
821                        evidence: vec![Evidence {
822                            metric_name: "loss_change".to_string(),
823                            observed_value: (first_half_avg - second_half_avg).abs()
824                                / first_half_avg,
825                            expected_range: (0.05, 1.0),
826                            explanation: "Loss change is below expected threshold".to_string(),
827                        }],
828                        metrics: HashMap::new(),
829                        detected_at: chrono::Utc::now(),
830                    });
831                }
832            }
833        }
834
835        Ok(issues)
836    }
837
838    fn get_detector_name(&self) -> &str {
839        "TrainingIssueDetector"
840    }
841
842    fn get_supported_issues(&self) -> Vec<IssueType> {
843        vec![
844            IssueType::TrainingStalled,
845            IssueType::LossNotDecreasing,
846            IssueType::UnstableTraining,
847        ]
848    }
849}
850
851#[derive(Debug)]
852struct PerformanceIssueDetector;
853
854impl PerformanceIssueDetector {
855    fn new() -> Self {
856        Self
857    }
858}
859
860impl IssueDetector for PerformanceIssueDetector {
861    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
862        let mut issues = Vec::new();
863
864        // Check GPU utilization
865        if let Some(metrics) = context.recent_metrics.last() {
866            if let Some(gpu_util) = metrics.gpu_utilization {
867                if gpu_util < 0.5 {
868                    issues.push(DetectedIssue {
869                        issue_type: IssueType::LowGpuUtilization,
870                        severity: IssueSeverity::Medium,
871                        confidence: 0.8,
872                        description:
873                            "Low GPU utilization detected - compute resources underutilized"
874                                .to_string(),
875                        evidence: vec![Evidence {
876                            metric_name: "gpu_utilization".to_string(),
877                            observed_value: gpu_util,
878                            expected_range: (0.7, 1.0),
879                            explanation: "GPU utilization is below optimal range".to_string(),
880                        }],
881                        metrics: HashMap::new(),
882                        detected_at: chrono::Utc::now(),
883                    });
884                }
885            }
886        }
887
888        Ok(issues)
889    }
890
891    fn get_detector_name(&self) -> &str {
892        "PerformanceIssueDetector"
893    }
894
895    fn get_supported_issues(&self) -> Vec<IssueType> {
896        vec![
897            IssueType::LowGpuUtilization,
898            IssueType::SlowTraining,
899            IssueType::MemoryBottleneck,
900        ]
901    }
902}
903
904#[derive(Debug)]
905struct HyperparameterIssueDetector;
906
907impl HyperparameterIssueDetector {
908    fn new() -> Self {
909        Self
910    }
911}
912
913impl IssueDetector for HyperparameterIssueDetector {
914    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
915        let mut issues = Vec::new();
916
917        if let Some(metrics) = context.recent_metrics.last() {
918            // Check learning rate issues
919            if let Some(lr) = metrics.learning_rate {
920                if lr > 0.1 {
921                    issues.push(DetectedIssue {
922                        issue_type: IssueType::LearningRateTooHigh,
923                        severity: IssueSeverity::High,
924                        confidence: 0.7,
925                        description:
926                            "Learning rate appears too high - may cause training instability"
927                                .to_string(),
928                        evidence: vec![Evidence {
929                            metric_name: "learning_rate".to_string(),
930                            observed_value: lr,
931                            expected_range: (0.0001, 0.01),
932                            explanation: "Learning rate is above typical range".to_string(),
933                        }],
934                        metrics: HashMap::new(),
935                        detected_at: chrono::Utc::now(),
936                    });
937                } else if lr < 0.00001 {
938                    issues.push(DetectedIssue {
939                        issue_type: IssueType::LearningRateTooLow,
940                        severity: IssueSeverity::Medium,
941                        confidence: 0.6,
942                        description: "Learning rate might be too low - training could be slow"
943                            .to_string(),
944                        evidence: vec![Evidence {
945                            metric_name: "learning_rate".to_string(),
946                            observed_value: lr,
947                            expected_range: (0.0001, 0.01),
948                            explanation: "Learning rate is below typical range".to_string(),
949                        }],
950                        metrics: HashMap::new(),
951                        detected_at: chrono::Utc::now(),
952                    });
953                }
954            }
955        }
956
957        Ok(issues)
958    }
959
960    fn get_detector_name(&self) -> &str {
961        "HyperparameterIssueDetector"
962    }
963
964    fn get_supported_issues(&self) -> Vec<IssueType> {
965        vec![
966            IssueType::LearningRateTooHigh,
967            IssueType::LearningRateTooLow,
968        ]
969    }
970}
971
972#[derive(Debug)]
973struct ArchitectureIssueDetector;
974
975impl ArchitectureIssueDetector {
976    fn new() -> Self {
977        Self
978    }
979}
980
981impl IssueDetector for ArchitectureIssueDetector {
982    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
983        let mut issues = Vec::new();
984
985        if let Some(model_info) = context.model_info {
986            // Check model size
987            if model_info.parameter_count > 1_000_000_000 {
988                issues.push(DetectedIssue {
989                    issue_type: IssueType::ModelTooLarge,
990                    severity: IssueSeverity::Medium,
991                    confidence: 0.6,
992                    description:
993                        "Model has very large number of parameters - consider optimization"
994                            .to_string(),
995                    evidence: vec![Evidence {
996                        metric_name: "parameter_count".to_string(),
997                        observed_value: model_info.parameter_count as f64,
998                        expected_range: (1_000_000.0, 100_000_000.0),
999                        explanation: "Parameter count is extremely high".to_string(),
1000                    }],
1001                    metrics: HashMap::new(),
1002                    detected_at: chrono::Utc::now(),
1003                });
1004            }
1005
1006            if model_info.layer_count > 100 {
1007                issues.push(DetectedIssue {
1008                    issue_type: IssueType::InappropriateArchitecture,
1009                    severity: IssueSeverity::Low,
1010                    confidence: 0.5,
1011                    description: "Very deep model - may have gradient flow issues".to_string(),
1012                    evidence: vec![Evidence {
1013                        metric_name: "layer_count".to_string(),
1014                        observed_value: model_info.layer_count as f64,
1015                        expected_range: (10.0, 50.0),
1016                        explanation: "Layer count is very high".to_string(),
1017                    }],
1018                    metrics: HashMap::new(),
1019                    detected_at: chrono::Utc::now(),
1020                });
1021            }
1022        }
1023
1024        Ok(issues)
1025    }
1026
1027    fn get_detector_name(&self) -> &str {
1028        "ArchitectureIssueDetector"
1029    }
1030
1031    fn get_supported_issues(&self) -> Vec<IssueType> {
1032        vec![
1033            IssueType::ModelTooLarge,
1034            IssueType::InappropriateArchitecture,
1035        ]
1036    }
1037}
1038
1039#[derive(Debug)]
1040struct DataIssueDetector;
1041
1042impl DataIssueDetector {
1043    fn new() -> Self {
1044        Self
1045    }
1046}
1047
1048impl IssueDetector for DataIssueDetector {
1049    fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
1050        // Detect three classes of data-related issues from dashboard metrics:
1051        //
1052        //  - BatchSizeProblems  : sustained low GPU utilisation paired with low
1053        //                         tokens/sec is a strong signal that the batch is
1054        //                         too small to saturate the device.
1055        //  - DataImbalance      : accuracy pinned at a near-trivial value (high
1056        //                         floor or low ceiling) while loss continues to
1057        //                         change is the canonical signature of a model
1058        //                         collapsing onto a majority class.
1059        //  - InsufficientData   : loss decreases but accuracy oscillates wildly,
1060        //                         indicating the model is memorising rather than
1061        //                         generalising — a classic small-dataset failure
1062        //                         mode.
1063        //
1064        // Heuristics use thresholds tuned for a "typical" supervised-learning
1065        // setup; they are intentionally conservative so we do not produce false
1066        // positives on small recent_metrics windows.
1067        let mut issues = Vec::new();
1068
1069        const MIN_WINDOW: usize = 5;
1070        if context.recent_metrics.len() < MIN_WINDOW {
1071            return Ok(issues);
1072        }
1073
1074        // Sample the freshest MIN_WINDOW metrics for analysis.
1075        let window: Vec<&DashboardMetrics> =
1076            context.recent_metrics.iter().rev().take(MIN_WINDOW * 2).collect();
1077
1078        // --- BatchSizeProblems ---------------------------------------------
1079        let gpu_samples: Vec<f64> = window.iter().filter_map(|m| m.gpu_utilization).collect();
1080        let tps_samples: Vec<f64> = window.iter().filter_map(|m| m.tokens_per_second).collect();
1081        if gpu_samples.len() >= MIN_WINDOW && tps_samples.len() >= MIN_WINDOW {
1082            let gpu_mean = gpu_samples.iter().sum::<f64>() / gpu_samples.len() as f64;
1083            let tps_mean = tps_samples.iter().sum::<f64>() / tps_samples.len() as f64;
1084            // Low GPU and low throughput together strongly suggest the batch is
1085            // starving the device. We use 0.5 (50% utilisation) and 100 tok/s as
1086            // canonical thresholds, matching the project's other detectors.
1087            if gpu_mean < 0.5 && tps_mean < 100.0 {
1088                let mut metrics = HashMap::new();
1089                metrics.insert("avg_gpu_utilization".to_string(), gpu_mean);
1090                metrics.insert("avg_tokens_per_second".to_string(), tps_mean);
1091                issues.push(DetectedIssue {
1092                    issue_type: IssueType::BatchSizeProblems,
1093                    severity: IssueSeverity::Medium,
1094                    confidence: 0.7,
1095                    description:
1096                        "Sustained low GPU utilisation and throughput suggest batch size may be \
1097                         too small to saturate the device"
1098                            .to_string(),
1099                    evidence: vec![
1100                        Evidence {
1101                            metric_name: "gpu_utilization".to_string(),
1102                            observed_value: gpu_mean,
1103                            expected_range: (0.7, 1.0),
1104                            explanation:
1105                                "Average GPU utilisation is below the typical training range"
1106                                    .to_string(),
1107                        },
1108                        Evidence {
1109                            metric_name: "tokens_per_second".to_string(),
1110                            observed_value: tps_mean,
1111                            expected_range: (100.0, f64::INFINITY),
1112                            explanation: "Throughput is below the typical training floor"
1113                                .to_string(),
1114                        },
1115                    ],
1116                    metrics,
1117                    detected_at: chrono::Utc::now(),
1118                });
1119            }
1120        }
1121
1122        // --- DataImbalance / InsufficientData ------------------------------
1123        let acc_samples: Vec<f64> = window.iter().filter_map(|m| m.accuracy).collect();
1124        let loss_samples: Vec<f64> = window.iter().filter_map(|m| m.loss).collect();
1125
1126        if acc_samples.len() >= MIN_WINDOW && loss_samples.len() >= MIN_WINDOW {
1127            let acc_mean = acc_samples.iter().sum::<f64>() / acc_samples.len() as f64;
1128            let acc_var = acc_samples
1129                .iter()
1130                .map(|a| {
1131                    let d = a - acc_mean;
1132                    d * d
1133                })
1134                .sum::<f64>()
1135                / acc_samples.len() as f64;
1136            let acc_stddev = acc_var.sqrt();
1137
1138            // `window` (and therefore `loss_samples`) is ordered newest-first.
1139            // Compare the older half (end of the slice) to the newer half
1140            // (start of the slice) to decide whether loss is decreasing.
1141            let half = loss_samples.len() / 2;
1142            let newer_half = &loss_samples[..half];
1143            let older_half = &loss_samples[loss_samples.len() - half..];
1144            let newer_avg = if newer_half.is_empty() {
1145                0.0
1146            } else {
1147                newer_half.iter().sum::<f64>() / newer_half.len() as f64
1148            };
1149            let older_avg = if older_half.is_empty() {
1150                0.0
1151            } else {
1152                older_half.iter().sum::<f64>() / older_half.len() as f64
1153            };
1154            // Positive => loss decreasing over time.
1155            let loss_relative_change = if older_avg.abs() > f64::EPSILON {
1156                (older_avg - newer_avg) / older_avg.abs()
1157            } else {
1158                0.0
1159            };
1160
1161            // DataImbalance: accuracy is pinned (very low variance) at an
1162            // extreme value (either trivially low or near-perfect) while the
1163            // loss continues to move meaningfully. Models collapsing onto the
1164            // majority class show exactly this signature.
1165            let acc_pinned_extreme = acc_stddev < 0.01 && !(0.2..=0.95).contains(&acc_mean);
1166            let loss_changing = loss_relative_change.abs() > 0.05;
1167            if acc_pinned_extreme && loss_changing {
1168                let mut metrics = HashMap::new();
1169                metrics.insert("accuracy_mean".to_string(), acc_mean);
1170                metrics.insert("accuracy_stddev".to_string(), acc_stddev);
1171                metrics.insert("loss_relative_change".to_string(), loss_relative_change);
1172                issues.push(DetectedIssue {
1173                    issue_type: IssueType::DataImbalance,
1174                    severity: IssueSeverity::High,
1175                    confidence: 0.75,
1176                    description:
1177                        "Accuracy is pinned at an extreme value while loss continues to change \
1178                         — model may be collapsing onto a majority class"
1179                            .to_string(),
1180                    evidence: vec![Evidence {
1181                        metric_name: "accuracy_stddev".to_string(),
1182                        observed_value: acc_stddev,
1183                        expected_range: (0.01, 0.5),
1184                        explanation:
1185                            "Accuracy variance is far below the range expected during healthy \
1186                             training"
1187                                .to_string(),
1188                    }],
1189                    metrics,
1190                    detected_at: chrono::Utc::now(),
1191                });
1192            }
1193
1194            // InsufficientData: loss is steadily decreasing (model fitting the
1195            // training set) but accuracy is highly volatile, suggesting the
1196            // model is memorising rather than generalising — a classic failure
1197            // mode of training on too little data.
1198            if loss_relative_change > 0.10 && acc_stddev > 0.15 {
1199                let mut metrics = HashMap::new();
1200                metrics.insert("accuracy_stddev".to_string(), acc_stddev);
1201                metrics.insert("loss_relative_change".to_string(), loss_relative_change);
1202                issues.push(DetectedIssue {
1203                    issue_type: IssueType::InsufficientData,
1204                    severity: IssueSeverity::Medium,
1205                    confidence: 0.6,
1206                    description:
1207                        "Loss is decreasing but accuracy fluctuates wildly — the dataset may be \
1208                         too small, leading to memorisation rather than generalisation"
1209                            .to_string(),
1210                    evidence: vec![Evidence {
1211                        metric_name: "accuracy_stddev".to_string(),
1212                        observed_value: acc_stddev,
1213                        expected_range: (0.0, 0.10),
1214                        explanation:
1215                            "Accuracy variance is well above what is expected when the model \
1216                             is generalising"
1217                                .to_string(),
1218                    }],
1219                    metrics,
1220                    detected_at: chrono::Utc::now(),
1221                });
1222            }
1223        }
1224
1225        Ok(issues)
1226    }
1227
1228    fn get_detector_name(&self) -> &str {
1229        "DataIssueDetector"
1230    }
1231
1232    fn get_supported_issues(&self) -> Vec<IssueType> {
1233        vec![
1234            IssueType::DataImbalance,
1235            IssueType::BatchSizeProblems,
1236            IssueType::InsufficientData,
1237        ]
1238    }
1239}
1240
1241impl Default for KnowledgeBase {
1242    fn default() -> Self {
1243        Self::new()
1244    }
1245}
1246
1247impl KnowledgeBase {
1248    pub fn new() -> Self {
1249        Self {
1250            issue_patterns: HashMap::new(),
1251            hyperparameter_recommendations: HashMap::new(),
1252            architecture_patterns: Vec::new(),
1253            best_practices: HashMap::new(),
1254        }
1255    }
1256}
1257
1258// Report structures
1259
1260#[derive(Debug, Serialize, Deserialize)]
1261pub struct AutoDebugReport {
1262    pub detected_issues: Vec<DetectedIssue>,
1263    pub fix_recommendations: Vec<FixRecommendation>,
1264    pub hyperparameter_recommendations: Vec<HyperparameterRecommendation>,
1265    pub architecture_suggestions: Vec<ArchitectureSuggestion>,
1266    pub training_recipe: TrainingRecipeOptimization,
1267    pub analysis_summary: String,
1268    pub confidence_score: f64,
1269}
1270
1271#[derive(Debug, Clone, Serialize, Deserialize)]
1272pub struct FixRecommendation {
1273    pub issue: DetectedIssue,
1274    pub fix_suggestion: FixSuggestion,
1275    pub confidence: f64,
1276    pub urgency: f64,
1277}
1278
1279#[derive(Debug, Clone, Serialize, Deserialize)]
1280pub struct HyperparameterRecommendation {
1281    pub parameter: String,
1282    pub current_value: Option<f64>,
1283    pub recommended_value: f64,
1284    pub reason: String,
1285    pub confidence: f64,
1286}
1287
1288#[derive(Debug, Clone, Serialize, Deserialize)]
1289pub struct ArchitectureSuggestion {
1290    pub suggestion_type: String,
1291    pub title: String,
1292    pub description: String,
1293    pub impact_assessment: String,
1294    pub implementation_difficulty: String,
1295}
1296
1297#[derive(Debug, Clone, Serialize, Deserialize)]
1298pub struct TrainingRecipeOptimization {
1299    pub recommended_optimizations: Vec<String>,
1300    pub training_schedule: TrainingSchedule,
1301    pub data_strategy: DataStrategy,
1302}
1303
1304#[derive(Debug, Clone, Serialize, Deserialize)]
1305pub struct TrainingSchedule {
1306    pub warmup_steps: u32,
1307    pub learning_rate_schedule: String,
1308    pub batch_size_schedule: String,
1309    pub early_stopping: bool,
1310    pub checkpoint_frequency: u32,
1311}
1312
1313#[derive(Debug, Clone, Serialize, Deserialize)]
1314pub struct DataStrategy {
1315    pub data_augmentation: Vec<String>,
1316    pub sampling_strategy: String,
1317    pub preprocessing_optimizations: Vec<String>,
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322    use super::*;
1323
1324    fn make_config() -> DebugConfig {
1325        DebugConfig::default()
1326    }
1327
1328    #[test]
1329    fn test_knowledge_base_new() {
1330        let kb = KnowledgeBase::new();
1331        assert!(kb.issue_patterns.is_empty());
1332        assert!(kb.hyperparameter_recommendations.is_empty());
1333        assert!(kb.architecture_patterns.is_empty());
1334        assert!(kb.best_practices.is_empty());
1335    }
1336
1337    #[test]
1338    fn test_knowledge_base_default() {
1339        let kb = KnowledgeBase::default();
1340        assert!(kb.issue_patterns.is_empty());
1341    }
1342
1343    #[test]
1344    fn test_auto_debugger_new() {
1345        let config = make_config();
1346        let debugger = AutoDebugger::new(&config);
1347        assert!(!debugger.issue_detectors.is_empty());
1348        assert!(!debugger.fix_suggestions.is_empty());
1349        assert!(debugger.optimization_history.is_empty());
1350    }
1351
1352    #[test]
1353    fn test_auto_debugger_has_default_detectors() {
1354        let config = make_config();
1355        let debugger = AutoDebugger::new(&config);
1356        assert_eq!(debugger.issue_detectors.len(), 6);
1357    }
1358
1359    #[test]
1360    fn test_auto_debugger_has_fix_suggestions() {
1361        let config = make_config();
1362        let debugger = AutoDebugger::new(&config);
1363        assert!(debugger.fix_suggestions.contains_key(&IssueType::VanishingGradients));
1364        assert!(debugger.fix_suggestions.contains_key(&IssueType::ExplodingGradients));
1365    }
1366
1367    #[test]
1368    fn test_gradient_issue_detector_name() {
1369        let detector = GradientIssueDetector::new();
1370        assert_eq!(detector.get_detector_name(), "GradientIssueDetector");
1371    }
1372
1373    #[test]
1374    fn test_gradient_issue_detector_supported_issues() {
1375        let detector = GradientIssueDetector::new();
1376        let issues = detector.get_supported_issues();
1377        assert!(issues.contains(&IssueType::VanishingGradients));
1378        assert!(issues.contains(&IssueType::ExplodingGradients));
1379    }
1380
1381    #[test]
1382    fn test_training_issue_detector_name() {
1383        let detector = TrainingIssueDetector::new();
1384        assert_eq!(detector.get_detector_name(), "TrainingIssueDetector");
1385    }
1386
1387    #[test]
1388    fn test_training_issue_detector_supported_issues() {
1389        let detector = TrainingIssueDetector::new();
1390        let issues = detector.get_supported_issues();
1391        assert!(!issues.is_empty());
1392    }
1393
1394    #[test]
1395    fn test_performance_issue_detector_name() {
1396        let detector = PerformanceIssueDetector::new();
1397        assert_eq!(detector.get_detector_name(), "PerformanceIssueDetector");
1398    }
1399
1400    #[test]
1401    fn test_hyperparameter_issue_detector_name() {
1402        let detector = HyperparameterIssueDetector::new();
1403        assert_eq!(detector.get_detector_name(), "HyperparameterIssueDetector");
1404    }
1405
1406    #[test]
1407    fn test_architecture_issue_detector_name() {
1408        let detector = ArchitectureIssueDetector::new();
1409        assert_eq!(detector.get_detector_name(), "ArchitectureIssueDetector");
1410    }
1411
1412    #[test]
1413    fn test_data_issue_detector_name() {
1414        let detector = DataIssueDetector::new();
1415        assert_eq!(detector.get_detector_name(), "DataIssueDetector");
1416    }
1417
1418    #[test]
1419    fn test_issue_type_equality() {
1420        assert_eq!(IssueType::VanishingGradients, IssueType::VanishingGradients);
1421        assert_ne!(IssueType::VanishingGradients, IssueType::ExplodingGradients);
1422    }
1423
1424    #[test]
1425    fn test_issue_type_hash_compatible() {
1426        let mut map = HashMap::new();
1427        map.insert(IssueType::OverfittingDetected, "fix");
1428        assert!(map.contains_key(&IssueType::OverfittingDetected));
1429        assert!(!map.contains_key(&IssueType::UnderfittingDetected));
1430    }
1431
1432    #[test]
1433    fn test_evidence_construction() {
1434        let evidence = Evidence {
1435            metric_name: "gradient_norm".to_string(),
1436            observed_value: 0.001,
1437            expected_range: (0.01, 1.0),
1438            explanation: "Gradient norm too low".to_string(),
1439        };
1440        assert_eq!(evidence.metric_name, "gradient_norm");
1441        assert!(evidence.observed_value < evidence.expected_range.0);
1442    }
1443
1444    #[test]
1445    fn test_expected_impact_fields() {
1446        let impact = ExpectedImpact {
1447            performance_improvement: 0.15,
1448            training_speed_improvement: 0.05,
1449            stability_improvement: 0.25,
1450            memory_usage_change: 0.02,
1451        };
1452        assert!(impact.performance_improvement > 0.0);
1453        assert!(impact.stability_improvement > impact.performance_improvement);
1454    }
1455
1456    #[test]
1457    fn test_model_info_construction() {
1458        let info = ModelInfo {
1459            model_type: "transformer".to_string(),
1460            parameter_count: 1_000_000,
1461            layer_count: 12,
1462            architecture_details: HashMap::new(),
1463        };
1464        assert_eq!(info.model_type, "transformer");
1465        assert_eq!(info.parameter_count, 1_000_000);
1466    }
1467
1468    #[test]
1469    fn test_issue_pattern_construction() {
1470        let pattern = IssuePattern {
1471            symptoms: vec!["low gradient norm".to_string()],
1472            common_causes: vec!["deep network".to_string()],
1473            diagnostic_metrics: vec!["gradient_norm".to_string()],
1474            typical_solutions: vec!["add skip connections".to_string()],
1475        };
1476        assert_eq!(pattern.symptoms.len(), 1);
1477        assert_eq!(pattern.common_causes.len(), 1);
1478    }
1479
1480    #[test]
1481    fn test_hyperparameter_advice_construction() {
1482        let advice = HyperparameterAdvice {
1483            parameter_name: "learning_rate".to_string(),
1484            recommended_range: (1e-5, 1e-2),
1485            tuning_strategy: "grid_search".to_string(),
1486            dependencies: vec!["batch_size".to_string()],
1487            common_mistakes: vec!["too high initial lr".to_string()],
1488        };
1489        assert!(advice.recommended_range.0 < advice.recommended_range.1);
1490    }
1491
1492    fn make_metric(
1493        loss: Option<f64>,
1494        accuracy: Option<f64>,
1495        gpu: Option<f64>,
1496        tps: Option<f64>,
1497    ) -> DashboardMetrics {
1498        DashboardMetrics {
1499            timestamp: std::time::SystemTime::now(),
1500            loss,
1501            accuracy,
1502            learning_rate: Some(1e-3),
1503            memory_usage_mb: 1024.0,
1504            gpu_utilization: gpu,
1505            tokens_per_second: tps,
1506            gradient_norm: Some(0.5),
1507            epoch: Some(0),
1508            step: Some(0),
1509        }
1510    }
1511
1512    #[test]
1513    fn test_data_issue_detector_returns_empty_with_no_metrics() {
1514        let detector = DataIssueDetector::new();
1515        let context = DebugContext {
1516            profiler_report: None,
1517            gradient_report: None,
1518            anomaly_report: None,
1519            recent_metrics: &[],
1520            training_duration: Duration::from_secs(60),
1521            model_info: None,
1522        };
1523        let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1524        assert!(issues.is_empty());
1525    }
1526
1527    #[test]
1528    fn test_data_issue_detector_flags_batch_size_problem() {
1529        let detector = DataIssueDetector::new();
1530        // Simulate a long stretch of low GPU utilisation and low throughput.
1531        let metrics: Vec<DashboardMetrics> = (0..10)
1532            .map(|i| {
1533                make_metric(
1534                    Some(2.0 - i as f64 * 0.01),
1535                    Some(0.6),
1536                    Some(0.2),
1537                    Some(50.0),
1538                )
1539            })
1540            .collect();
1541        let context = DebugContext {
1542            profiler_report: None,
1543            gradient_report: None,
1544            anomaly_report: None,
1545            recent_metrics: &metrics,
1546            training_duration: Duration::from_secs(600),
1547            model_info: None,
1548        };
1549        let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1550        assert!(
1551            issues.iter().any(|i| i.issue_type == IssueType::BatchSizeProblems),
1552            "expected BatchSizeProblems to be flagged, got: {:?}",
1553            issues.iter().map(|i| &i.issue_type).collect::<Vec<_>>()
1554        );
1555    }
1556
1557    #[test]
1558    fn test_data_issue_detector_flags_data_imbalance_when_accuracy_pinned() {
1559        let detector = DataIssueDetector::new();
1560        // Accuracy pinned at ~0.97 with virtually no variance, while loss
1561        // continues to fall: a classic majority-class collapse.
1562        let metrics: Vec<DashboardMetrics> = (0..10)
1563            .map(|i| {
1564                make_metric(
1565                    Some(2.0 - i as f64 * 0.10),
1566                    Some(0.97),
1567                    Some(0.85),
1568                    Some(500.0),
1569                )
1570            })
1571            .collect();
1572        let context = DebugContext {
1573            profiler_report: None,
1574            gradient_report: None,
1575            anomaly_report: None,
1576            recent_metrics: &metrics,
1577            training_duration: Duration::from_secs(600),
1578            model_info: None,
1579        };
1580        let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1581        assert!(
1582            issues.iter().any(|i| i.issue_type == IssueType::DataImbalance),
1583            "expected DataImbalance to be flagged, got: {:?}",
1584            issues.iter().map(|i| &i.issue_type).collect::<Vec<_>>()
1585        );
1586    }
1587}