1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{
9 AnomalyDetectorReport, DashboardMetrics, DebugConfig, GradientDebugReport, ProfilerReport,
10};
11
12#[derive(Debug)]
14#[allow(dead_code)]
15pub struct AutoDebugger {
16 #[allow(dead_code)]
17 config: DebugConfig,
18 issue_detectors: Vec<Box<dyn IssueDetector>>,
19 fix_suggestions: HashMap<IssueType, Vec<FixSuggestion>>,
20 optimization_history: Vec<OptimizationAttempt>,
21 knowledge_base: KnowledgeBase,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
26pub enum IssueType {
27 VanishingGradients,
29 ExplodingGradients,
30 LearningRateProblems,
31 OverfittingDetected,
32 UnderfittingDetected,
33 TrainingStalled,
34 LossNotDecreasing,
35 UnstableTraining,
36 MemoryIssues,
37
38 ModelTooLarge,
40 ModelTooSmall,
41 InappropriateArchitecture,
42 LayerMismatch,
43 ActivationProblems,
44
45 DataImbalance,
47 DataLeakage,
48 InsufficientData,
49 DataQualityIssues,
50 BatchSizeProblems,
51
52 SlowTraining,
54 LowGpuUtilization,
55 MemoryBottleneck,
56 IoBottleneck,
57 ComputeBottleneck,
58
59 LearningRateTooHigh,
61 LearningRateTooLow,
62 BatchSizeTooLarge,
63 BatchSizeTooSmall,
64 RegularizationIssues,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct DetectedIssue {
70 pub issue_type: IssueType,
71 pub severity: IssueSeverity,
72 pub confidence: f64,
73 pub description: String,
74 pub evidence: Vec<Evidence>,
75 pub metrics: HashMap<String, f64>,
76 pub detected_at: chrono::DateTime<chrono::Utc>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum IssueSeverity {
81 Critical,
82 High,
83 Medium,
84 Low,
85 Info,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Evidence {
90 pub metric_name: String,
91 pub observed_value: f64,
92 pub expected_range: (f64, f64),
93 pub explanation: String,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct FixSuggestion {
99 pub fix_id: String,
100 pub fix_type: FixType,
101 pub title: String,
102 pub description: String,
103 pub implementation_steps: Vec<String>,
104 pub expected_impact: ExpectedImpact,
105 pub priority: FixPriority,
106 pub estimated_effort: EstimatedEffort,
107 pub prerequisites: Vec<String>,
108 pub code_examples: Vec<CodeExample>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum FixType {
113 HyperparameterAdjustment,
114 ArchitectureChange,
115 TrainingProcedure,
116 DataProcessing,
117 OptimizationTechnique,
118 EnvironmentConfig,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ExpectedImpact {
123 pub performance_improvement: f64,
124 pub training_speed_improvement: f64,
125 pub stability_improvement: f64,
126 pub memory_usage_change: f64,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub enum FixPriority {
131 Critical,
132 High,
133 Medium,
134 Low,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub enum EstimatedEffort {
139 Trivial, Easy, Medium, Hard, Complex, }
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct CodeExample {
148 pub language: String,
149 pub code: String,
150 pub explanation: String,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct OptimizationAttempt {
156 pub attempt_id: String,
157 pub issue_addressed: IssueType,
158 pub fix_applied: String,
159 pub before_metrics: HashMap<String, f64>,
160 pub after_metrics: Option<HashMap<String, f64>>,
161 pub success: Option<bool>,
162 pub notes: String,
163 pub timestamp: chrono::DateTime<chrono::Utc>,
164}
165
166#[derive(Debug)]
168#[allow(dead_code)]
169pub struct KnowledgeBase {
170 #[allow(dead_code)]
171 issue_patterns: HashMap<IssueType, IssuePattern>,
172 hyperparameter_recommendations: HashMap<String, HyperparameterAdvice>,
173 architecture_patterns: Vec<ArchitecturePattern>,
174 best_practices: HashMap<String, Vec<String>>,
175}
176
177#[derive(Debug, Clone)]
178pub struct IssuePattern {
179 pub symptoms: Vec<String>,
180 pub common_causes: Vec<String>,
181 pub diagnostic_metrics: Vec<String>,
182 pub typical_solutions: Vec<String>,
183}
184
185#[derive(Debug, Clone)]
186pub struct HyperparameterAdvice {
187 pub parameter_name: String,
188 pub recommended_range: (f64, f64),
189 pub tuning_strategy: String,
190 pub dependencies: Vec<String>,
191 pub common_mistakes: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
195pub struct ArchitecturePattern {
196 pub pattern_name: String,
197 pub use_cases: Vec<String>,
198 pub typical_layers: Vec<String>,
199 pub hyperparameter_suggestions: HashMap<String, f64>,
200 pub performance_characteristics: String,
201}
202
203pub trait IssueDetector: std::fmt::Debug {
205 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>>;
206 fn get_detector_name(&self) -> &str;
207 fn get_supported_issues(&self) -> Vec<IssueType>;
208}
209
210#[derive(Debug)]
212pub struct DebugContext<'a> {
213 pub profiler_report: Option<&'a ProfilerReport>,
214 pub gradient_report: Option<&'a GradientDebugReport>,
215 pub anomaly_report: Option<&'a AnomalyDetectorReport>,
216 pub recent_metrics: &'a [DashboardMetrics],
217 pub training_duration: Duration,
218 pub model_info: Option<&'a ModelInfo>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct ModelInfo {
223 pub model_type: String,
224 pub parameter_count: usize,
225 pub layer_count: usize,
226 pub architecture_details: HashMap<String, String>,
227}
228
229impl AutoDebugger {
230 pub fn new(config: &DebugConfig) -> Self {
232 let mut auto_debugger = Self {
233 config: config.clone(),
234 issue_detectors: Vec::new(),
235 fix_suggestions: HashMap::new(),
236 optimization_history: Vec::new(),
237 knowledge_base: KnowledgeBase::new(),
238 };
239
240 auto_debugger.register_default_detectors();
242 auto_debugger.initialize_fix_suggestions();
243
244 auto_debugger
245 }
246
247 fn register_default_detectors(&mut self) {
249 self.issue_detectors.push(Box::new(GradientIssueDetector::new()));
250 self.issue_detectors.push(Box::new(TrainingIssueDetector::new()));
251 self.issue_detectors.push(Box::new(PerformanceIssueDetector::new()));
252 self.issue_detectors.push(Box::new(HyperparameterIssueDetector::new()));
253 self.issue_detectors.push(Box::new(ArchitectureIssueDetector::new()));
254 self.issue_detectors.push(Box::new(DataIssueDetector::new()));
255 }
256
257 fn initialize_fix_suggestions(&mut self) {
259 self.fix_suggestions.insert(
261 IssueType::VanishingGradients,
262 vec![
263 FixSuggestion {
264 fix_id: "vg_001".to_string(),
265 fix_type: FixType::ArchitectureChange,
266 title: "Add Residual Connections".to_string(),
267 description:
268 "Implement skip connections to help gradients flow through deep networks"
269 .to_string(),
270 implementation_steps: vec![
271 "Add residual blocks to your model architecture".to_string(),
272 "Ensure input and output dimensions match for residual connections"
273 .to_string(),
274 "Consider using batch normalization within residual blocks".to_string(),
275 ],
276 expected_impact: ExpectedImpact {
277 performance_improvement: 0.15,
278 training_speed_improvement: 0.05,
279 stability_improvement: 0.25,
280 memory_usage_change: 0.02,
281 },
282 priority: FixPriority::High,
283 estimated_effort: EstimatedEffort::Medium,
284 prerequisites: vec!["Model architecture access".to_string()],
285 code_examples: vec![CodeExample {
286 language: "python".to_string(),
287 code: r#"
288class ResidualBlock(nn.Module):
289 def __init__(self, channels):
290 super().__init__()
291 self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
292 self.bn1 = nn.BatchNorm2d(channels)
293 self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
294 self.bn2 = nn.BatchNorm2d(channels)
295
296 def forward(self, x):
297 residual = x
298 out = F.relu(self.bn1(self.conv1(x)))
299 out = self.bn2(self.conv2(out))
300 out += residual # Skip connection
301 return F.relu(out)
302"#
303 .to_string(),
304 explanation: "Basic residual block implementation with skip connection"
305 .to_string(),
306 }],
307 },
308 FixSuggestion {
309 fix_id: "vg_002".to_string(),
310 fix_type: FixType::HyperparameterAdjustment,
311 title: "Adjust Learning Rate".to_string(),
312 description:
313 "Increase learning rate to help gradients propagate more effectively"
314 .to_string(),
315 implementation_steps: vec![
316 "Increase learning rate by 2-5x".to_string(),
317 "Monitor training stability".to_string(),
318 "Consider learning rate scheduling".to_string(),
319 ],
320 expected_impact: ExpectedImpact {
321 performance_improvement: 0.08,
322 training_speed_improvement: 0.10,
323 stability_improvement: -0.05,
324 memory_usage_change: 0.0,
325 },
326 priority: FixPriority::Medium,
327 estimated_effort: EstimatedEffort::Trivial,
328 prerequisites: vec![],
329 code_examples: vec![CodeExample {
330 language: "python".to_string(),
331 code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
332 .to_string(),
333 explanation: "Increase learning rate to help overcome vanishing gradients"
334 .to_string(),
335 }],
336 },
337 ],
338 );
339
340 self.fix_suggestions.insert(
342 IssueType::ExplodingGradients,
343 vec![FixSuggestion {
344 fix_id: "eg_001".to_string(),
345 fix_type: FixType::TrainingProcedure,
346 title: "Apply Gradient Clipping".to_string(),
347 description: "Clip gradients to prevent explosion during backpropagation"
348 .to_string(),
349 implementation_steps: vec![
350 "Add gradient clipping to your training loop".to_string(),
351 "Start with clip value of 1.0 and adjust based on results".to_string(),
352 "Monitor gradient norms to ensure clipping is effective".to_string(),
353 ],
354 expected_impact: ExpectedImpact {
355 performance_improvement: 0.10,
356 training_speed_improvement: 0.0,
357 stability_improvement: 0.30,
358 memory_usage_change: 0.0,
359 },
360 priority: FixPriority::Critical,
361 estimated_effort: EstimatedEffort::Easy,
362 prerequisites: vec![],
363 code_examples: vec![CodeExample {
364 language: "python".to_string(),
365 code: "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
366 .to_string(),
367 explanation: "Clip gradients before optimizer step".to_string(),
368 }],
369 }],
370 );
371
372 self.fix_suggestions.insert(
374 IssueType::LearningRateTooHigh,
375 vec![FixSuggestion {
376 fix_id: "lr_high_001".to_string(),
377 fix_type: FixType::HyperparameterAdjustment,
378 title: "Reduce Learning Rate".to_string(),
379 description: "Lower the learning rate to improve training stability".to_string(),
380 implementation_steps: vec![
381 "Reduce learning rate by 2-10x".to_string(),
382 "Consider learning rate scheduling".to_string(),
383 "Monitor loss convergence".to_string(),
384 ],
385 expected_impact: ExpectedImpact {
386 performance_improvement: 0.12,
387 training_speed_improvement: -0.05,
388 stability_improvement: 0.25,
389 memory_usage_change: 0.0,
390 },
391 priority: FixPriority::High,
392 estimated_effort: EstimatedEffort::Trivial,
393 prerequisites: vec![],
394 code_examples: vec![CodeExample {
395 language: "python".to_string(),
396 code: "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)".to_string(),
397 explanation: "Reduce learning rate for more stable training".to_string(),
398 }],
399 }],
400 );
401
402 self.fix_suggestions.insert(
404 IssueType::LowGpuUtilization,
405 vec![FixSuggestion {
406 fix_id: "gpu_001".to_string(),
407 fix_type: FixType::HyperparameterAdjustment,
408 title: "Increase Batch Size".to_string(),
409 description: "Increase batch size to better utilize GPU compute capacity"
410 .to_string(),
411 implementation_steps: vec![
412 "Double the current batch size".to_string(),
413 "Monitor memory usage to avoid OOM".to_string(),
414 "Adjust learning rate proportionally".to_string(),
415 ],
416 expected_impact: ExpectedImpact {
417 performance_improvement: 0.05,
418 training_speed_improvement: 0.30,
419 stability_improvement: 0.0,
420 memory_usage_change: 0.20,
421 },
422 priority: FixPriority::Medium,
423 estimated_effort: EstimatedEffort::Easy,
424 prerequisites: vec!["Available GPU memory".to_string()],
425 code_examples: vec![CodeExample {
426 language: "python".to_string(),
427 code: "train_loader = DataLoader(dataset, batch_size=64, shuffle=True)"
428 .to_string(),
429 explanation: "Increase batch size to improve GPU utilization".to_string(),
430 }],
431 }],
432 );
433 }
434
435 pub fn analyze_issues(&self, context: &DebugContext) -> Result<AutoDebugReport> {
437 let mut all_issues = Vec::new();
438
439 for detector in &self.issue_detectors {
441 match detector.detect_issues(context) {
442 Ok(mut issues) => all_issues.append(&mut issues),
443 Err(e) => {
444 tracing::warn!(
445 "Issue detector '{}' failed: {}",
446 detector.get_detector_name(),
447 e
448 );
449 },
450 }
451 }
452
453 all_issues.sort_by(|a, b| {
455 let severity_order = |s: &IssueSeverity| match s {
456 IssueSeverity::Critical => 0,
457 IssueSeverity::High => 1,
458 IssueSeverity::Medium => 2,
459 IssueSeverity::Low => 3,
460 IssueSeverity::Info => 4,
461 };
462
463 let severity_cmp = severity_order(&a.severity).cmp(&severity_order(&b.severity));
464 if severity_cmp == std::cmp::Ordering::Equal {
465 b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
466 } else {
467 severity_cmp
468 }
469 });
470
471 let fix_recommendations = self.generate_fix_recommendations(&all_issues);
473
474 let hyperparameter_recommendations = self.generate_hyperparameter_recommendations(context);
476
477 let architecture_suggestions = self.generate_architecture_suggestions(context);
479
480 let training_recipe = self.generate_training_recipe_optimization(context);
482
483 Ok(AutoDebugReport {
484 detected_issues: all_issues,
485 fix_recommendations: fix_recommendations.clone(),
486 hyperparameter_recommendations,
487 architecture_suggestions,
488 training_recipe,
489 analysis_summary: self.generate_analysis_summary(&fix_recommendations),
490 confidence_score: self.calculate_overall_confidence(&fix_recommendations),
491 })
492 }
493
494 fn generate_fix_recommendations(&self, issues: &[DetectedIssue]) -> Vec<FixRecommendation> {
496 let mut recommendations = Vec::new();
497
498 for issue in issues {
499 if let Some(suggestions) = self.fix_suggestions.get(&issue.issue_type) {
500 for suggestion in suggestions {
501 recommendations.push(FixRecommendation {
502 issue: issue.clone(),
503 fix_suggestion: suggestion.clone(),
504 confidence: issue.confidence * 0.9, urgency: self.calculate_urgency(issue),
506 });
507 }
508 }
509 }
510
511 recommendations.sort_by(|a, b| {
513 let urgency_cmp =
514 b.urgency.partial_cmp(&a.urgency).unwrap_or(std::cmp::Ordering::Equal);
515 if urgency_cmp == std::cmp::Ordering::Equal {
516 b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal)
517 } else {
518 urgency_cmp
519 }
520 });
521
522 recommendations
523 }
524
525 fn calculate_urgency(&self, issue: &DetectedIssue) -> f64 {
526 let severity_multiplier = match issue.severity {
527 IssueSeverity::Critical => 1.0,
528 IssueSeverity::High => 0.8,
529 IssueSeverity::Medium => 0.6,
530 IssueSeverity::Low => 0.4,
531 IssueSeverity::Info => 0.2,
532 };
533
534 issue.confidence * severity_multiplier
535 }
536
537 fn generate_hyperparameter_recommendations(
539 &self,
540 context: &DebugContext,
541 ) -> Vec<HyperparameterRecommendation> {
542 let mut recommendations = Vec::new();
543
544 if let Some(metrics) = context.recent_metrics.last() {
546 if let Some(loss) = metrics.loss {
547 if loss > 1.0 {
548 recommendations.push(HyperparameterRecommendation {
549 parameter: "learning_rate".to_string(),
550 current_value: None,
551 recommended_value: 0.001,
552 reason: "High loss suggests learning rate might be too low".to_string(),
553 confidence: 0.7,
554 });
555 }
556 }
557 }
558
559 if let Some(_profiler_report) = context.profiler_report {
561 recommendations.push(HyperparameterRecommendation {
563 parameter: "batch_size".to_string(),
564 current_value: None,
565 recommended_value: 32.0,
566 reason: "Optimize batch size for better GPU utilization".to_string(),
567 confidence: 0.6,
568 });
569 }
570
571 recommendations
572 }
573
574 fn generate_architecture_suggestions(
576 &self,
577 context: &DebugContext,
578 ) -> Vec<ArchitectureSuggestion> {
579 let mut suggestions = Vec::new();
580
581 if let Some(model_info) = context.model_info {
583 if model_info.parameter_count > 100_000_000 {
584 suggestions.push(ArchitectureSuggestion {
585 suggestion_type: "model_compression".to_string(),
586 title: "Consider Model Compression".to_string(),
587 description: "Large model may benefit from pruning or distillation".to_string(),
588 impact_assessment: "Reduce memory usage by 20-50% with minimal accuracy loss"
589 .to_string(),
590 implementation_difficulty: "Medium".to_string(),
591 });
592 }
593
594 if model_info.layer_count > 50 {
595 suggestions.push(ArchitectureSuggestion {
596 suggestion_type: "depth_optimization".to_string(),
597 title: "Optimize Network Depth".to_string(),
598 description: "Very deep network may suffer from gradient flow issues"
599 .to_string(),
600 impact_assessment: "Improve training stability and convergence speed"
601 .to_string(),
602 implementation_difficulty: "High".to_string(),
603 });
604 }
605 }
606
607 suggestions
608 }
609
610 fn generate_training_recipe_optimization(
612 &self,
613 context: &DebugContext,
614 ) -> TrainingRecipeOptimization {
615 let mut optimizations = Vec::new();
616
617 if context.training_duration > Duration::from_secs(3600) {
619 optimizations
620 .push("Consider learning rate scheduling to speed up convergence".to_string());
621 optimizations.push("Implement early stopping to avoid overtraining".to_string());
622 }
623
624 if context.recent_metrics.len() > 10 {
626 let recent_losses: Vec<f64> =
627 context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
628
629 if recent_losses.len() >= 5 {
630 let variance = self.calculate_variance(&recent_losses);
631 if variance > 0.1 {
632 optimizations.push(
633 "Training loss is unstable - consider reducing learning rate".to_string(),
634 );
635 }
636 }
637 }
638
639 TrainingRecipeOptimization {
640 recommended_optimizations: optimizations,
641 training_schedule: TrainingSchedule {
642 warmup_steps: 1000,
643 learning_rate_schedule: "cosine_annealing".to_string(),
644 batch_size_schedule: "constant".to_string(),
645 early_stopping: true,
646 checkpoint_frequency: 1000,
647 },
648 data_strategy: DataStrategy {
649 data_augmentation: vec!["horizontal_flip".to_string(), "random_crop".to_string()],
650 sampling_strategy: "balanced".to_string(),
651 preprocessing_optimizations: vec![
652 "normalization".to_string(),
653 "standardization".to_string(),
654 ],
655 },
656 }
657 }
658
659 fn calculate_variance(&self, values: &[f64]) -> f64 {
660 if values.len() < 2 {
661 return 0.0;
662 }
663
664 let mean = values.iter().sum::<f64>() / values.len() as f64;
665 let variance =
666 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
667 variance
668 }
669
670 fn generate_analysis_summary(&self, recommendations: &[FixRecommendation]) -> String {
671 let critical_count = recommendations
672 .iter()
673 .filter(|r| matches!(r.issue.severity, IssueSeverity::Critical))
674 .count();
675
676 let high_count = recommendations
677 .iter()
678 .filter(|r| matches!(r.issue.severity, IssueSeverity::High))
679 .count();
680
681 if critical_count > 0 {
682 format!("Found {} critical issues requiring immediate attention. {} high-priority issues also detected.",
683 critical_count, high_count)
684 } else if high_count > 0 {
685 format!(
686 "Found {} high-priority issues that should be addressed soon.",
687 high_count
688 )
689 } else if !recommendations.is_empty() {
690 "Found some optimization opportunities to improve training performance.".to_string()
691 } else {
692 "No significant issues detected. Training appears to be proceeding normally."
693 .to_string()
694 }
695 }
696
697 fn calculate_overall_confidence(&self, recommendations: &[FixRecommendation]) -> f64 {
698 if recommendations.is_empty() {
699 return 1.0;
700 }
701
702 let sum_confidence: f64 = recommendations.iter().map(|r| r.confidence).sum();
703 sum_confidence / recommendations.len() as f64
704 }
705
706 pub fn record_optimization_attempt(&mut self, attempt: OptimizationAttempt) {
708 self.optimization_history.push(attempt);
709
710 if self.optimization_history.len() > 1000 {
712 self.optimization_history.drain(0..500);
713 }
714 }
715
716 pub fn get_optimization_history(&self) -> &[OptimizationAttempt] {
718 &self.optimization_history
719 }
720}
721
722#[derive(Debug)]
725struct GradientIssueDetector;
726
727impl GradientIssueDetector {
728 fn new() -> Self {
729 Self
730 }
731}
732
733impl IssueDetector for GradientIssueDetector {
734 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
735 let mut issues = Vec::new();
736
737 if let Some(gradient_report) = context.gradient_report {
738 if gradient_report.has_vanishing_gradients() {
740 issues.push(DetectedIssue {
741 issue_type: IssueType::VanishingGradients,
742 severity: IssueSeverity::High,
743 confidence: 0.9,
744 description: "Vanishing gradients detected in multiple layers".to_string(),
745 evidence: vec![Evidence {
746 metric_name: "gradient_norm".to_string(),
747 observed_value: 0.001,
748 expected_range: (0.01, 1.0),
749 explanation: "Gradient norms are significantly below normal range"
750 .to_string(),
751 }],
752 metrics: HashMap::new(),
753 detected_at: chrono::Utc::now(),
754 });
755 }
756
757 if gradient_report.has_exploding_gradients() {
759 issues.push(DetectedIssue {
760 issue_type: IssueType::ExplodingGradients,
761 severity: IssueSeverity::Critical,
762 confidence: 0.95,
763 description: "Exploding gradients detected - training instability likely"
764 .to_string(),
765 evidence: vec![Evidence {
766 metric_name: "gradient_norm".to_string(),
767 observed_value: 100.0,
768 expected_range: (0.01, 10.0),
769 explanation: "Gradient norms are extremely high".to_string(),
770 }],
771 metrics: HashMap::new(),
772 detected_at: chrono::Utc::now(),
773 });
774 }
775 }
776
777 Ok(issues)
778 }
779
780 fn get_detector_name(&self) -> &str {
781 "GradientIssueDetector"
782 }
783
784 fn get_supported_issues(&self) -> Vec<IssueType> {
785 vec![IssueType::VanishingGradients, IssueType::ExplodingGradients]
786 }
787}
788
789#[derive(Debug)]
790struct TrainingIssueDetector;
791
792impl TrainingIssueDetector {
793 fn new() -> Self {
794 Self
795 }
796}
797
798impl IssueDetector for TrainingIssueDetector {
799 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
800 let mut issues = Vec::new();
801
802 if context.recent_metrics.len() >= 10 {
804 let recent_losses: Vec<f64> =
805 context.recent_metrics.iter().rev().take(10).filter_map(|m| m.loss).collect();
806
807 if recent_losses.len() >= 5 {
808 let first_half_avg = recent_losses[..recent_losses.len() / 2].iter().sum::<f64>()
810 / (recent_losses.len() / 2) as f64;
811 let second_half_avg = recent_losses[recent_losses.len() / 2..].iter().sum::<f64>()
812 / (recent_losses.len() - recent_losses.len() / 2) as f64;
813
814 if (first_half_avg - second_half_avg).abs() / first_half_avg < 0.01 {
815 issues.push(DetectedIssue {
816 issue_type: IssueType::TrainingStalled,
817 severity: IssueSeverity::Medium,
818 confidence: 0.8,
819 description: "Training appears to have stalled - loss not decreasing"
820 .to_string(),
821 evidence: vec![Evidence {
822 metric_name: "loss_change".to_string(),
823 observed_value: (first_half_avg - second_half_avg).abs()
824 / first_half_avg,
825 expected_range: (0.05, 1.0),
826 explanation: "Loss change is below expected threshold".to_string(),
827 }],
828 metrics: HashMap::new(),
829 detected_at: chrono::Utc::now(),
830 });
831 }
832 }
833 }
834
835 Ok(issues)
836 }
837
838 fn get_detector_name(&self) -> &str {
839 "TrainingIssueDetector"
840 }
841
842 fn get_supported_issues(&self) -> Vec<IssueType> {
843 vec![
844 IssueType::TrainingStalled,
845 IssueType::LossNotDecreasing,
846 IssueType::UnstableTraining,
847 ]
848 }
849}
850
851#[derive(Debug)]
852struct PerformanceIssueDetector;
853
854impl PerformanceIssueDetector {
855 fn new() -> Self {
856 Self
857 }
858}
859
860impl IssueDetector for PerformanceIssueDetector {
861 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
862 let mut issues = Vec::new();
863
864 if let Some(metrics) = context.recent_metrics.last() {
866 if let Some(gpu_util) = metrics.gpu_utilization {
867 if gpu_util < 0.5 {
868 issues.push(DetectedIssue {
869 issue_type: IssueType::LowGpuUtilization,
870 severity: IssueSeverity::Medium,
871 confidence: 0.8,
872 description:
873 "Low GPU utilization detected - compute resources underutilized"
874 .to_string(),
875 evidence: vec![Evidence {
876 metric_name: "gpu_utilization".to_string(),
877 observed_value: gpu_util,
878 expected_range: (0.7, 1.0),
879 explanation: "GPU utilization is below optimal range".to_string(),
880 }],
881 metrics: HashMap::new(),
882 detected_at: chrono::Utc::now(),
883 });
884 }
885 }
886 }
887
888 Ok(issues)
889 }
890
891 fn get_detector_name(&self) -> &str {
892 "PerformanceIssueDetector"
893 }
894
895 fn get_supported_issues(&self) -> Vec<IssueType> {
896 vec![
897 IssueType::LowGpuUtilization,
898 IssueType::SlowTraining,
899 IssueType::MemoryBottleneck,
900 ]
901 }
902}
903
904#[derive(Debug)]
905struct HyperparameterIssueDetector;
906
907impl HyperparameterIssueDetector {
908 fn new() -> Self {
909 Self
910 }
911}
912
913impl IssueDetector for HyperparameterIssueDetector {
914 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
915 let mut issues = Vec::new();
916
917 if let Some(metrics) = context.recent_metrics.last() {
918 if let Some(lr) = metrics.learning_rate {
920 if lr > 0.1 {
921 issues.push(DetectedIssue {
922 issue_type: IssueType::LearningRateTooHigh,
923 severity: IssueSeverity::High,
924 confidence: 0.7,
925 description:
926 "Learning rate appears too high - may cause training instability"
927 .to_string(),
928 evidence: vec![Evidence {
929 metric_name: "learning_rate".to_string(),
930 observed_value: lr,
931 expected_range: (0.0001, 0.01),
932 explanation: "Learning rate is above typical range".to_string(),
933 }],
934 metrics: HashMap::new(),
935 detected_at: chrono::Utc::now(),
936 });
937 } else if lr < 0.00001 {
938 issues.push(DetectedIssue {
939 issue_type: IssueType::LearningRateTooLow,
940 severity: IssueSeverity::Medium,
941 confidence: 0.6,
942 description: "Learning rate might be too low - training could be slow"
943 .to_string(),
944 evidence: vec![Evidence {
945 metric_name: "learning_rate".to_string(),
946 observed_value: lr,
947 expected_range: (0.0001, 0.01),
948 explanation: "Learning rate is below typical range".to_string(),
949 }],
950 metrics: HashMap::new(),
951 detected_at: chrono::Utc::now(),
952 });
953 }
954 }
955 }
956
957 Ok(issues)
958 }
959
960 fn get_detector_name(&self) -> &str {
961 "HyperparameterIssueDetector"
962 }
963
964 fn get_supported_issues(&self) -> Vec<IssueType> {
965 vec![
966 IssueType::LearningRateTooHigh,
967 IssueType::LearningRateTooLow,
968 ]
969 }
970}
971
972#[derive(Debug)]
973struct ArchitectureIssueDetector;
974
975impl ArchitectureIssueDetector {
976 fn new() -> Self {
977 Self
978 }
979}
980
981impl IssueDetector for ArchitectureIssueDetector {
982 fn detect_issues(&self, context: &DebugContext) -> Result<Vec<DetectedIssue>> {
983 let mut issues = Vec::new();
984
985 if let Some(model_info) = context.model_info {
986 if model_info.parameter_count > 1_000_000_000 {
988 issues.push(DetectedIssue {
989 issue_type: IssueType::ModelTooLarge,
990 severity: IssueSeverity::Medium,
991 confidence: 0.6,
992 description:
993 "Model has very large number of parameters - consider optimization"
994 .to_string(),
995 evidence: vec![Evidence {
996 metric_name: "parameter_count".to_string(),
997 observed_value: model_info.parameter_count as f64,
998 expected_range: (1_000_000.0, 100_000_000.0),
999 explanation: "Parameter count is extremely high".to_string(),
1000 }],
1001 metrics: HashMap::new(),
1002 detected_at: chrono::Utc::now(),
1003 });
1004 }
1005
1006 if model_info.layer_count > 100 {
1007 issues.push(DetectedIssue {
1008 issue_type: IssueType::InappropriateArchitecture,
1009 severity: IssueSeverity::Low,
1010 confidence: 0.5,
1011 description: "Very deep model - may have gradient flow issues".to_string(),
1012 evidence: vec![Evidence {
1013 metric_name: "layer_count".to_string(),
1014 observed_value: model_info.layer_count as f64,
1015 expected_range: (10.0, 50.0),
1016 explanation: "Layer count is very high".to_string(),
1017 }],
1018 metrics: HashMap::new(),
1019 detected_at: chrono::Utc::now(),
1020 });
1021 }
1022 }
1023
1024 Ok(issues)
1025 }
1026
1027 fn get_detector_name(&self) -> &str {
1028 "ArchitectureIssueDetector"
1029 }
1030
1031 fn get_supported_issues(&self) -> Vec<IssueType> {
1032 vec![
1033 IssueType::ModelTooLarge,
1034 IssueType::InappropriateArchitecture,
1035 ]
1036 }
1037}
1038
1039#[derive(Debug)]
1040struct DataIssueDetector;
1041
1042impl DataIssueDetector {
1043 fn new() -> Self {
1044 Self
1045 }
1046}
1047
1048impl IssueDetector for DataIssueDetector {
1049 fn detect_issues(&self, _context: &DebugContext) -> Result<Vec<DetectedIssue>> {
1050 Ok(Vec::new())
1053 }
1054
1055 fn get_detector_name(&self) -> &str {
1056 "DataIssueDetector"
1057 }
1058
1059 fn get_supported_issues(&self) -> Vec<IssueType> {
1060 vec![
1061 IssueType::DataImbalance,
1062 IssueType::BatchSizeProblems,
1063 IssueType::InsufficientData,
1064 ]
1065 }
1066}
1067
1068impl Default for KnowledgeBase {
1069 fn default() -> Self {
1070 Self::new()
1071 }
1072}
1073
1074impl KnowledgeBase {
1075 pub fn new() -> Self {
1076 Self {
1077 issue_patterns: HashMap::new(),
1078 hyperparameter_recommendations: HashMap::new(),
1079 architecture_patterns: Vec::new(),
1080 best_practices: HashMap::new(),
1081 }
1082 }
1083}
1084
1085#[derive(Debug, Serialize, Deserialize)]
1088pub struct AutoDebugReport {
1089 pub detected_issues: Vec<DetectedIssue>,
1090 pub fix_recommendations: Vec<FixRecommendation>,
1091 pub hyperparameter_recommendations: Vec<HyperparameterRecommendation>,
1092 pub architecture_suggestions: Vec<ArchitectureSuggestion>,
1093 pub training_recipe: TrainingRecipeOptimization,
1094 pub analysis_summary: String,
1095 pub confidence_score: f64,
1096}
1097
1098#[derive(Debug, Clone, Serialize, Deserialize)]
1099pub struct FixRecommendation {
1100 pub issue: DetectedIssue,
1101 pub fix_suggestion: FixSuggestion,
1102 pub confidence: f64,
1103 pub urgency: f64,
1104}
1105
1106#[derive(Debug, Clone, Serialize, Deserialize)]
1107pub struct HyperparameterRecommendation {
1108 pub parameter: String,
1109 pub current_value: Option<f64>,
1110 pub recommended_value: f64,
1111 pub reason: String,
1112 pub confidence: f64,
1113}
1114
1115#[derive(Debug, Clone, Serialize, Deserialize)]
1116pub struct ArchitectureSuggestion {
1117 pub suggestion_type: String,
1118 pub title: String,
1119 pub description: String,
1120 pub impact_assessment: String,
1121 pub implementation_difficulty: String,
1122}
1123
1124#[derive(Debug, Clone, Serialize, Deserialize)]
1125pub struct TrainingRecipeOptimization {
1126 pub recommended_optimizations: Vec<String>,
1127 pub training_schedule: TrainingSchedule,
1128 pub data_strategy: DataStrategy,
1129}
1130
1131#[derive(Debug, Clone, Serialize, Deserialize)]
1132pub struct TrainingSchedule {
1133 pub warmup_steps: u32,
1134 pub learning_rate_schedule: String,
1135 pub batch_size_schedule: String,
1136 pub early_stopping: bool,
1137 pub checkpoint_frequency: u32,
1138}
1139
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct DataStrategy {
1142 pub data_augmentation: Vec<String>,
1143 pub sampling_strategy: String,
1144 pub preprocessing_optimizations: Vec<String>,
1145}