1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{
9 AnomalyDetectorReport, DashboardMetrics, DebugConfig, GradientDebugReport, ProfilerReport,
10};
11
12#[derive(Debug)]
14#[allow(dead_code)]
15pub struct AutoDebugger {
16 #[allow(dead_code)]
17 config: DebugConfig,
18 issue_detectors: Vec<Box<dyn IssueDetector>>,
19 fix_suggestions: HashMap<IssueType, Vec<FixSuggestion>>,
20 optimization_history: Vec<OptimizationAttempt>,
21 knowledge_base: KnowledgeBase,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum IssueType {
27 VanishingGradients,
29 ExplodingGradients,
30 LearningRateProblems,
31 OverfittingDetected,
32 UnderfittingDetected,
33 TrainingStalled,
34 LossNotDecreasing,
35 UnstableTraining,
36 MemoryIssues,
37
38 ModelTooLarge,
40 ModelTooSmall,
41 InappropriateArchitecture,
42 LayerMismatch,
43 ActivationProblems,
44
45 DataImbalance,
47 DataLeakage,
48 InsufficientData,
49 DataQualityIssues,
50 BatchSizeProblems,
51
52 SlowTraining,
54 LowGpuUtilization,
55 MemoryBottleneck,
56 IoBottleneck,
57 ComputeBottleneck,
58
59 LearningRateTooHigh,
61 LearningRateTooLow,
62 BatchSizeTooLarge,
63 BatchSizeTooSmall,
64 RegularizationIssues,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DetectedIssue {
70 pub issue_type: IssueType,
71 pub severity: IssueSeverity,
72 pub confidence: f64,
73 pub description: String,
74 pub evidence: Vec<Evidence>,
75 pub metrics: HashMap<String, f64>,
76 pub detected_at: chrono::DateTime<chrono::Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum IssueSeverity {
81 Critical,
82 High,
83 Medium,
84 Low,
85 Info,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Evidence {
90 pub metric_name: String,
91 pub observed_value: f64,
92 pub expected_range: (f64, f64),
93 pub explanation: String,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FixSuggestion {
99 pub fix_id: String,
100 pub fix_type: FixType,
101 pub title: String,
102 pub description: String,
103 pub implementation_steps: Vec<String>,
104 pub expected_impact: ExpectedImpact,
105 pub priority: FixPriority,
106 pub estimated_effort: EstimatedEffort,
107 pub prerequisites: Vec<String>,
108 pub code_examples: Vec<CodeExample>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum FixType {
113 HyperparameterAdjustment,
114 ArchitectureChange,
115 TrainingProcedure,
116 DataProcessing,
117 OptimizationTechnique,
118 EnvironmentConfig,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ExpectedImpact {
123 pub performance_improvement: f64,
124 pub training_speed_improvement: f64,
125 pub stability_improvement: f64,
126 pub memory_usage_change: f64,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum FixPriority {
131 Critical,
132 High,
133 Medium,
134 Low,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub enum EstimatedEffort {
139 Trivial, Easy, Medium, Hard, Complex, }
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CodeExample {
148 pub language: String,
149 pub code: String,
150 pub explanation: String,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct OptimizationAttempt {
156 pub attempt_id: String,
157 pub issue_addressed: IssueType,
158 pub fix_applied: String,
159 pub before_metrics: HashMap<String, f64>,
160 pub after_metrics: Option<HashMap<String, f64>>,
161 pub success: Option<bool>,
162 pub notes: String,
163 pub timestamp: chrono::DateTime<chrono::Utc>,
164}
165
166#[derive(Debug)]
168#[allow(dead_code)]
169pub struct KnowledgeBase {
170 #[allow(dead_code)]
171 issue_patterns: HashMap<IssueType, IssuePattern>,
172 hyperparameter_recommendations: HashMap<String, HyperparameterAdvice>,
173 architecture_patterns: Vec<ArchitecturePattern>,
174 best_practices: HashMap<String, Vec<String>>,
175}
176
177#[derive(Debug, Clone)]
178pub struct IssuePattern {
179 pub symptoms: Vec<String>,
180 pub common_causes: Vec<String>,
181 pub diagnostic_metrics: Vec<String>,
182 pub typical_solutions: Vec<String>,
183}
184
185#[derive(Debug, Clone)]
186pub struct HyperparameterAdvice {
187 pub parameter_name: String,
188 pub recommended_range: (f64, f64),
189 pub tuning_strategy: String,
190 pub dependencies: Vec<String>,
191 pub common_mistakes: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
195pub struct ArchitecturePattern {
196 pub pattern_name: String,
197 pub use_cases: Vec<String>,
198 pub typical_layers: Vec<String>,
199 pub hyperparameter_suggestions: HashMap<String, f64>,
200 pub performance_characteristics: String,
201}
202
203pub trait IssueDetector: std::fmt::Debug {
205 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>>;
206 fn get_detector_name(&self) -> &str;
207 fn get_supported_issues(&self) -> Vec<IssueType>;
208}
209
210#[derive(Debug)]
212pub struct DebugContext<'a> {
213 pub profiler_report: Option<&'a ProfilerReport>,
214 pub gradient_report: Option<&'a GradientDebugReport>,
215 pub anomaly_report: Option<&'a AnomalyDetectorReport>,
216 pub recent_metrics: &'a [DashboardMetrics],
217 pub training_duration: Duration,
218 pub model_info: Option<&'a ModelInfo>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ModelInfo {
223 pub model_type: String,
224 pub parameter_count: usize,
225 pub layer_count: usize,
226 pub architecture_details: HashMap<String, String>,
227}
228
229impl AutoDebugger {
230 pub fn new(config: &DebugConfig) -> Self {
232 let mut auto_debugger = Self {
233 config: config.clone(),
234 issue_detectors: Vec::new(),
235 fix_suggestions: HashMap::new(),
236 optimization_history: Vec::new(),
237 knowledge_base: KnowledgeBase::new(),
238 };
239
240 auto_debugger.register_default_detectors();
242 auto_debugger.initialize_fix_suggestions();
243
244 auto_debugger
245 }
246
247 fn register_default_detectors(&mut self) {
249 self.issue_detectors.push(Box::new(GradientIssueDetector::new()));
250 self.issue_detectors.push(Box::new(TrainingIssueDetector::new()));
251 self.issue_detectors.push(Box::new(PerformanceIssueDetector::new()));
252 self.issue_detectors.push(Box::new(HyperparameterIssueDetector::new()));
253 self.issue_detectors.push(Box::new(ArchitectureIssueDetector::new()));
254 self.issue_detectors.push(Box::new(DataIssueDetector::new()));
255 }
256
257 fn initialize_fix_suggestions(&mut self) {
259 self.fix_suggestions.insert(
261 IssueType::VanishingGradients,
262 vec![
263 FixSuggestion {
264 fix_id: "vg_001".to_string(),
265 fix_type: FixType::ArchitectureChange,
266 title: "Add Residual Connections".to_string(),
267 description:
268 "Implement skip connections to help gradients flow through deep networks"
269 .to_string(),
270 implementation_steps: vec![
271 "Add residual blocks to your model architecture".to_string(),
272 "Ensure input and output dimensions match for residual connections"
273 .to_string(),
274 "Consider using batch normalization within residual blocks".to_string(),
275 ],
276 expected_impact: ExpectedImpact {
277 performance_improvement: 0.15,
278 training_speed_improvement: 0.05,
279 stability_improvement: 0.25,
280 memory_usage_change: 0.02,
281 },
282 priority: FixPriority::High,
283 estimated_effort: EstimatedEffort::Medium,
284 prerequisites: vec!["Model architecture access".to_string()],
285 code_examples: vec![CodeExample {
286 language: "python".to_string(),
287 code: r#"
288class ResidualBlock(nn.Module):
289 def __init__(self, channels):
290 super().__init__()
291 self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
292 self.bn1 = nn.BatchNorm2d(channels)
293 self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
294 self.bn2 = nn.BatchNorm2d(channels)
295
296 def forward(self, x):
297 residual = x
298 out = F.relu(self.bn1(self.conv1(x)))
299 out = self.bn2(self.conv2(out))
300 out += residual # Skip connection
301 return F.relu(out)
302"#
303 .to_string(),
304 explanation: "Basic residual block implementation with skip connection"
305 .to_string(),
306 }],
307 },
308 FixSuggestion {
309 fix_id: "vg_002".to_string(),
310 fix_type: FixType::HyperparameterAdjustment,
311 title: "Adjust Learning Rate".to_string(),
312 description:
313 "Increase learning rate to help gradients propagate more effectively"
314 .to_string(),
315 implementation_steps: vec![
316 "Increase learning rate by 2-5x".to_string(),
317 "Monitor training stability".to_string(),
318 "Consider learning rate scheduling".to_string(),
319 ],
320 expected_impact: ExpectedImpact {
321 performance_improvement: 0.08,
322 training_speed_improvement: 0.10,
323 stability_improvement: -0.05,
324 memory_usage_change: 0.0,
325 },
326 priority: FixPriority::Medium,
327 estimated_effort: EstimatedEffort::Trivial,
328 prerequisites: vec![],
329 code_examples: vec![CodeExample {
330 language: "python".to_string(),
331 code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
332 .to_string(),
333 explanation: "Increase learning rate to help overcome vanishing gradients"
334 .to_string(),
335 }],
336 },
337 ],
338 );
339
340 self.fix_suggestions.insert(
342 IssueType::ExplodingGradients,
343 vec![FixSuggestion {
344 fix_id: "eg_001".to_string(),
345 fix_type: FixType::TrainingProcedure,
346 title: "Apply Gradient Clipping".to_string(),
347 description: "Clip gradients to prevent explosion during backpropagation"
348 .to_string(),
349 implementation_steps: vec![
350 "Add gradient clipping to your training loop".to_string(),
351 "Start with clip value of 1.0 and adjust based on results".to_string(),
352 "Monitor gradient norms to ensure clipping is effective".to_string(),
353 ],
354 expected_impact: ExpectedImpact {
355 performance_improvement: 0.10,
356 training_speed_improvement: 0.0,
357 stability_improvement: 0.30,
358 memory_usage_change: 0.0,
359 },
360 priority: FixPriority::Critical,
361 estimated_effort: EstimatedEffort::Easy,
362 prerequisites: vec![],
363 code_examples: vec![CodeExample {
364 language: "python".to_string(),
365 code: "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
366 .to_string(),
367 explanation: "Clip gradients before optimizer step".to_string(),
368 }],
369 }],
370 );
371
372 self.fix_suggestions.insert(
374 IssueType::LearningRateTooHigh,
375 vec![FixSuggestion {
376 fix_id: "lr_high_001".to_string(),
377 fix_type: FixType::HyperparameterAdjustment,
378 title: "Reduce Learning Rate".to_string(),
379 description: "Lower the learning rate to improve training stability".to_string(),
380 implementation_steps: vec![
381 "Reduce learning rate by 2-10x".to_string(),
382 "Consider learning rate scheduling".to_string(),
383 "Monitor loss convergence".to_string(),
384 ],
385 expected_impact: ExpectedImpact {
386 performance_improvement: 0.12,
387 training_speed_improvement: -0.05,
388 stability_improvement: 0.25,
389 memory_usage_change: 0.0,
390 },
391 priority: FixPriority::High,
392 estimated_effort: EstimatedEffort::Trivial,
393 prerequisites: vec![],
394 code_examples: vec![CodeExample {
395 language: "python".to_string(),
396 code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)".to_string(),
397 explanation: "Reduce learning rate for more stable training".to_string(),
398 }],
399 }],
400 );
401
402 self.fix_suggestions.insert(
404 IssueType::LowGpuUtilization,
405 vec![FixSuggestion {
406 fix_id: "gpu_001".to_string(),
407 fix_type: FixType::HyperparameterAdjustment,
408 title: "Increase Batch Size".to_string(),
409 description: "Increase batch size to better utilize GPU compute capacity"
410 .to_string(),
411 implementation_steps: vec![
412 "Double the current batch size".to_string(),
413 "Monitor memory usage to avoid OOM".to_string(),
414 "Adjust learning rate proportionally".to_string(),
415 ],
416 expected_impact: ExpectedImpact {
417 performance_improvement: 0.05,
418 training_speed_improvement: 0.30,
419 stability_improvement: 0.0,
420 memory_usage_change: 0.20,
421 },
422 priority: FixPriority::Medium,
423 estimated_effort: EstimatedEffort::Easy,
424 prerequisites: vec!["Available GPU memory".to_string()],
425 code_examples: vec![CodeExample {
426 language: "python".to_string(),
427 code: "train_loader = DataLoader(dataset, batch_size=64, shuffle=True)"
428 .to_string(),
429 explanation: "Increase batch size to improve GPU utilization".to_string(),
430 }],
431 }],
432 );
433 }
434
435 pub fn analyze_issues(&self, context: &DebugContext) -> Result<AutoDebugReport> {
437 let mut all_issues = Vec::new();
438
439 for detector in &self.issue_detectors {
441 match detector.detect_issues(context) {
442 Ok(mut issues) => all_issues.append(&mut issues),
443 Err(e) => {
444 tracing::warn!(
445 "Issue detector '{}' failed: {}",
446 detector.get_detector_name(),
447 e
448 );
449 },
450 }
451 }
452
453 all_issues.sort_by(|a, b| {
455 let severity_order = |s: &IssueSeverity| match s {
456 IssueSeverity::Critical => 0,
457 IssueSeverity::High => 1,
458 IssueSeverity::Medium => 2,
459 IssueSeverity::Low => 3,
460 IssueSeverity::Info => 4,
461 };
462
463 let severity_cmp = severity_order(&a.severity).cmp(&severity_order(&b.severity));
464 if severity_cmp == std::cmp::Ordering::Equal {
465 b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
466 } else {
467 severity_cmp
468 }
469 });
470
471 let fix_recommendations = self.generate_fix_recommendations(&all_issues);
473
474 let hyperparameter_recommendations = self.generate_hyperparameter_recommendations(context);
476
477 let architecture_suggestions = self.generate_architecture_suggestions(context);
479
480 let training_recipe = self.generate_training_recipe_optimization(context);
482
483 Ok(AutoDebugReport {
484 detected_issues: all_issues,
485 fix_recommendations: fix_recommendations.clone(),
486 hyperparameter_recommendations,
487 architecture_suggestions,
488 training_recipe,
489 analysis_summary: self.generate_analysis_summary(&fix_recommendations),
490 confidence_score: self.calculate_overall_confidence(&fix_recommendations),
491 })
492 }
493
494 fn generate_fix_recommendations(&self, issues: &[DetectedIssue]) -> Vec<FixRecommendation> {
496 let mut recommendations = Vec::new();
497
498 for issue in issues {
499 if let Some(suggestions) = self.fix_suggestions.get(&issue.issue_type) {
500 for suggestion in suggestions {
501 recommendations.push(FixRecommendation {
502 issue: issue.clone(),
503 fix_suggestion: suggestion.clone(),
504 confidence: issue.confidence * 0.9, urgency: self.calculate_urgency(issue),
506 });
507 }
508 }
509 }
510
511 recommendations.sort_by(|a, b| {
513 let urgency_cmp =
514 b.urgency.partial_cmp(&a.urgency).unwrap_or(std::cmp::Ordering::Equal);
515 if urgency_cmp == std::cmp::Ordering::Equal {
516 b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
517 } else {
518 urgency_cmp
519 }
520 });
521
522 recommendations
523 }
524
525 fn calculate_urgency(&self, issue: &DetectedIssue) -> f64 {
526 let severity_multiplier = match issue.severity {
527 IssueSeverity::Critical => 1.0,
528 IssueSeverity::High => 0.8,
529 IssueSeverity::Medium => 0.6,
530 IssueSeverity::Low => 0.4,
531 IssueSeverity::Info => 0.2,
532 };
533
534 issue.confidence * severity_multiplier
535 }
536
537 fn generate_hyperparameter_recommendations(
539 &self,
540 context: &DebugContext,
541 ) -> Vec<HyperparameterRecommendation> {
542 let mut recommendations = Vec::new();
543
544 if let Some(metrics) = context.recent_metrics.last() {
546 if let Some(loss) = metrics.loss {
547 if loss > 1.0 {
548 recommendations.push(HyperparameterRecommendation {
549 parameter: "learning_rate".to_string(),
550 current_value: None,
551 recommended_value: 0.001,
552 reason: "High loss suggests learning rate might be too low".to_string(),
553 confidence: 0.7,
554 });
555 }
556 }
557 }
558
559 if let Some(_profiler_report) = context.profiler_report {
561 recommendations.push(HyperparameterRecommendation {
563 parameter: "batch_size".to_string(),
564 current_value: None,
565 recommended_value: 32.0,
566 reason: "Optimize batch size for better GPU utilization".to_string(),
567 confidence: 0.6,
568 });
569 }
570
571 recommendations
572 }
573
574 fn generate_architecture_suggestions(
576 &self,
577 context: &DebugContext,
578 ) -> Vec<ArchitectureSuggestion> {
579 let mut suggestions = Vec::new();
580
581 if let Some(model_info) = context.model_info {
583 if model_info.parameter_count > 100_000_000 {
584 suggestions.push(ArchitectureSuggestion {
585 suggestion_type: "model_compression".to_string(),
586 title: "Consider Model Compression".to_string(),
587 description: "Large model may benefit from pruning or distillation".to_string(),
588 impact_assessment: "Reduce memory usage by 20-50% with minimal accuracy loss"
589 .to_string(),
590 implementation_difficulty: "Medium".to_string(),
591 });
592 }
593
594 if model_info.layer_count > 50 {
595 suggestions.push(ArchitectureSuggestion {
596 suggestion_type: "depth_optimization".to_string(),
597 title: "Optimize Network Depth".to_string(),
598 description: "Very deep network may suffer from gradient flow issues"
599 .to_string(),
600 impact_assessment: "Improve training stability and convergence speed"
601 .to_string(),
602 implementation_difficulty: "High".to_string(),
603 });
604 }
605 }
606
607 suggestions
608 }
609
610 fn generate_training_recipe_optimization(
612 &self,
613 context: &DebugContext,
614 ) -> TrainingRecipeOptimization {
615 let mut optimizations = Vec::new();
616
617 if context.training_duration > Duration::from_secs(3600) {
619 optimizations
620 .push("Consider learning rate scheduling to speed up convergence".to_string());
621 optimizations.push("Implement early stopping to avoid overtraining".to_string());
622 }
623
624 if context.recent_metrics.len() > 10 {
626 let recent_losses: Vec<f64> =
627 context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
628
629 if recent_losses.len() >= 5 {
630 let variance = self.calculate_variance(&recent_losses);
631 if variance > 0.1 {
632 optimizations.push(
633 "Training loss is unstable - consider reducing learning rate".to_string(),
634 );
635 }
636 }
637 }
638
639 TrainingRecipeOptimization {
640 recommended_optimizations: optimizations,
641 training_schedule: TrainingSchedule {
642 warmup_steps: 1000,
643 learning_rate_schedule: "cosine_annealing".to_string(),
644 batch_size_schedule: "constant".to_string(),
645 early_stopping: true,
646 checkpoint_frequency: 1000,
647 },
648 data_strategy: DataStrategy {
649 data_augmentation: vec!["horizontal_flip".to_string(), "random_crop".to_string()],
650 sampling_strategy: "balanced".to_string(),
651 preprocessing_optimizations: vec![
652 "normalization".to_string(),
653 "standardization".to_string(),
654 ],
655 },
656 }
657 }
658
659 fn calculate_variance(&self, values: &[f64]) -> f64 {
660 if values.len() < 2 {
661 return 0.0;
662 }
663
664 let mean = values.iter().sum::<f64>() / values.len() as f64;
665 let variance =
666 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
667 variance
668 }
669
670 fn generate_analysis_summary(&self, recommendations: &[FixRecommendation]) -> String {
671 let critical_count = recommendations
672 .iter()
673 .filter(|r| matches!(r.issue.severity, IssueSeverity::Critical))
674 .count();
675
676 let high_count = recommendations
677 .iter()
678 .filter(|r| matches!(r.issue.severity, IssueSeverity::High))
679 .count();
680
681 if critical_count > 0 {
682 format!("Found {} critical issues requiring immediate attention. {} high-priority issues also detected.",
683 critical_count, high_count)
684 } else if high_count > 0 {
685 format!(
686 "Found {} high-priority issues that should be addressed soon.",
687 high_count
688 )
689 } else if !recommendations.is_empty() {
690 "Found some optimization opportunities to improve training performance.".to_string()
691 } else {
692 "No significant issues detected. Training appears to be proceeding normally."
693 .to_string()
694 }
695 }
696
697 fn calculate_overall_confidence(&self, recommendations: &[FixRecommendation]) -> f64 {
698 if recommendations.is_empty() {
699 return 1.0;
700 }
701
702 let sum_confidence: f64 = recommendations.iter().map(|r| r.confidence).sum();
703 sum_confidence / recommendations.len() as f64
704 }
705
706 pub fn record_optimization_attempt(&mut self, attempt: OptimizationAttempt) {
708 self.optimization_history.push(attempt);
709
710 if self.optimization_history.len() > 1000 {
712 self.optimization_history.drain(0..500);
713 }
714 }
715
716 pub fn get_optimization_history(&self) -> &[OptimizationAttempt] {
718 &self.optimization_history
719 }
720}
721
722#[derive(Debug)]
725struct GradientIssueDetector;
726
727impl GradientIssueDetector {
728 fn new() -> Self {
729 Self
730 }
731}
732
733impl IssueDetector for GradientIssueDetector {
734 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
735 let mut issues = Vec::new();
736
737 if let Some(gradient_report) = context.gradient_report {
738 if gradient_report.has_vanishing_gradients() {
740 issues.push(DetectedIssue {
741 issue_type: IssueType::VanishingGradients,
742 severity: IssueSeverity::High,
743 confidence: 0.9,
744 description: "Vanishing gradients detected in multiple layers".to_string(),
745 evidence: vec![Evidence {
746 metric_name: "gradient_norm".to_string(),
747 observed_value: 0.001,
748 expected_range: (0.01, 1.0),
749 explanation: "Gradient norms are significantly below normal range"
750 .to_string(),
751 }],
752 metrics: HashMap::new(),
753 detected_at: chrono::Utc::now(),
754 });
755 }
756
757 if gradient_report.has_exploding_gradients() {
759 issues.push(DetectedIssue {
760 issue_type: IssueType::ExplodingGradients,
761 severity: IssueSeverity::Critical,
762 confidence: 0.95,
763 description: "Exploding gradients detected - training instability likely"
764 .to_string(),
765 evidence: vec![Evidence {
766 metric_name: "gradient_norm".to_string(),
767 observed_value: 100.0,
768 expected_range: (0.01, 10.0),
769 explanation: "Gradient norms are extremely high".to_string(),
770 }],
771 metrics: HashMap::new(),
772 detected_at: chrono::Utc::now(),
773 });
774 }
775 }
776
777 Ok(issues)
778 }
779
780 fn get_detector_name(&self) -> &str {
781 "GradientIssueDetector"
782 }
783
784 fn get_supported_issues(&self) -> Vec<IssueType> {
785 vec![IssueType::VanishingGradients, IssueType::ExplodingGradients]
786 }
787}
788
789#[derive(Debug)]
790struct TrainingIssueDetector;
791
792impl TrainingIssueDetector {
793 fn new() -> Self {
794 Self
795 }
796}
797
798impl IssueDetector for TrainingIssueDetector {
799 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
800 let mut issues = Vec::new();
801
802 if context.recent_metrics.len() >= 10 {
804 let recent_losses: Vec<f64> =
805 context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
806
807 if recent_losses.len() >= 5 {
808 let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
810 / (recent_losses.len() / 2) as f64;
811 let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
812 / (recent_losses.len() - recent_losses.len() / 2) as f64;
813
814 if (first_half_avg - second_half_avg).abs() / first_half_avg < 0.01 {
815 issues.push(DetectedIssue {
816 issue_type: IssueType::TrainingStalled,
817 severity: IssueSeverity::Medium,
818 confidence: 0.8,
819 description: "Training appears to have stalled - loss not decreasing"
820 .to_string(),
821 evidence: vec![Evidence {
822 metric_name: "loss_change".to_string(),
823 observed_value: (first_half_avg - second_half_avg).abs()
824 / first_half_avg,
825 expected_range: (0.05, 1.0),
826 explanation: "Loss change is below expected threshold".to_string(),
827 }],
828 metrics: HashMap::new(),
829 detected_at: chrono::Utc::now(),
830 });
831 }
832 }
833 }
834
835 Ok(issues)
836 }
837
838 fn get_detector_name(&self) -> &str {
839 "TrainingIssueDetector"
840 }
841
842 fn get_supported_issues(&self) -> Vec<IssueType> {
843 vec![
844 IssueType::TrainingStalled,
845 IssueType::LossNotDecreasing,
846 IssueType::UnstableTraining,
847 ]
848 }
849}
850
851#[derive(Debug)]
852struct PerformanceIssueDetector;
853
854impl PerformanceIssueDetector {
855 fn new() -> Self {
856 Self
857 }
858}
859
860impl IssueDetector for PerformanceIssueDetector {
861 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
862 let mut issues = Vec::new();
863
864 if let Some(metrics) = context.recent_metrics.last() {
866 if let Some(gpu_util) = metrics.gpu_utilization {
867 if gpu_util < 0.5 {
868 issues.push(DetectedIssue {
869 issue_type: IssueType::LowGpuUtilization,
870 severity: IssueSeverity::Medium,
871 confidence: 0.8,
872 description:
873 "Low GPU utilization detected - compute resources underutilized"
874 .to_string(),
875 evidence: vec![Evidence {
876 metric_name: "gpu_utilization".to_string(),
877 observed_value: gpu_util,
878 expected_range: (0.7, 1.0),
879 explanation: "GPU utilization is below optimal range".to_string(),
880 }],
881 metrics: HashMap::new(),
882 detected_at: chrono::Utc::now(),
883 });
884 }
885 }
886 }
887
888 Ok(issues)
889 }
890
891 fn get_detector_name(&self) -> &str {
892 "PerformanceIssueDetector"
893 }
894
895 fn get_supported_issues(&self) -> Vec<IssueType> {
896 vec![
897 IssueType::LowGpuUtilization,
898 IssueType::SlowTraining,
899 IssueType::MemoryBottleneck,
900 ]
901 }
902}
903
904#[derive(Debug)]
905struct HyperparameterIssueDetector;
906
907impl HyperparameterIssueDetector {
908 fn new() -> Self {
909 Self
910 }
911}
912
913impl IssueDetector for HyperparameterIssueDetector {
914 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
915 let mut issues = Vec::new();
916
917 if let Some(metrics) = context.recent_metrics.last() {
918 if let Some(lr) = metrics.learning_rate {
920 if lr > 0.1 {
921 issues.push(DetectedIssue {
922 issue_type: IssueType::LearningRateTooHigh,
923 severity: IssueSeverity::High,
924 confidence: 0.7,
925 description:
926 "Learning rate appears too high - may cause training instability"
927 .to_string(),
928 evidence: vec![Evidence {
929 metric_name: "learning_rate".to_string(),
930 observed_value: lr,
931 expected_range: (0.0001, 0.01),
932 explanation: "Learning rate is above typical range".to_string(),
933 }],
934 metrics: HashMap::new(),
935 detected_at: chrono::Utc::now(),
936 });
937 } else if lr < 0.00001 {
938 issues.push(DetectedIssue {
939 issue_type: IssueType::LearningRateTooLow,
940 severity: IssueSeverity::Medium,
941 confidence: 0.6,
942 description: "Learning rate might be too low - training could be slow"
943 .to_string(),
944 evidence: vec![Evidence {
945 metric_name: "learning_rate".to_string(),
946 observed_value: lr,
947 expected_range: (0.0001, 0.01),
948 explanation: "Learning rate is below typical range".to_string(),
949 }],
950 metrics: HashMap::new(),
951 detected_at: chrono::Utc::now(),
952 });
953 }
954 }
955 }
956
957 Ok(issues)
958 }
959
960 fn get_detector_name(&self) -> &str {
961 "HyperparameterIssueDetector"
962 }
963
964 fn get_supported_issues(&self) -> Vec<IssueType> {
965 vec![
966 IssueType::LearningRateTooHigh,
967 IssueType::LearningRateTooLow,
968 ]
969 }
970}
971
972#[derive(Debug)]
973struct ArchitectureIssueDetector;
974
975impl ArchitectureIssueDetector {
976 fn new() -> Self {
977 Self
978 }
979}
980
981impl IssueDetector for ArchitectureIssueDetector {
982 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
983 let mut issues = Vec::new();
984
985 if let Some(model_info) = context.model_info {
986 if model_info.parameter_count > 1_000_000_000 {
988 issues.push(DetectedIssue {
989 issue_type: IssueType::ModelTooLarge,
990 severity: IssueSeverity::Medium,
991 confidence: 0.6,
992 description:
993 "Model has very large number of parameters - consider optimization"
994 .to_string(),
995 evidence: vec![Evidence {
996 metric_name: "parameter_count".to_string(),
997 observed_value: model_info.parameter_count as f64,
998 expected_range: (1_000_000.0, 100_000_000.0),
999 explanation: "Parameter count is extremely high".to_string(),
1000 }],
1001 metrics: HashMap::new(),
1002 detected_at: chrono::Utc::now(),
1003 });
1004 }
1005
1006 if model_info.layer_count > 100 {
1007 issues.push(DetectedIssue {
1008 issue_type: IssueType::InappropriateArchitecture,
1009 severity: IssueSeverity::Low,
1010 confidence: 0.5,
1011 description: "Very deep model - may have gradient flow issues".to_string(),
1012 evidence: vec![Evidence {
1013 metric_name: "layer_count".to_string(),
1014 observed_value: model_info.layer_count as f64,
1015 expected_range: (10.0, 50.0),
1016 explanation: "Layer count is very high".to_string(),
1017 }],
1018 metrics: HashMap::new(),
1019 detected_at: chrono::Utc::now(),
1020 });
1021 }
1022 }
1023
1024 Ok(issues)
1025 }
1026
1027 fn get_detector_name(&self) -> &str {
1028 "ArchitectureIssueDetector"
1029 }
1030
1031 fn get_supported_issues(&self) -> Vec<IssueType> {
1032 vec![
1033 IssueType::ModelTooLarge,
1034 IssueType::InappropriateArchitecture,
1035 ]
1036 }
1037}
1038
1039#[derive(Debug)]
1040struct DataIssueDetector;
1041
1042impl DataIssueDetector {
1043 fn new() -> Self {
1044 Self
1045 }
1046}
1047
1048impl IssueDetector for DataIssueDetector {
1049 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
1050 let mut issues = Vec::new();
1068
1069 const MIN_WINDOW: usize = 5;
1070 if context.recent_metrics.len() < MIN_WINDOW {
1071 return Ok(issues);
1072 }
1073
1074 let window: Vec<&DashboardMetrics> =
1076 context.recent_metrics.iter().rev().take(MIN_WINDOW * 2).collect();
1077
1078 let gpu_samples: Vec<f64> = window.iter().filter_map(|m| m.gpu_utilization).collect();
1080 let tps_samples: Vec<f64> = window.iter().filter_map(|m| m.tokens_per_second).collect();
1081 if gpu_samples.len() >= MIN_WINDOW && tps_samples.len() >= MIN_WINDOW {
1082 let gpu_mean = gpu_samples.iter().sum::<f64>() / gpu_samples.len() as f64;
1083 let tps_mean = tps_samples.iter().sum::<f64>() / tps_samples.len() as f64;
1084 if gpu_mean < 0.5 && tps_mean < 100.0 {
1088 let mut metrics = HashMap::new();
1089 metrics.insert("avg_gpu_utilization".to_string(), gpu_mean);
1090 metrics.insert("avg_tokens_per_second".to_string(), tps_mean);
1091 issues.push(DetectedIssue {
1092 issue_type: IssueType::BatchSizeProblems,
1093 severity: IssueSeverity::Medium,
1094 confidence: 0.7,
1095 description:
1096 "Sustained low GPU utilisation and throughput suggest batch size may be \
1097 too small to saturate the device"
1098 .to_string(),
1099 evidence: vec![
1100 Evidence {
1101 metric_name: "gpu_utilization".to_string(),
1102 observed_value: gpu_mean,
1103 expected_range: (0.7, 1.0),
1104 explanation:
1105 "Average GPU utilisation is below the typical training range"
1106 .to_string(),
1107 },
1108 Evidence {
1109 metric_name: "tokens_per_second".to_string(),
1110 observed_value: tps_mean,
1111 expected_range: (100.0, f64::INFINITY),
1112 explanation: "Throughput is below the typical training floor"
1113 .to_string(),
1114 },
1115 ],
1116 metrics,
1117 detected_at: chrono::Utc::now(),
1118 });
1119 }
1120 }
1121
1122 let acc_samples: Vec<f64> = window.iter().filter_map(|m| m.accuracy).collect();
1124 let loss_samples: Vec<f64> = window.iter().filter_map(|m| m.loss).collect();
1125
1126 if acc_samples.len() >= MIN_WINDOW && loss_samples.len() >= MIN_WINDOW {
1127 let acc_mean = acc_samples.iter().sum::<f64>() / acc_samples.len() as f64;
1128 let acc_var = acc_samples
1129 .iter()
1130 .map(|a| {
1131 let d = a - acc_mean;
1132 d * d
1133 })
1134 .sum::<f64>()
1135 / acc_samples.len() as f64;
1136 let acc_stddev = acc_var.sqrt();
1137
1138 let half = loss_samples.len() / 2;
1142 let newer_half = &loss_samples[..half];
1143 let older_half = &loss_samples[loss_samples.len() - half..];
1144 let newer_avg = if newer_half.is_empty() {
1145 0.0
1146 } else {
1147 newer_half.iter().sum::<f64>() / newer_half.len() as f64
1148 };
1149 let older_avg = if older_half.is_empty() {
1150 0.0
1151 } else {
1152 older_half.iter().sum::<f64>() / older_half.len() as f64
1153 };
1154 let loss_relative_change = if older_avg.abs() > f64::EPSILON {
1156 (older_avg - newer_avg) / older_avg.abs()
1157 } else {
1158 0.0
1159 };
1160
1161 let acc_pinned_extreme = acc_stddev < 0.01 && !(0.2..=0.95).contains(&acc_mean);
1166 let loss_changing = loss_relative_change.abs() > 0.05;
1167 if acc_pinned_extreme && loss_changing {
1168 let mut metrics = HashMap::new();
1169 metrics.insert("accuracy_mean".to_string(), acc_mean);
1170 metrics.insert("accuracy_stddev".to_string(), acc_stddev);
1171 metrics.insert("loss_relative_change".to_string(), loss_relative_change);
1172 issues.push(DetectedIssue {
1173 issue_type: IssueType::DataImbalance,
1174 severity: IssueSeverity::High,
1175 confidence: 0.75,
1176 description:
1177 "Accuracy is pinned at an extreme value while loss continues to change \
1178 — model may be collapsing onto a majority class"
1179 .to_string(),
1180 evidence: vec![Evidence {
1181 metric_name: "accuracy_stddev".to_string(),
1182 observed_value: acc_stddev,
1183 expected_range: (0.01, 0.5),
1184 explanation:
1185 "Accuracy variance is far below the range expected during healthy \
1186 training"
1187 .to_string(),
1188 }],
1189 metrics,
1190 detected_at: chrono::Utc::now(),
1191 });
1192 }
1193
1194 if loss_relative_change > 0.10 && acc_stddev > 0.15 {
1199 let mut metrics = HashMap::new();
1200 metrics.insert("accuracy_stddev".to_string(), acc_stddev);
1201 metrics.insert("loss_relative_change".to_string(), loss_relative_change);
1202 issues.push(DetectedIssue {
1203 issue_type: IssueType::InsufficientData,
1204 severity: IssueSeverity::Medium,
1205 confidence: 0.6,
1206 description:
1207 "Loss is decreasing but accuracy fluctuates wildly — the dataset may be \
1208 too small, leading to memorisation rather than generalisation"
1209 .to_string(),
1210 evidence: vec![Evidence {
1211 metric_name: "accuracy_stddev".to_string(),
1212 observed_value: acc_stddev,
1213 expected_range: (0.0, 0.10),
1214 explanation:
1215 "Accuracy variance is well above what is expected when the model \
1216 is generalising"
1217 .to_string(),
1218 }],
1219 metrics,
1220 detected_at: chrono::Utc::now(),
1221 });
1222 }
1223 }
1224
1225 Ok(issues)
1226 }
1227
1228 fn get_detector_name(&self) -> &str {
1229 "DataIssueDetector"
1230 }
1231
1232 fn get_supported_issues(&self) -> Vec<IssueType> {
1233 vec![
1234 IssueType::DataImbalance,
1235 IssueType::BatchSizeProblems,
1236 IssueType::InsufficientData,
1237 ]
1238 }
1239}
1240
1241impl Default for KnowledgeBase {
1242 fn default() -> Self {
1243 Self::new()
1244 }
1245}
1246
1247impl KnowledgeBase {
1248 pub fn new() -> Self {
1249 Self {
1250 issue_patterns: HashMap::new(),
1251 hyperparameter_recommendations: HashMap::new(),
1252 architecture_patterns: Vec::new(),
1253 best_practices: HashMap::new(),
1254 }
1255 }
1256}
1257
1258#[derive(Debug, Serialize, Deserialize)]
1261pub struct AutoDebugReport {
1262 pub detected_issues: Vec<DetectedIssue>,
1263 pub fix_recommendations: Vec<FixRecommendation>,
1264 pub hyperparameter_recommendations: Vec<HyperparameterRecommendation>,
1265 pub architecture_suggestions: Vec<ArchitectureSuggestion>,
1266 pub training_recipe: TrainingRecipeOptimization,
1267 pub analysis_summary: String,
1268 pub confidence_score: f64,
1269}
1270
1271#[derive(Debug, Clone, Serialize, Deserialize)]
1272pub struct FixRecommendation {
1273 pub issue: DetectedIssue,
1274 pub fix_suggestion: FixSuggestion,
1275 pub confidence: f64,
1276 pub urgency: f64,
1277}
1278
1279#[derive(Debug, Clone, Serialize, Deserialize)]
1280pub struct HyperparameterRecommendation {
1281 pub parameter: String,
1282 pub current_value: Option<f64>,
1283 pub recommended_value: f64,
1284 pub reason: String,
1285 pub confidence: f64,
1286}
1287
1288#[derive(Debug, Clone, Serialize, Deserialize)]
1289pub struct ArchitectureSuggestion {
1290 pub suggestion_type: String,
1291 pub title: String,
1292 pub description: String,
1293 pub impact_assessment: String,
1294 pub implementation_difficulty: String,
1295}
1296
1297#[derive(Debug, Clone, Serialize, Deserialize)]
1298pub struct TrainingRecipeOptimization {
1299 pub recommended_optimizations: Vec<String>,
1300 pub training_schedule: TrainingSchedule,
1301 pub data_strategy: DataStrategy,
1302}
1303
1304#[derive(Debug, Clone, Serialize, Deserialize)]
1305pub struct TrainingSchedule {
1306 pub warmup_steps: u32,
1307 pub learning_rate_schedule: String,
1308 pub batch_size_schedule: String,
1309 pub early_stopping: bool,
1310 pub checkpoint_frequency: u32,
1311}
1312
1313#[derive(Debug, Clone, Serialize, Deserialize)]
1314pub struct DataStrategy {
1315 pub data_augmentation: Vec<String>,
1316 pub sampling_strategy: String,
1317 pub preprocessing_optimizations: Vec<String>,
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322 use super::*;
1323
1324 fn make_config() -> DebugConfig {
1325 DebugConfig::default()
1326 }
1327
1328 #[test]
1329 fn test_knowledge_base_new() {
1330 let kb = KnowledgeBase::new();
1331 assert!(kb.issue_patterns.is_empty());
1332 assert!(kb.hyperparameter_recommendations.is_empty());
1333 assert!(kb.architecture_patterns.is_empty());
1334 assert!(kb.best_practices.is_empty());
1335 }
1336
1337 #[test]
1338 fn test_knowledge_base_default() {
1339 let kb = KnowledgeBase::default();
1340 assert!(kb.issue_patterns.is_empty());
1341 }
1342
1343 #[test]
1344 fn test_auto_debugger_new() {
1345 let config = make_config();
1346 let debugger = AutoDebugger::new(&config);
1347 assert!(!debugger.issue_detectors.is_empty());
1348 assert!(!debugger.fix_suggestions.is_empty());
1349 assert!(debugger.optimization_history.is_empty());
1350 }
1351
1352 #[test]
1353 fn test_auto_debugger_has_default_detectors() {
1354 let config = make_config();
1355 let debugger = AutoDebugger::new(&config);
1356 assert_eq!(debugger.issue_detectors.len(), 6);
1357 }
1358
1359 #[test]
1360 fn test_auto_debugger_has_fix_suggestions() {
1361 let config = make_config();
1362 let debugger = AutoDebugger::new(&config);
1363 assert!(debugger.fix_suggestions.contains_key(&IssueType::VanishingGradients));
1364 assert!(debugger.fix_suggestions.contains_key(&IssueType::ExplodingGradients));
1365 }
1366
1367 #[test]
1368 fn test_gradient_issue_detector_name() {
1369 let detector = GradientIssueDetector::new();
1370 assert_eq!(detector.get_detector_name(), "GradientIssueDetector");
1371 }
1372
1373 #[test]
1374 fn test_gradient_issue_detector_supported_issues() {
1375 let detector = GradientIssueDetector::new();
1376 let issues = detector.get_supported_issues();
1377 assert!(issues.contains(&IssueType::VanishingGradients));
1378 assert!(issues.contains(&IssueType::ExplodingGradients));
1379 }
1380
1381 #[test]
1382 fn test_training_issue_detector_name() {
1383 let detector = TrainingIssueDetector::new();
1384 assert_eq!(detector.get_detector_name(), "TrainingIssueDetector");
1385 }
1386
1387 #[test]
1388 fn test_training_issue_detector_supported_issues() {
1389 let detector = TrainingIssueDetector::new();
1390 let issues = detector.get_supported_issues();
1391 assert!(!issues.is_empty());
1392 }
1393
1394 #[test]
1395 fn test_performance_issue_detector_name() {
1396 let detector = PerformanceIssueDetector::new();
1397 assert_eq!(detector.get_detector_name(), "PerformanceIssueDetector");
1398 }
1399
1400 #[test]
1401 fn test_hyperparameter_issue_detector_name() {
1402 let detector = HyperparameterIssueDetector::new();
1403 assert_eq!(detector.get_detector_name(), "HyperparameterIssueDetector");
1404 }
1405
1406 #[test]
1407 fn test_architecture_issue_detector_name() {
1408 let detector = ArchitectureIssueDetector::new();
1409 assert_eq!(detector.get_detector_name(), "ArchitectureIssueDetector");
1410 }
1411
1412 #[test]
1413 fn test_data_issue_detector_name() {
1414 let detector = DataIssueDetector::new();
1415 assert_eq!(detector.get_detector_name(), "DataIssueDetector");
1416 }
1417
1418 #[test]
1419 fn test_issue_type_equality() {
1420 assert_eq!(IssueType::VanishingGradients, IssueType::VanishingGradients);
1421 assert_ne!(IssueType::VanishingGradients, IssueType::ExplodingGradients);
1422 }
1423
1424 #[test]
1425 fn test_issue_type_hash_compatible() {
1426 let mut map = HashMap::new();
1427 map.insert(IssueType::OverfittingDetected, "fix");
1428 assert!(map.contains_key(&IssueType::OverfittingDetected));
1429 assert!(!map.contains_key(&IssueType::UnderfittingDetected));
1430 }
1431
1432 #[test]
1433 fn test_evidence_construction() {
1434 let evidence = Evidence {
1435 metric_name: "gradient_norm".to_string(),
1436 observed_value: 0.001,
1437 expected_range: (0.01, 1.0),
1438 explanation: "Gradient norm too low".to_string(),
1439 };
1440 assert_eq!(evidence.metric_name, "gradient_norm");
1441 assert!(evidence.observed_value < evidence.expected_range.0);
1442 }
1443
1444 #[test]
1445 fn test_expected_impact_fields() {
1446 let impact = ExpectedImpact {
1447 performance_improvement: 0.15,
1448 training_speed_improvement: 0.05,
1449 stability_improvement: 0.25,
1450 memory_usage_change: 0.02,
1451 };
1452 assert!(impact.performance_improvement > 0.0);
1453 assert!(impact.stability_improvement > impact.performance_improvement);
1454 }
1455
1456 #[test]
1457 fn test_model_info_construction() {
1458 let info = ModelInfo {
1459 model_type: "transformer".to_string(),
1460 parameter_count: 1_000_000,
1461 layer_count: 12,
1462 architecture_details: HashMap::new(),
1463 };
1464 assert_eq!(info.model_type, "transformer");
1465 assert_eq!(info.parameter_count, 1_000_000);
1466 }
1467
1468 #[test]
1469 fn test_issue_pattern_construction() {
1470 let pattern = IssuePattern {
1471 symptoms: vec!["low gradient norm".to_string()],
1472 common_causes: vec!["deep network".to_string()],
1473 diagnostic_metrics: vec!["gradient_norm".to_string()],
1474 typical_solutions: vec!["add skip connections".to_string()],
1475 };
1476 assert_eq!(pattern.symptoms.len(), 1);
1477 assert_eq!(pattern.common_causes.len(), 1);
1478 }
1479
1480 #[test]
1481 fn test_hyperparameter_advice_construction() {
1482 let advice = HyperparameterAdvice {
1483 parameter_name: "learning_rate".to_string(),
1484 recommended_range: (1e-5, 1e-2),
1485 tuning_strategy: "grid_search".to_string(),
1486 dependencies: vec!["batch_size".to_string()],
1487 common_mistakes: vec!["too high initial lr".to_string()],
1488 };
1489 assert!(advice.recommended_range.0 < advice.recommended_range.1);
1490 }
1491
1492 fn make_metric(
1493 loss: Option<f64>,
1494 accuracy: Option<f64>,
1495 gpu: Option<f64>,
1496 tps: Option<f64>,
1497 ) -> DashboardMetrics {
1498 DashboardMetrics {
1499 timestamp: std::time::SystemTime::now(),
1500 loss,
1501 accuracy,
1502 learning_rate: Some(1e-3),
1503 memory_usage_mb: 1024.0,
1504 gpu_utilization: gpu,
1505 tokens_per_second: tps,
1506 gradient_norm: Some(0.5),
1507 epoch: Some(0),
1508 step: Some(0),
1509 }
1510 }
1511
1512 #[test]
1513 fn test_data_issue_detector_returns_empty_with_no_metrics() {
1514 let detector = DataIssueDetector::new();
1515 let context = DebugContext {
1516 profiler_report: None,
1517 gradient_report: None,
1518 anomaly_report: None,
1519 recent_metrics: &[],
1520 training_duration: Duration::from_secs(60),
1521 model_info: None,
1522 };
1523 let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1524 assert!(issues.is_empty());
1525 }
1526
1527 #[test]
1528 fn test_data_issue_detector_flags_batch_size_problem() {
1529 let detector = DataIssueDetector::new();
1530 let metrics: Vec<DashboardMetrics> = (0..10)
1532 .map(|i| {
1533 make_metric(
1534 Some(2.0 - i as f64 * 0.01),
1535 Some(0.6),
1536 Some(0.2),
1537 Some(50.0),
1538 )
1539 })
1540 .collect();
1541 let context = DebugContext {
1542 profiler_report: None,
1543 gradient_report: None,
1544 anomaly_report: None,
1545 recent_metrics: &metrics,
1546 training_duration: Duration::from_secs(600),
1547 model_info: None,
1548 };
1549 let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1550 assert!(
1551 issues.iter().any(|i| i.issue_type == IssueType::BatchSizeProblems),
1552 "expected BatchSizeProblems to be flagged, got: {:?}",
1553 issues.iter().map(|i| &i.issue_type).collect::<Vec<_>>()
1554 );
1555 }
1556
1557 #[test]
1558 fn test_data_issue_detector_flags_data_imbalance_when_accuracy_pinned() {
1559 let detector = DataIssueDetector::new();
1560 let metrics: Vec<DashboardMetrics> = (0..10)
1563 .map(|i| {
1564 make_metric(
1565 Some(2.0 - i as f64 * 0.10),
1566 Some(0.97),
1567 Some(0.85),
1568 Some(500.0),
1569 )
1570 })
1571 .collect();
1572 let context = DebugContext {
1573 profiler_report: None,
1574 gradient_report: None,
1575 anomaly_report: None,
1576 recent_metrics: &metrics,
1577 training_duration: Duration::from_secs(600),
1578 model_info: None,
1579 };
1580 let issues = detector.detect_issues(&context).expect("detect_issues should succeed");
1581 assert!(
1582 issues.iter().any(|i| i.issue_type == IssueType::DataImbalance),
1583 "expected DataImbalance to be flagged, got: {:?}",
1584 issues.iter().map(|i| &i.issue_type).collect::<Vec<_>>()
1585 );
1586 }
1587}