1use anyhow::Result;
8use std::collections::{HashMap, VecDeque};
9
10use super::types::{
11 ConvergenceStatus, LayerActivationStats, ModelPerformanceMetrics, TrainingDynamics,
12};
13
14#[derive(Debug)]
16pub struct AutoDebugger {
17 config: AutoDebugConfig,
19 performance_history: VecDeque<ModelPerformanceMetrics>,
21 layer_history: HashMap<String, VecDeque<LayerActivationStats>>,
23 dynamics_history: VecDeque<TrainingDynamics>,
25 #[allow(dead_code)]
27 issue_patterns: IssuePatternDatabase,
28 session_state: DebuggingSession,
30}
31
32#[derive(Debug, Clone)]
34pub struct AutoDebugConfig {
35 pub max_history_size: usize,
37 pub min_samples_for_analysis: usize,
39 pub recommendation_confidence_threshold: f64,
41 pub enable_advanced_patterns: bool,
43 pub enable_hyperparameter_suggestions: bool,
45 pub enable_architectural_recommendations: bool,
47}
48
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct DebuggingSession {
52 pub session_start: chrono::DateTime<chrono::Utc>,
54 pub identified_issues: Vec<IdentifiedIssue>,
56 pub recommendations: Vec<DebuggingRecommendation>,
58 pub session_stats: SessionStatistics,
60}
61
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct IdentifiedIssue {
65 pub category: IssueCategory,
67 pub description: String,
69 pub severity: IssueSeverity,
71 pub confidence: f64,
73 pub evidence: Vec<String>,
75 pub potential_causes: Vec<String>,
77 pub identified_at: chrono::DateTime<chrono::Utc>,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
83pub enum IssueCategory {
84 LearningRate,
86 GradientFlow,
88 Overfitting,
90 Underfitting,
92 DataQuality,
94 Architecture,
96 Memory,
98 Convergence,
100 NumericalStability,
102}
103
104#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
106pub enum IssueSeverity {
107 Minor,
109 Moderate,
111 Major,
113 Critical,
115}
116
117#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct DebuggingRecommendation {
120 pub category: RecommendationCategory,
122 pub title: String,
124 pub description: String,
126 pub actions: Vec<String>,
128 pub expected_impact: String,
130 pub confidence: f64,
132 pub priority: AutoDebugRecommendationPriority,
134 pub hyperparameter_suggestions: Vec<HyperparameterSuggestion>,
136}
137
138#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub enum RecommendationCategory {
141 HyperparameterTuning,
143 ArchitecturalModification,
145 DataPreprocessing,
147 TrainingStrategy,
149 DebuggingAndMonitoring,
151 ResourceOptimization,
153}
154
155#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
157pub enum AutoDebugRecommendationPriority {
158 Low,
160 Medium,
162 High,
164 Urgent,
166}
167
168#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
170pub struct HyperparameterSuggestion {
171 pub parameter_name: String,
173 pub current_value: Option<f64>,
175 pub suggested_value: f64,
177 pub reasoning: String,
179 pub expected_effect: String,
181}
182
183#[derive(Debug, Clone)]
185pub struct IssuePatternDatabase {
186 pub learning_rate_patterns: Vec<IssuePattern>,
188 pub gradient_patterns: Vec<IssuePattern>,
190 pub convergence_patterns: Vec<IssuePattern>,
192 pub layer_patterns: Vec<IssuePattern>,
194}
195
196#[derive(Debug, Clone)]
198pub struct IssuePattern {
199 pub name: String,
201 pub description: String,
203 pub conditions: Vec<PatternCondition>,
205 pub issue_category: IssueCategory,
207 pub confidence_weight: f64,
209 pub solutions: Vec<String>,
211}
212
213#[derive(Debug, Clone)]
215pub struct PatternCondition {
216 pub metric: String,
218 pub operator: ComparisonOperator,
220 pub threshold: f64,
222 pub consecutive_count: usize,
224}
225
226#[derive(Debug, Clone)]
228pub enum ComparisonOperator {
229 GreaterThan,
231 LessThan,
233 EqualTo,
235 Increasing,
237 Decreasing,
239 Oscillating,
241}
242
243#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
245pub struct SessionStatistics {
246 pub total_issues: usize,
248 pub issues_by_category: HashMap<IssueCategory, usize>,
250 pub total_recommendations: usize,
252 pub avg_recommendation_confidence: f64,
254 pub analysis_duration: chrono::Duration,
256}
257
258impl Default for AutoDebugConfig {
259 fn default() -> Self {
260 Self {
261 max_history_size: 1000,
262 min_samples_for_analysis: 10,
263 recommendation_confidence_threshold: 0.7,
264 enable_advanced_patterns: true,
265 enable_hyperparameter_suggestions: true,
266 enable_architectural_recommendations: true,
267 }
268 }
269}
270
271impl AutoDebugger {
272 pub fn new() -> Self {
274 Self {
275 config: AutoDebugConfig::default(),
276 performance_history: VecDeque::new(),
277 layer_history: HashMap::new(),
278 dynamics_history: VecDeque::new(),
279 issue_patterns: IssuePatternDatabase::new(),
280 session_state: DebuggingSession::new(),
281 }
282 }
283
284 pub fn with_config(config: AutoDebugConfig) -> Self {
286 Self {
287 config,
288 performance_history: VecDeque::new(),
289 layer_history: HashMap::new(),
290 dynamics_history: VecDeque::new(),
291 issue_patterns: IssuePatternDatabase::new(),
292 session_state: DebuggingSession::new(),
293 }
294 }
295
296 pub fn record_performance_metrics(&mut self, metrics: ModelPerformanceMetrics) {
298 self.performance_history.push_back(metrics);
299
300 while self.performance_history.len() > self.config.max_history_size {
301 self.performance_history.pop_front();
302 }
303 }
304
305 pub fn record_layer_stats(&mut self, stats: LayerActivationStats) {
307 let layer_name = stats.layer_name.clone();
308
309 let layer_history = self.layer_history.entry(layer_name).or_default();
310 layer_history.push_back(stats);
311
312 while layer_history.len() > self.config.max_history_size {
313 layer_history.pop_front();
314 }
315 }
316
317 pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) {
319 self.dynamics_history.push_back(dynamics);
320
321 while self.dynamics_history.len() > self.config.max_history_size {
322 self.dynamics_history.pop_front();
323 }
324 }
325
326 pub fn perform_analysis(&mut self) -> Result<DebuggingReport> {
328 let analysis_start = chrono::Utc::now();
329
330 if self.performance_history.len() < self.config.min_samples_for_analysis {
331 return Err(anyhow::anyhow!("Insufficient data for analysis"));
332 }
333
334 self.session_state = DebuggingSession::new();
336
337 self.analyze_learning_rate_issues()?;
339 self.analyze_convergence_issues()?;
340 self.analyze_gradient_flow_issues()?;
341 self.analyze_layer_health_issues()?;
342 self.analyze_memory_issues()?;
343 self.analyze_overfitting_underfitting()?;
344
345 self.generate_recommendations()?;
347
348 self.session_state.session_stats.analysis_duration = chrono::Utc::now() - analysis_start;
350 self.update_session_statistics();
351
352 Ok(DebuggingReport {
353 session_info: self.session_state.clone(),
354 identified_issues: self.session_state.identified_issues.clone(),
355 recommendations: self.session_state.recommendations.clone(),
356 summary: self.generate_analysis_summary(),
357 })
358 }
359
360 fn analyze_learning_rate_issues(&mut self) -> Result<()> {
362 let recent_metrics: Vec<_> = self.performance_history.iter().rev().take(20).collect();
363 if recent_metrics.len() < 10 {
364 return Ok(());
365 }
366
367 let mut issues_to_add = Vec::new();
368
369 let recent_losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
371 if let Some(max_loss) = recent_losses
372 .iter()
373 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
374 {
375 if let Some(min_loss) = recent_losses
376 .iter()
377 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
378 {
379 if max_loss / min_loss > 10.0 {
380 issues_to_add.push(IdentifiedIssue {
381 category: IssueCategory::LearningRate,
382 description: "Learning rate too high - loss explosion detected".to_string(),
383 severity: IssueSeverity::Critical,
384 confidence: 0.9,
385 evidence: vec![
386 format!("Loss ratio: {:.2}", max_loss / min_loss),
387 "Rapid loss increase observed".to_string(),
388 ],
389 potential_causes: vec![
390 "Learning rate set too high".to_string(),
391 "Gradient clipping disabled".to_string(),
392 "Numerical instability".to_string(),
393 ],
394 identified_at: chrono::Utc::now(),
395 });
396 }
397 }
398 }
399
400 let loss_variance = self.calculate_variance(&recent_losses);
402 let recent_metrics_len = recent_metrics.len();
403 if loss_variance < 1e-6 && recent_metrics_len >= 15 {
404 issues_to_add.push(IdentifiedIssue {
405 category: IssueCategory::LearningRate,
406 description: "Learning rate too low - training stagnation".to_string(),
407 severity: IssueSeverity::Major,
408 confidence: 0.8,
409 evidence: vec![
410 format!("Loss variance: {:.2e}", loss_variance),
411 "No learning progress in recent steps".to_string(),
412 ],
413 potential_causes: vec![
414 "Learning rate set too low".to_string(),
415 "Learning rate decay too aggressive".to_string(),
416 "Model has converged".to_string(),
417 ],
418 identified_at: chrono::Utc::now(),
419 });
420 }
421
422 for issue in issues_to_add {
424 self.add_issue(issue);
425 }
426
427 Ok(())
428 }
429
430 fn analyze_convergence_issues(&mut self) -> Result<()> {
432 if let Some(latest_dynamics) = self.dynamics_history.back() {
433 match latest_dynamics.convergence_status {
434 ConvergenceStatus::Diverging => {
435 self.add_issue(IdentifiedIssue {
436 category: IssueCategory::Convergence,
437 description: "Training is diverging".to_string(),
438 severity: IssueSeverity::Critical,
439 confidence: 0.95,
440 evidence: vec!["Convergence status: Diverging".to_string()],
441 potential_causes: vec![
442 "Learning rate too high".to_string(),
443 "Gradient explosion".to_string(),
444 "Numerical instability".to_string(),
445 ],
446 identified_at: chrono::Utc::now(),
447 });
448 },
449 ConvergenceStatus::Plateau => {
450 if let Some(plateau_info) = &latest_dynamics.plateau_detection {
451 if plateau_info.duration_steps > 100 {
452 self.add_issue(IdentifiedIssue {
453 category: IssueCategory::Convergence,
454 description: "Training has plateaued".to_string(),
455 severity: IssueSeverity::Moderate,
456 confidence: 0.8,
457 evidence: vec![
458 format!(
459 "Plateau duration: {} steps",
460 plateau_info.duration_steps
461 ),
462 format!("Plateau value: {:.4}", plateau_info.plateau_value),
463 ],
464 potential_causes: vec![
465 "Learning rate too low".to_string(),
466 "Model capacity insufficient".to_string(),
467 "Local minimum reached".to_string(),
468 ],
469 identified_at: chrono::Utc::now(),
470 });
471 }
472 }
473 },
474 ConvergenceStatus::Oscillating => {
475 self.add_issue(IdentifiedIssue {
476 category: IssueCategory::NumericalStability,
477 description: "Training is oscillating".to_string(),
478 severity: IssueSeverity::Moderate,
479 confidence: 0.7,
480 evidence: vec!["Convergence status: Oscillating".to_string()],
481 potential_causes: vec![
482 "Learning rate too high".to_string(),
483 "Batch size too small".to_string(),
484 "Momentum settings suboptimal".to_string(),
485 ],
486 identified_at: chrono::Utc::now(),
487 });
488 },
489 _ => {},
490 }
491 }
492
493 Ok(())
494 }
495
496 fn analyze_gradient_flow_issues(&mut self) -> Result<()> {
498 let mut issues_to_add = Vec::new();
499
500 for (layer_name, layer_history) in &self.layer_history {
502 if let Some(latest_stats) = layer_history.back() {
503 if latest_stats.dead_neurons_ratio > 0.5 {
505 issues_to_add.push(IdentifiedIssue {
506 category: IssueCategory::GradientFlow,
507 description: format!("High dead neuron ratio in layer {}", layer_name),
508 severity: IssueSeverity::Major,
509 confidence: 0.85,
510 evidence: vec![
511 format!(
512 "Dead neurons: {:.1}%",
513 latest_stats.dead_neurons_ratio * 100.0
514 ),
515 format!("Layer: {}", layer_name),
516 ],
517 potential_causes: vec![
518 "Dying ReLU problem".to_string(),
519 "Poor weight initialization".to_string(),
520 "Learning rate too high".to_string(),
521 ],
522 identified_at: chrono::Utc::now(),
523 });
524 }
525
526 if latest_stats.saturated_neurons_ratio > 0.3 {
528 issues_to_add.push(IdentifiedIssue {
529 category: IssueCategory::GradientFlow,
530 description: format!("High activation saturation in layer {}", layer_name),
531 severity: IssueSeverity::Moderate,
532 confidence: 0.8,
533 evidence: vec![
534 format!(
535 "Saturated neurons: {:.1}%",
536 latest_stats.saturated_neurons_ratio * 100.0
537 ),
538 format!("Layer: {}", layer_name),
539 ],
540 potential_causes: vec![
541 "Vanishing gradient problem".to_string(),
542 "Poor activation function choice".to_string(),
543 "Input normalization issues".to_string(),
544 ],
545 identified_at: chrono::Utc::now(),
546 });
547 }
548 }
549 }
550
551 for issue in issues_to_add {
553 self.add_issue(issue);
554 }
555
556 Ok(())
557 }
558
559 fn analyze_layer_health_issues(&mut self) -> Result<()> {
561 let mut issues_to_add = Vec::new();
562
563 for (layer_name, layer_history) in &self.layer_history {
564 if layer_history.len() >= 5 {
565 let recent_stats: Vec<_> = layer_history.iter().rev().take(5).collect();
566
567 let variances: Vec<f64> = recent_stats.iter().map(|s| s.std_activation).collect();
569 let avg_variance = variances.iter().sum::<f64>() / variances.len() as f64;
570
571 if avg_variance < 0.01 {
572 issues_to_add.push(IdentifiedIssue {
573 category: IssueCategory::Architecture,
574 description: format!("Low activation variance in layer {}", layer_name),
575 severity: IssueSeverity::Minor,
576 confidence: 0.6,
577 evidence: vec![
578 format!("Average variance: {:.4}", avg_variance),
579 format!("Layer: {}", layer_name),
580 ],
581 potential_causes: vec![
582 "Poor weight initialization".to_string(),
583 "Input normalization too aggressive".to_string(),
584 "Activation function saturation".to_string(),
585 ],
586 identified_at: chrono::Utc::now(),
587 });
588 }
589 }
590 }
591
592 for issue in issues_to_add {
594 self.add_issue(issue);
595 }
596
597 Ok(())
598 }
599
600 fn analyze_memory_issues(&mut self) -> Result<()> {
602 if self.performance_history.len() >= 10 {
603 let recent_memory: Vec<f64> = self
604 .performance_history
605 .iter()
606 .rev()
607 .take(10)
608 .map(|m| m.memory_usage_mb)
609 .collect();
610
611 let memory_trend = self.calculate_trend(&recent_memory);
613 if memory_trend > 10.0 {
614 self.add_issue(IdentifiedIssue {
616 category: IssueCategory::Memory,
617 description: "Memory leak detected".to_string(),
618 severity: IssueSeverity::Critical,
619 confidence: 0.9,
620 evidence: vec![
621 format!("Memory growth rate: {:.2} MB/step", memory_trend),
622 "Increasing memory usage trend".to_string(),
623 ],
624 potential_causes: vec![
625 "Gradient accumulation without clearing".to_string(),
626 "Cached tensors not being released".to_string(),
627 "Memory fragmentation".to_string(),
628 ],
629 identified_at: chrono::Utc::now(),
630 });
631 }
632
633 if let Some(max_memory) = recent_memory
635 .iter()
636 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
637 {
638 if *max_memory > 16384.0 {
639 self.add_issue(IdentifiedIssue {
641 category: IssueCategory::Memory,
642 description: "Excessive memory usage detected".to_string(),
643 severity: IssueSeverity::Major,
644 confidence: 0.8,
645 evidence: vec![
646 format!("Peak memory: {:.0} MB", max_memory),
647 "High memory consumption".to_string(),
648 ],
649 potential_causes: vec![
650 "Batch size too large".to_string(),
651 "Model too large for available memory".to_string(),
652 "Inefficient memory allocation".to_string(),
653 ],
654 identified_at: chrono::Utc::now(),
655 });
656 }
657 }
658 }
659
660 Ok(())
661 }
662
663 fn analyze_overfitting_underfitting(&mut self) -> Result<()> {
665 let mut issues_to_add = Vec::new();
666
667 if let Some(latest_dynamics) = self.dynamics_history.back() {
668 for indicator in &latest_dynamics.overfitting_indicators {
670 if let super::types::OverfittingIndicator::TrainValidationGap { gap } = indicator {
671 if *gap > 0.1 {
672 issues_to_add.push(IdentifiedIssue {
673 category: IssueCategory::Overfitting,
674 description: "Large training-validation gap detected".to_string(),
675 severity: IssueSeverity::Major,
676 confidence: 0.85,
677 evidence: vec![
678 format!("Train-validation gap: {:.3}", gap),
679 "Overfitting indicator present".to_string(),
680 ],
681 potential_causes: vec![
682 "Model complexity too high".to_string(),
683 "Insufficient regularization".to_string(),
684 "Training set too small".to_string(),
685 ],
686 identified_at: chrono::Utc::now(),
687 });
688 }
689 }
690 }
691
692 for indicator in &latest_dynamics.underfitting_indicators {
694 match indicator {
695 super::types::UnderfittingIndicator::HighTrainingLoss { loss, threshold } => {
696 issues_to_add.push(IdentifiedIssue {
697 category: IssueCategory::Underfitting,
698 description: "High training loss indicates underfitting".to_string(),
699 severity: IssueSeverity::Moderate,
700 confidence: 0.7,
701 evidence: vec![
702 format!("Training loss: {:.3}", loss),
703 format!("Threshold: {:.3}", threshold),
704 ],
705 potential_causes: vec![
706 "Model capacity too low".to_string(),
707 "Learning rate too low".to_string(),
708 "Insufficient training time".to_string(),
709 ],
710 identified_at: chrono::Utc::now(),
711 });
712 },
713 super::types::UnderfittingIndicator::SlowConvergence {
714 steps_taken,
715 expected,
716 } => {
717 issues_to_add.push(IdentifiedIssue {
718 category: IssueCategory::Underfitting,
719 description: "Slow convergence detected".to_string(),
720 severity: IssueSeverity::Minor,
721 confidence: 0.6,
722 evidence: vec![
723 format!("Steps taken: {}", steps_taken),
724 format!("Expected: {}", expected),
725 ],
726 potential_causes: vec![
727 "Learning rate too conservative".to_string(),
728 "Optimizer choice suboptimal".to_string(),
729 "Poor initialization".to_string(),
730 ],
731 identified_at: chrono::Utc::now(),
732 });
733 },
734 _ => {},
735 }
736 }
737 }
738
739 for issue in issues_to_add {
741 self.add_issue(issue);
742 }
743
744 Ok(())
745 }
746
747 fn generate_recommendations(&mut self) -> Result<()> {
749 for issue in &self.session_state.identified_issues {
750 let recommendations = self.generate_recommendations_for_issue(issue);
751 self.session_state.recommendations.extend(recommendations);
752 }
753
754 self.session_state.recommendations.sort_by(|a, b| {
756 b.priority
757 .partial_cmp(&a.priority)
758 .unwrap_or(std::cmp::Ordering::Equal)
759 .then(b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal))
760 });
761
762 Ok(())
763 }
764
765 fn generate_recommendations_for_issue(
767 &self,
768 issue: &IdentifiedIssue,
769 ) -> Vec<DebuggingRecommendation> {
770 match issue.category {
771 IssueCategory::LearningRate => {
772 if issue.description.contains("too high") {
773 vec![DebuggingRecommendation {
774 category: RecommendationCategory::HyperparameterTuning,
775 title: "Reduce Learning Rate".to_string(),
776 description: "Lower the learning rate to stabilize training".to_string(),
777 actions: vec![
778 "Reduce learning rate by factor of 2-10".to_string(),
779 "Enable gradient clipping".to_string(),
780 "Consider learning rate scheduling".to_string(),
781 ],
782 expected_impact: "Stabilized training with reduced loss oscillations"
783 .to_string(),
784 confidence: 0.9,
785 priority: AutoDebugRecommendationPriority::High,
786 hyperparameter_suggestions: vec![HyperparameterSuggestion {
787 parameter_name: "learning_rate".to_string(),
788 current_value: None,
789 suggested_value: 0.0001,
790 reasoning: "Reduce to prevent loss explosion".to_string(),
791 expected_effect: "More stable training".to_string(),
792 }],
793 }]
794 } else if issue.description.contains("too low") {
795 vec![DebuggingRecommendation {
796 category: RecommendationCategory::HyperparameterTuning,
797 title: "Increase Learning Rate".to_string(),
798 description: "Increase learning rate to improve convergence speed"
799 .to_string(),
800 actions: vec![
801 "Increase learning rate by factor of 2-5".to_string(),
802 "Use learning rate warmup".to_string(),
803 "Consider adaptive learning rate methods".to_string(),
804 ],
805 expected_impact: "Faster convergence and better final performance"
806 .to_string(),
807 confidence: 0.8,
808 priority: AutoDebugRecommendationPriority::Medium,
809 hyperparameter_suggestions: vec![HyperparameterSuggestion {
810 parameter_name: "learning_rate".to_string(),
811 current_value: None,
812 suggested_value: 0.001,
813 reasoning: "Increase to improve learning speed".to_string(),
814 expected_effect: "Faster convergence".to_string(),
815 }],
816 }]
817 } else {
818 Vec::new()
819 }
820 },
821 IssueCategory::Memory => {
822 vec![DebuggingRecommendation {
823 category: RecommendationCategory::ResourceOptimization,
824 title: "Optimize Memory Usage".to_string(),
825 description: "Implement memory optimization strategies".to_string(),
826 actions: vec![
827 "Reduce batch size".to_string(),
828 "Enable gradient checkpointing".to_string(),
829 "Clear cached tensors regularly".to_string(),
830 "Use mixed precision training".to_string(),
831 ],
832 expected_impact: "Reduced memory consumption and stable training".to_string(),
833 confidence: 0.85,
834 priority: AutoDebugRecommendationPriority::High,
835 hyperparameter_suggestions: vec![HyperparameterSuggestion {
836 parameter_name: "batch_size".to_string(),
837 current_value: None,
838 suggested_value: 16.0,
839 reasoning: "Reduce to lower memory usage".to_string(),
840 expected_effect: "Lower memory consumption".to_string(),
841 }],
842 }]
843 },
844 IssueCategory::Overfitting => {
845 vec![DebuggingRecommendation {
846 category: RecommendationCategory::TrainingStrategy,
847 title: "Address Overfitting".to_string(),
848 description: "Implement regularization strategies to reduce overfitting"
849 .to_string(),
850 actions: vec![
851 "Add dropout layers".to_string(),
852 "Increase weight decay".to_string(),
853 "Use data augmentation".to_string(),
854 "Reduce model complexity".to_string(),
855 "Implement early stopping".to_string(),
856 ],
857 expected_impact: "Better generalization and validation performance".to_string(),
858 confidence: 0.8,
859 priority: AutoDebugRecommendationPriority::Medium,
860 hyperparameter_suggestions: vec![HyperparameterSuggestion {
861 parameter_name: "dropout_rate".to_string(),
862 current_value: None,
863 suggested_value: 0.1,
864 reasoning: "Add regularization to reduce overfitting".to_string(),
865 expected_effect: "Better generalization".to_string(),
866 }],
867 }]
868 },
869 IssueCategory::GradientFlow => {
870 vec![DebuggingRecommendation {
871 category: RecommendationCategory::ArchitecturalModification,
872 title: "Improve Gradient Flow".to_string(),
873 description: "Address gradient flow issues in the network".to_string(),
874 actions: vec![
875 "Use different activation functions (e.g., Leaky ReLU, Swish)".to_string(),
876 "Add batch normalization".to_string(),
877 "Implement residual connections".to_string(),
878 "Adjust weight initialization".to_string(),
879 ],
880 expected_impact: "Better gradient flow and training stability".to_string(),
881 confidence: 0.75,
882 priority: AutoDebugRecommendationPriority::Medium,
883 hyperparameter_suggestions: Vec::new(),
884 }]
885 },
886 _ => Vec::new(),
887 }
888 }
889
890 fn add_issue(&mut self, issue: IdentifiedIssue) {
892 self.session_state.identified_issues.push(issue);
893 }
894
895 fn calculate_variance(&self, values: &[f64]) -> f64 {
897 if values.len() < 2 {
898 return 0.0;
899 }
900
901 let mean = values.iter().sum::<f64>() / values.len() as f64;
902 let variance =
903 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
904
905 variance
906 }
907
908 fn calculate_trend(&self, values: &[f64]) -> f64 {
910 if values.len() < 2 {
911 return 0.0;
912 }
913
914 let n = values.len() as f64;
915 let x_mean = (n - 1.0) / 2.0;
916 let y_mean = values.iter().sum::<f64>() / n;
917
918 let numerator: f64 = values
919 .iter()
920 .enumerate()
921 .map(|(i, &y)| (i as f64 - x_mean) * (y - y_mean))
922 .sum();
923
924 let denominator: f64 = (0..values.len()).map(|i| (i as f64 - x_mean).powi(2)).sum();
925
926 if denominator == 0.0 {
927 0.0
928 } else {
929 numerator / denominator
930 }
931 }
932
933 fn update_session_statistics(&mut self) {
935 let mut issues_by_category = HashMap::new();
936 for issue in &self.session_state.identified_issues {
937 *issues_by_category.entry(issue.category.clone()).or_insert(0) += 1;
938 }
939
940 let avg_confidence = if self.session_state.recommendations.is_empty() {
941 0.0
942 } else {
943 self.session_state.recommendations.iter().map(|r| r.confidence).sum::<f64>()
944 / self.session_state.recommendations.len() as f64
945 };
946
947 self.session_state.session_stats = SessionStatistics {
948 total_issues: self.session_state.identified_issues.len(),
949 issues_by_category,
950 total_recommendations: self.session_state.recommendations.len(),
951 avg_recommendation_confidence: avg_confidence,
952 analysis_duration: self.session_state.session_stats.analysis_duration,
953 };
954 }
955
956 fn generate_analysis_summary(&self) -> String {
958 let critical_issues = self
959 .session_state
960 .identified_issues
961 .iter()
962 .filter(|i| i.severity == IssueSeverity::Critical)
963 .count();
964
965 let major_issues = self
966 .session_state
967 .identified_issues
968 .iter()
969 .filter(|i| i.severity == IssueSeverity::Major)
970 .count();
971
972 let high_priority_recommendations = self
973 .session_state
974 .recommendations
975 .iter()
976 .filter(|r| r.priority == AutoDebugRecommendationPriority::High)
977 .count();
978
979 format!(
980 "Auto-debugging analysis completed. Found {} critical issues, {} major issues. \
981 Generated {} recommendations with {} high-priority actions. \
982 Average recommendation confidence: {:.2}",
983 critical_issues,
984 major_issues,
985 self.session_state.recommendations.len(),
986 high_priority_recommendations,
987 self.session_state.session_stats.avg_recommendation_confidence
988 )
989 }
990}
991
992#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
994pub struct DebuggingReport {
995 pub session_info: DebuggingSession,
997 pub identified_issues: Vec<IdentifiedIssue>,
999 pub recommendations: Vec<DebuggingRecommendation>,
1001 pub summary: String,
1003}
1004
1005impl IssuePatternDatabase {
1006 pub fn new() -> Self {
1008 Self {
1009 learning_rate_patterns: Self::create_learning_rate_patterns(),
1010 gradient_patterns: Self::create_gradient_patterns(),
1011 convergence_patterns: Self::create_convergence_patterns(),
1012 layer_patterns: Self::create_layer_patterns(),
1013 }
1014 }
1015
1016 fn create_learning_rate_patterns() -> Vec<IssuePattern> {
1018 vec![IssuePattern {
1019 name: "Loss Explosion".to_string(),
1020 description: "Rapid increase in loss indicating learning rate too high".to_string(),
1021 conditions: vec![PatternCondition {
1022 metric: "loss".to_string(),
1023 operator: ComparisonOperator::Increasing,
1024 threshold: 2.0,
1025 consecutive_count: 3,
1026 }],
1027 issue_category: IssueCategory::LearningRate,
1028 confidence_weight: 0.9,
1029 solutions: vec![
1030 "Reduce learning rate by factor of 10".to_string(),
1031 "Enable gradient clipping".to_string(),
1032 ],
1033 }]
1034 }
1035
1036 fn create_gradient_patterns() -> Vec<IssuePattern> {
1038 vec![]
1039 }
1040
1041 fn create_convergence_patterns() -> Vec<IssuePattern> {
1043 vec![]
1044 }
1045
1046 fn create_layer_patterns() -> Vec<IssuePattern> {
1048 vec![]
1049 }
1050}
1051
1052impl DebuggingSession {
1053 fn new() -> Self {
1055 Self {
1056 session_start: chrono::Utc::now(),
1057 identified_issues: Vec::new(),
1058 recommendations: Vec::new(),
1059 session_stats: SessionStatistics {
1060 total_issues: 0,
1061 issues_by_category: HashMap::new(),
1062 total_recommendations: 0,
1063 avg_recommendation_confidence: 0.0,
1064 analysis_duration: chrono::Duration::zero(),
1065 },
1066 }
1067 }
1068}
1069
1070impl Default for AutoDebugger {
1071 fn default() -> Self {
1072 Self::new()
1073 }
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_auto_debugger_creation() {
1082 let debugger = AutoDebugger::new();
1083 assert_eq!(debugger.performance_history.len(), 0);
1084 assert_eq!(debugger.layer_history.len(), 0);
1085 }
1086
1087 #[test]
1088 fn test_issue_identification() {
1089 let mut debugger = AutoDebugger::new();
1090
1091 let issue = IdentifiedIssue {
1092 category: IssueCategory::LearningRate,
1093 description: "Test issue".to_string(),
1094 severity: IssueSeverity::Major,
1095 confidence: 0.8,
1096 evidence: vec!["Test evidence".to_string()],
1097 potential_causes: vec!["Test cause".to_string()],
1098 identified_at: chrono::Utc::now(),
1099 };
1100
1101 debugger.add_issue(issue);
1102 assert_eq!(debugger.session_state.identified_issues.len(), 1);
1103 }
1104
1105 #[test]
1106 fn test_variance_calculation() {
1107 let debugger = AutoDebugger::new();
1108 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1109 let variance = debugger.calculate_variance(&values);
1110 assert!(variance > 0.0);
1111 }
1112
1113 #[test]
1114 fn test_trend_calculation() {
1115 let debugger = AutoDebugger::new();
1116 let increasing_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1117 let trend = debugger.calculate_trend(&increasing_values);
1118 assert!(trend > 0.0);
1119 }
1120}