Skip to main content

trustformers_debug/model_diagnostics/
auto_debug.rs

1//! Auto-debugging and automated recommendation system.
2//!
3//! This module provides intelligent debugging capabilities that automatically
4//! analyze model behavior, identify potential issues, and generate actionable
5//! recommendations for optimization and problem resolution.
6
7use anyhow::Result;
8use std::collections::{HashMap, VecDeque};
9
10use super::types::{
11    ConvergenceStatus, LayerActivationStats, ModelPerformanceMetrics, TrainingDynamics,
12};
13
14/// Auto-debugging system for intelligent model analysis.
15#[derive(Debug)]
16pub struct AutoDebugger {
17    /// Debugging configuration
18    config: AutoDebugConfig,
19    /// Historical performance data for analysis
20    performance_history: VecDeque<ModelPerformanceMetrics>,
21    /// Layer statistics history
22    layer_history: HashMap<String, VecDeque<LayerActivationStats>>,
23    /// Training dynamics history
24    dynamics_history: VecDeque<TrainingDynamics>,
25    /// Known issue patterns and solutions
26    #[allow(dead_code)]
27    issue_patterns: IssuePatternDatabase,
28    /// Current debugging session state
29    session_state: DebuggingSession,
30}
31
32/// Configuration for auto-debugging system.
33#[derive(Debug, Clone)]
34pub struct AutoDebugConfig {
35    /// Maximum history size for analysis
36    pub max_history_size: usize,
37    /// Minimum samples required for pattern detection
38    pub min_samples_for_analysis: usize,
39    /// Confidence threshold for recommendations
40    pub recommendation_confidence_threshold: f64,
41    /// Enable advanced pattern recognition
42    pub enable_advanced_patterns: bool,
43    /// Enable hyperparameter suggestions
44    pub enable_hyperparameter_suggestions: bool,
45    /// Enable architectural recommendations
46    pub enable_architectural_recommendations: bool,
47}
48
49/// Current debugging session state.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct DebuggingSession {
52    /// Session start time
53    pub session_start: chrono::DateTime<chrono::Utc>,
54    /// Issues identified in current session
55    pub identified_issues: Vec<IdentifiedIssue>,
56    /// Recommendations generated
57    pub recommendations: Vec<DebuggingRecommendation>,
58    /// Session statistics
59    pub session_stats: SessionStatistics,
60}
61
62/// An identified issue with diagnostic information.
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct IdentifiedIssue {
65    /// Issue category
66    pub category: IssueCategory,
67    /// Issue description
68    pub description: String,
69    /// Severity level
70    pub severity: IssueSeverity,
71    /// Confidence in identification
72    pub confidence: f64,
73    /// Evidence supporting the identification
74    pub evidence: Vec<String>,
75    /// Potential causes
76    pub potential_causes: Vec<String>,
77    /// When issue was identified
78    pub identified_at: chrono::DateTime<chrono::Utc>,
79}
80
81/// Categories of issues that can be identified.
82#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
83pub enum IssueCategory {
84    /// Learning rate related issues
85    LearningRate,
86    /// Gradient flow problems
87    GradientFlow,
88    /// Overfitting issues
89    Overfitting,
90    /// Underfitting issues
91    Underfitting,
92    /// Data quality problems
93    DataQuality,
94    /// Architecture inefficiencies
95    Architecture,
96    /// Memory management issues
97    Memory,
98    /// Convergence problems
99    Convergence,
100    /// Numerical stability issues
101    NumericalStability,
102}
103
104/// Severity levels for identified issues.
105#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
106pub enum IssueSeverity {
107    /// Minor issue with low impact
108    Minor,
109    /// Moderate issue affecting performance
110    Moderate,
111    /// Major issue requiring attention
112    Major,
113    /// Critical issue requiring immediate action
114    Critical,
115}
116
117/// Auto-generated debugging recommendation.
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct DebuggingRecommendation {
120    /// Recommendation category
121    pub category: RecommendationCategory,
122    /// Recommendation title
123    pub title: String,
124    /// Detailed description
125    pub description: String,
126    /// Specific actions to take
127    pub actions: Vec<String>,
128    /// Expected impact
129    pub expected_impact: String,
130    /// Confidence in recommendation
131    pub confidence: f64,
132    /// Priority level
133    pub priority: AutoDebugRecommendationPriority,
134    /// Relevant hyperparameters to adjust
135    pub hyperparameter_suggestions: Vec<HyperparameterSuggestion>,
136}
137
138/// Categories of recommendations.
139#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub enum RecommendationCategory {
141    /// Hyperparameter adjustments
142    HyperparameterTuning,
143    /// Architectural changes
144    ArchitecturalModification,
145    /// Data preprocessing recommendations
146    DataPreprocessing,
147    /// Training strategy changes
148    TrainingStrategy,
149    /// Debugging and monitoring
150    DebuggingAndMonitoring,
151    /// Resource optimization
152    ResourceOptimization,
153}
154
155/// Priority levels for recommendations.
156#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
157pub enum AutoDebugRecommendationPriority {
158    /// Low priority suggestion
159    Low,
160    /// Medium priority recommendation
161    Medium,
162    /// High priority action needed
163    High,
164    /// Urgent action required
165    Urgent,
166}
167
168/// Hyperparameter adjustment suggestion.
169#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
170pub struct HyperparameterSuggestion {
171    /// Parameter name
172    pub parameter_name: String,
173    /// Current value (if known)
174    pub current_value: Option<f64>,
175    /// Suggested value
176    pub suggested_value: f64,
177    /// Adjustment reasoning
178    pub reasoning: String,
179    /// Expected effect
180    pub expected_effect: String,
181}
182
183/// Database of known issue patterns and solutions.
184#[derive(Debug, Clone)]
185pub struct IssuePatternDatabase {
186    /// Learning rate patterns
187    pub learning_rate_patterns: Vec<IssuePattern>,
188    /// Gradient patterns
189    pub gradient_patterns: Vec<IssuePattern>,
190    /// Convergence patterns
191    pub convergence_patterns: Vec<IssuePattern>,
192    /// Layer behavior patterns
193    pub layer_patterns: Vec<IssuePattern>,
194}
195
196/// Pattern for identifying specific issues.
197#[derive(Debug, Clone)]
198pub struct IssuePattern {
199    /// Pattern name
200    pub name: String,
201    /// Pattern description
202    pub description: String,
203    /// Conditions that must be met
204    pub conditions: Vec<PatternCondition>,
205    /// Associated issue category
206    pub issue_category: IssueCategory,
207    /// Confidence weight for this pattern
208    pub confidence_weight: f64,
209    /// Recommended solutions
210    pub solutions: Vec<String>,
211}
212
213/// Condition for pattern matching.
214#[derive(Debug, Clone)]
215pub struct PatternCondition {
216    /// Metric name
217    pub metric: String,
218    /// Comparison operator
219    pub operator: ComparisonOperator,
220    /// Threshold value
221    pub threshold: f64,
222    /// Number of consecutive occurrences required
223    pub consecutive_count: usize,
224}
225
226/// Comparison operators for pattern conditions.
227#[derive(Debug, Clone)]
228pub enum ComparisonOperator {
229    /// Greater than
230    GreaterThan,
231    /// Less than
232    LessThan,
233    /// Equal to (within tolerance)
234    EqualTo,
235    /// Increasing trend
236    Increasing,
237    /// Decreasing trend
238    Decreasing,
239    /// Oscillating pattern
240    Oscillating,
241}
242
243/// Session statistics for debugging analysis.
244#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
245pub struct SessionStatistics {
246    /// Total issues identified
247    pub total_issues: usize,
248    /// Issues by category
249    pub issues_by_category: HashMap<IssueCategory, usize>,
250    /// Total recommendations generated
251    pub total_recommendations: usize,
252    /// Average confidence of recommendations
253    pub avg_recommendation_confidence: f64,
254    /// Analysis time taken
255    pub analysis_duration: chrono::Duration,
256}
257
258impl Default for AutoDebugConfig {
259    fn default() -> Self {
260        Self {
261            max_history_size: 1000,
262            min_samples_for_analysis: 10,
263            recommendation_confidence_threshold: 0.7,
264            enable_advanced_patterns: true,
265            enable_hyperparameter_suggestions: true,
266            enable_architectural_recommendations: true,
267        }
268    }
269}
270
271impl AutoDebugger {
272    /// Create a new auto-debugger.
273    pub fn new() -> Self {
274        Self {
275            config: AutoDebugConfig::default(),
276            performance_history: VecDeque::new(),
277            layer_history: HashMap::new(),
278            dynamics_history: VecDeque::new(),
279            issue_patterns: IssuePatternDatabase::new(),
280            session_state: DebuggingSession::new(),
281        }
282    }
283
284    /// Create auto-debugger with custom configuration.
285    pub fn with_config(config: AutoDebugConfig) -> Self {
286        Self {
287            config,
288            performance_history: VecDeque::new(),
289            layer_history: HashMap::new(),
290            dynamics_history: VecDeque::new(),
291            issue_patterns: IssuePatternDatabase::new(),
292            session_state: DebuggingSession::new(),
293        }
294    }
295
296    /// Record performance metrics for analysis.
297    pub fn record_performance_metrics(&mut self, metrics: ModelPerformanceMetrics) {
298        self.performance_history.push_back(metrics);
299
300        while self.performance_history.len() > self.config.max_history_size {
301            self.performance_history.pop_front();
302        }
303    }
304
305    /// Record layer statistics for analysis.
306    pub fn record_layer_stats(&mut self, stats: LayerActivationStats) {
307        let layer_name = stats.layer_name.clone();
308
309        let layer_history = self.layer_history.entry(layer_name).or_default();
310        layer_history.push_back(stats);
311
312        while layer_history.len() > self.config.max_history_size {
313            layer_history.pop_front();
314        }
315    }
316
317    /// Record training dynamics for analysis.
318    pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) {
319        self.dynamics_history.push_back(dynamics);
320
321        while self.dynamics_history.len() > self.config.max_history_size {
322            self.dynamics_history.pop_front();
323        }
324    }
325
326    /// Perform comprehensive auto-debugging analysis.
327    pub fn perform_analysis(&mut self) -> Result<DebuggingReport> {
328        let analysis_start = chrono::Utc::now();
329
330        if self.performance_history.len() < self.config.min_samples_for_analysis {
331            return Err(anyhow::anyhow!("Insufficient data for analysis"));
332        }
333
334        // Clear previous session state
335        self.session_state = DebuggingSession::new();
336
337        // Analyze different aspects
338        self.analyze_learning_rate_issues()?;
339        self.analyze_convergence_issues()?;
340        self.analyze_gradient_flow_issues()?;
341        self.analyze_layer_health_issues()?;
342        self.analyze_memory_issues()?;
343        self.analyze_overfitting_underfitting()?;
344
345        // Generate recommendations based on identified issues
346        self.generate_recommendations()?;
347
348        // Update session statistics
349        self.session_state.session_stats.analysis_duration = chrono::Utc::now() - analysis_start;
350        self.update_session_statistics();
351
352        Ok(DebuggingReport {
353            session_info: self.session_state.clone(),
354            identified_issues: self.session_state.identified_issues.clone(),
355            recommendations: self.session_state.recommendations.clone(),
356            summary: self.generate_analysis_summary(),
357        })
358    }
359
360    /// Analyze learning rate related issues.
361    fn analyze_learning_rate_issues(&mut self) -> Result<()> {
362        let recent_metrics: Vec<_> = self.performance_history.iter().rev().take(20).collect();
363        if recent_metrics.len() < 10 {
364            return Ok(());
365        }
366
367        let mut issues_to_add = Vec::new();
368
369        // Check for loss explosion
370        let recent_losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
371        if let Some(max_loss) = recent_losses.iter().max_by(|a, b| a.partial_cmp(b).unwrap()) {
372            if let Some(min_loss) = recent_losses.iter().min_by(|a, b| a.partial_cmp(b).unwrap()) {
373                if max_loss / min_loss > 10.0 {
374                    issues_to_add.push(IdentifiedIssue {
375                        category: IssueCategory::LearningRate,
376                        description: "Learning rate too high - loss explosion detected".to_string(),
377                        severity: IssueSeverity::Critical,
378                        confidence: 0.9,
379                        evidence: vec![
380                            format!("Loss ratio: {:.2}", max_loss / min_loss),
381                            "Rapid loss increase observed".to_string(),
382                        ],
383                        potential_causes: vec![
384                            "Learning rate set too high".to_string(),
385                            "Gradient clipping disabled".to_string(),
386                            "Numerical instability".to_string(),
387                        ],
388                        identified_at: chrono::Utc::now(),
389                    });
390                }
391            }
392        }
393
394        // Check for learning stagnation
395        let loss_variance = self.calculate_variance(&recent_losses);
396        let recent_metrics_len = recent_metrics.len();
397        if loss_variance < 1e-6 && recent_metrics_len >= 15 {
398            issues_to_add.push(IdentifiedIssue {
399                category: IssueCategory::LearningRate,
400                description: "Learning rate too low - training stagnation".to_string(),
401                severity: IssueSeverity::Major,
402                confidence: 0.8,
403                evidence: vec![
404                    format!("Loss variance: {:.2e}", loss_variance),
405                    "No learning progress in recent steps".to_string(),
406                ],
407                potential_causes: vec![
408                    "Learning rate set too low".to_string(),
409                    "Learning rate decay too aggressive".to_string(),
410                    "Model has converged".to_string(),
411                ],
412                identified_at: chrono::Utc::now(),
413            });
414        }
415
416        // Add all collected issues
417        for issue in issues_to_add {
418            self.add_issue(issue);
419        }
420
421        Ok(())
422    }
423
424    /// Analyze convergence related issues.
425    fn analyze_convergence_issues(&mut self) -> Result<()> {
426        if let Some(latest_dynamics) = self.dynamics_history.back() {
427            match latest_dynamics.convergence_status {
428                ConvergenceStatus::Diverging => {
429                    self.add_issue(IdentifiedIssue {
430                        category: IssueCategory::Convergence,
431                        description: "Training is diverging".to_string(),
432                        severity: IssueSeverity::Critical,
433                        confidence: 0.95,
434                        evidence: vec!["Convergence status: Diverging".to_string()],
435                        potential_causes: vec![
436                            "Learning rate too high".to_string(),
437                            "Gradient explosion".to_string(),
438                            "Numerical instability".to_string(),
439                        ],
440                        identified_at: chrono::Utc::now(),
441                    });
442                },
443                ConvergenceStatus::Plateau => {
444                    if let Some(plateau_info) = &latest_dynamics.plateau_detection {
445                        if plateau_info.duration_steps > 100 {
446                            self.add_issue(IdentifiedIssue {
447                                category: IssueCategory::Convergence,
448                                description: "Training has plateaued".to_string(),
449                                severity: IssueSeverity::Moderate,
450                                confidence: 0.8,
451                                evidence: vec![
452                                    format!(
453                                        "Plateau duration: {} steps",
454                                        plateau_info.duration_steps
455                                    ),
456                                    format!("Plateau value: {:.4}", plateau_info.plateau_value),
457                                ],
458                                potential_causes: vec![
459                                    "Learning rate too low".to_string(),
460                                    "Model capacity insufficient".to_string(),
461                                    "Local minimum reached".to_string(),
462                                ],
463                                identified_at: chrono::Utc::now(),
464                            });
465                        }
466                    }
467                },
468                ConvergenceStatus::Oscillating => {
469                    self.add_issue(IdentifiedIssue {
470                        category: IssueCategory::NumericalStability,
471                        description: "Training is oscillating".to_string(),
472                        severity: IssueSeverity::Moderate,
473                        confidence: 0.7,
474                        evidence: vec!["Convergence status: Oscillating".to_string()],
475                        potential_causes: vec![
476                            "Learning rate too high".to_string(),
477                            "Batch size too small".to_string(),
478                            "Momentum settings suboptimal".to_string(),
479                        ],
480                        identified_at: chrono::Utc::now(),
481                    });
482                },
483                _ => {},
484            }
485        }
486
487        Ok(())
488    }
489
490    /// Analyze gradient flow issues.
491    fn analyze_gradient_flow_issues(&mut self) -> Result<()> {
492        let mut issues_to_add = Vec::new();
493
494        // Check layer statistics for gradient flow problems
495        for (layer_name, layer_history) in &self.layer_history {
496            if let Some(latest_stats) = layer_history.back() {
497                // Check for dead neurons
498                if latest_stats.dead_neurons_ratio > 0.5 {
499                    issues_to_add.push(IdentifiedIssue {
500                        category: IssueCategory::GradientFlow,
501                        description: format!("High dead neuron ratio in layer {}", layer_name),
502                        severity: IssueSeverity::Major,
503                        confidence: 0.85,
504                        evidence: vec![
505                            format!(
506                                "Dead neurons: {:.1}%",
507                                latest_stats.dead_neurons_ratio * 100.0
508                            ),
509                            format!("Layer: {}", layer_name),
510                        ],
511                        potential_causes: vec![
512                            "Dying ReLU problem".to_string(),
513                            "Poor weight initialization".to_string(),
514                            "Learning rate too high".to_string(),
515                        ],
516                        identified_at: chrono::Utc::now(),
517                    });
518                }
519
520                // Check for activation saturation
521                if latest_stats.saturated_neurons_ratio > 0.3 {
522                    issues_to_add.push(IdentifiedIssue {
523                        category: IssueCategory::GradientFlow,
524                        description: format!("High activation saturation in layer {}", layer_name),
525                        severity: IssueSeverity::Moderate,
526                        confidence: 0.8,
527                        evidence: vec![
528                            format!(
529                                "Saturated neurons: {:.1}%",
530                                latest_stats.saturated_neurons_ratio * 100.0
531                            ),
532                            format!("Layer: {}", layer_name),
533                        ],
534                        potential_causes: vec![
535                            "Vanishing gradient problem".to_string(),
536                            "Poor activation function choice".to_string(),
537                            "Input normalization issues".to_string(),
538                        ],
539                        identified_at: chrono::Utc::now(),
540                    });
541                }
542            }
543        }
544
545        // Add all collected issues
546        for issue in issues_to_add {
547            self.add_issue(issue);
548        }
549
550        Ok(())
551    }
552
553    /// Analyze layer health issues.
554    fn analyze_layer_health_issues(&mut self) -> Result<()> {
555        let mut issues_to_add = Vec::new();
556
557        for (layer_name, layer_history) in &self.layer_history {
558            if layer_history.len() >= 5 {
559                let recent_stats: Vec<_> = layer_history.iter().rev().take(5).collect();
560
561                // Check for activation variance trends
562                let variances: Vec<f64> = recent_stats.iter().map(|s| s.std_activation).collect();
563                let avg_variance = variances.iter().sum::<f64>() / variances.len() as f64;
564
565                if avg_variance < 0.01 {
566                    issues_to_add.push(IdentifiedIssue {
567                        category: IssueCategory::Architecture,
568                        description: format!("Low activation variance in layer {}", layer_name),
569                        severity: IssueSeverity::Minor,
570                        confidence: 0.6,
571                        evidence: vec![
572                            format!("Average variance: {:.4}", avg_variance),
573                            format!("Layer: {}", layer_name),
574                        ],
575                        potential_causes: vec![
576                            "Poor weight initialization".to_string(),
577                            "Input normalization too aggressive".to_string(),
578                            "Activation function saturation".to_string(),
579                        ],
580                        identified_at: chrono::Utc::now(),
581                    });
582                }
583            }
584        }
585
586        // Add all collected issues
587        for issue in issues_to_add {
588            self.add_issue(issue);
589        }
590
591        Ok(())
592    }
593
594    /// Analyze memory usage issues.
595    fn analyze_memory_issues(&mut self) -> Result<()> {
596        if self.performance_history.len() >= 10 {
597            let recent_memory: Vec<f64> = self
598                .performance_history
599                .iter()
600                .rev()
601                .take(10)
602                .map(|m| m.memory_usage_mb)
603                .collect();
604
605            // Check for memory leaks
606            let memory_trend = self.calculate_trend(&recent_memory);
607            if memory_trend > 10.0 {
608                // MB per step
609                self.add_issue(IdentifiedIssue {
610                    category: IssueCategory::Memory,
611                    description: "Memory leak detected".to_string(),
612                    severity: IssueSeverity::Critical,
613                    confidence: 0.9,
614                    evidence: vec![
615                        format!("Memory growth rate: {:.2} MB/step", memory_trend),
616                        "Increasing memory usage trend".to_string(),
617                    ],
618                    potential_causes: vec![
619                        "Gradient accumulation without clearing".to_string(),
620                        "Cached tensors not being released".to_string(),
621                        "Memory fragmentation".to_string(),
622                    ],
623                    identified_at: chrono::Utc::now(),
624                });
625            }
626
627            // Check for excessive memory usage
628            if let Some(max_memory) = recent_memory.iter().max_by(|a, b| a.partial_cmp(b).unwrap())
629            {
630                if *max_memory > 16384.0 {
631                    // 16GB
632                    self.add_issue(IdentifiedIssue {
633                        category: IssueCategory::Memory,
634                        description: "Excessive memory usage detected".to_string(),
635                        severity: IssueSeverity::Major,
636                        confidence: 0.8,
637                        evidence: vec![
638                            format!("Peak memory: {:.0} MB", max_memory),
639                            "High memory consumption".to_string(),
640                        ],
641                        potential_causes: vec![
642                            "Batch size too large".to_string(),
643                            "Model too large for available memory".to_string(),
644                            "Inefficient memory allocation".to_string(),
645                        ],
646                        identified_at: chrono::Utc::now(),
647                    });
648                }
649            }
650        }
651
652        Ok(())
653    }
654
655    /// Analyze overfitting and underfitting issues.
656    fn analyze_overfitting_underfitting(&mut self) -> Result<()> {
657        let mut issues_to_add = Vec::new();
658
659        if let Some(latest_dynamics) = self.dynamics_history.back() {
660            // Check for overfitting indicators
661            for indicator in &latest_dynamics.overfitting_indicators {
662                if let super::types::OverfittingIndicator::TrainValidationGap { gap } = indicator {
663                    if *gap > 0.1 {
664                        issues_to_add.push(IdentifiedIssue {
665                            category: IssueCategory::Overfitting,
666                            description: "Large training-validation gap detected".to_string(),
667                            severity: IssueSeverity::Major,
668                            confidence: 0.85,
669                            evidence: vec![
670                                format!("Train-validation gap: {:.3}", gap),
671                                "Overfitting indicator present".to_string(),
672                            ],
673                            potential_causes: vec![
674                                "Model complexity too high".to_string(),
675                                "Insufficient regularization".to_string(),
676                                "Training set too small".to_string(),
677                            ],
678                            identified_at: chrono::Utc::now(),
679                        });
680                    }
681                }
682            }
683
684            // Check for underfitting indicators
685            for indicator in &latest_dynamics.underfitting_indicators {
686                match indicator {
687                    super::types::UnderfittingIndicator::HighTrainingLoss { loss, threshold } => {
688                        issues_to_add.push(IdentifiedIssue {
689                            category: IssueCategory::Underfitting,
690                            description: "High training loss indicates underfitting".to_string(),
691                            severity: IssueSeverity::Moderate,
692                            confidence: 0.7,
693                            evidence: vec![
694                                format!("Training loss: {:.3}", loss),
695                                format!("Threshold: {:.3}", threshold),
696                            ],
697                            potential_causes: vec![
698                                "Model capacity too low".to_string(),
699                                "Learning rate too low".to_string(),
700                                "Insufficient training time".to_string(),
701                            ],
702                            identified_at: chrono::Utc::now(),
703                        });
704                    },
705                    super::types::UnderfittingIndicator::SlowConvergence {
706                        steps_taken,
707                        expected,
708                    } => {
709                        issues_to_add.push(IdentifiedIssue {
710                            category: IssueCategory::Underfitting,
711                            description: "Slow convergence detected".to_string(),
712                            severity: IssueSeverity::Minor,
713                            confidence: 0.6,
714                            evidence: vec![
715                                format!("Steps taken: {}", steps_taken),
716                                format!("Expected: {}", expected),
717                            ],
718                            potential_causes: vec![
719                                "Learning rate too conservative".to_string(),
720                                "Optimizer choice suboptimal".to_string(),
721                                "Poor initialization".to_string(),
722                            ],
723                            identified_at: chrono::Utc::now(),
724                        });
725                    },
726                    _ => {},
727                }
728            }
729        }
730
731        // Add all collected issues
732        for issue in issues_to_add {
733            self.add_issue(issue);
734        }
735
736        Ok(())
737    }
738
739    /// Generate recommendations based on identified issues.
740    fn generate_recommendations(&mut self) -> Result<()> {
741        for issue in &self.session_state.identified_issues {
742            let recommendations = self.generate_recommendations_for_issue(issue);
743            self.session_state.recommendations.extend(recommendations);
744        }
745
746        // Sort recommendations by priority and confidence
747        self.session_state.recommendations.sort_by(|a, b| {
748            b.priority
749                .partial_cmp(&a.priority)
750                .unwrap()
751                .then(b.confidence.partial_cmp(&a.confidence).unwrap())
752        });
753
754        Ok(())
755    }
756
757    /// Generate specific recommendations for an issue.
758    fn generate_recommendations_for_issue(
759        &self,
760        issue: &IdentifiedIssue,
761    ) -> Vec<DebuggingRecommendation> {
762        match issue.category {
763            IssueCategory::LearningRate => {
764                if issue.description.contains("too high") {
765                    vec![DebuggingRecommendation {
766                        category: RecommendationCategory::HyperparameterTuning,
767                        title: "Reduce Learning Rate".to_string(),
768                        description: "Lower the learning rate to stabilize training".to_string(),
769                        actions: vec![
770                            "Reduce learning rate by factor of 2-10".to_string(),
771                            "Enable gradient clipping".to_string(),
772                            "Consider learning rate scheduling".to_string(),
773                        ],
774                        expected_impact: "Stabilized training with reduced loss oscillations"
775                            .to_string(),
776                        confidence: 0.9,
777                        priority: AutoDebugRecommendationPriority::High,
778                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
779                            parameter_name: "learning_rate".to_string(),
780                            current_value: None,
781                            suggested_value: 0.0001,
782                            reasoning: "Reduce to prevent loss explosion".to_string(),
783                            expected_effect: "More stable training".to_string(),
784                        }],
785                    }]
786                } else if issue.description.contains("too low") {
787                    vec![DebuggingRecommendation {
788                        category: RecommendationCategory::HyperparameterTuning,
789                        title: "Increase Learning Rate".to_string(),
790                        description: "Increase learning rate to improve convergence speed"
791                            .to_string(),
792                        actions: vec![
793                            "Increase learning rate by factor of 2-5".to_string(),
794                            "Use learning rate warmup".to_string(),
795                            "Consider adaptive learning rate methods".to_string(),
796                        ],
797                        expected_impact: "Faster convergence and better final performance"
798                            .to_string(),
799                        confidence: 0.8,
800                        priority: AutoDebugRecommendationPriority::Medium,
801                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
802                            parameter_name: "learning_rate".to_string(),
803                            current_value: None,
804                            suggested_value: 0.001,
805                            reasoning: "Increase to improve learning speed".to_string(),
806                            expected_effect: "Faster convergence".to_string(),
807                        }],
808                    }]
809                } else {
810                    Vec::new()
811                }
812            },
813            IssueCategory::Memory => {
814                vec![DebuggingRecommendation {
815                    category: RecommendationCategory::ResourceOptimization,
816                    title: "Optimize Memory Usage".to_string(),
817                    description: "Implement memory optimization strategies".to_string(),
818                    actions: vec![
819                        "Reduce batch size".to_string(),
820                        "Enable gradient checkpointing".to_string(),
821                        "Clear cached tensors regularly".to_string(),
822                        "Use mixed precision training".to_string(),
823                    ],
824                    expected_impact: "Reduced memory consumption and stable training".to_string(),
825                    confidence: 0.85,
826                    priority: AutoDebugRecommendationPriority::High,
827                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
828                        parameter_name: "batch_size".to_string(),
829                        current_value: None,
830                        suggested_value: 16.0,
831                        reasoning: "Reduce to lower memory usage".to_string(),
832                        expected_effect: "Lower memory consumption".to_string(),
833                    }],
834                }]
835            },
836            IssueCategory::Overfitting => {
837                vec![DebuggingRecommendation {
838                    category: RecommendationCategory::TrainingStrategy,
839                    title: "Address Overfitting".to_string(),
840                    description: "Implement regularization strategies to reduce overfitting"
841                        .to_string(),
842                    actions: vec![
843                        "Add dropout layers".to_string(),
844                        "Increase weight decay".to_string(),
845                        "Use data augmentation".to_string(),
846                        "Reduce model complexity".to_string(),
847                        "Implement early stopping".to_string(),
848                    ],
849                    expected_impact: "Better generalization and validation performance".to_string(),
850                    confidence: 0.8,
851                    priority: AutoDebugRecommendationPriority::Medium,
852                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
853                        parameter_name: "dropout_rate".to_string(),
854                        current_value: None,
855                        suggested_value: 0.1,
856                        reasoning: "Add regularization to reduce overfitting".to_string(),
857                        expected_effect: "Better generalization".to_string(),
858                    }],
859                }]
860            },
861            IssueCategory::GradientFlow => {
862                vec![DebuggingRecommendation {
863                    category: RecommendationCategory::ArchitecturalModification,
864                    title: "Improve Gradient Flow".to_string(),
865                    description: "Address gradient flow issues in the network".to_string(),
866                    actions: vec![
867                        "Use different activation functions (e.g., Leaky ReLU, Swish)".to_string(),
868                        "Add batch normalization".to_string(),
869                        "Implement residual connections".to_string(),
870                        "Adjust weight initialization".to_string(),
871                    ],
872                    expected_impact: "Better gradient flow and training stability".to_string(),
873                    confidence: 0.75,
874                    priority: AutoDebugRecommendationPriority::Medium,
875                    hyperparameter_suggestions: Vec::new(),
876                }]
877            },
878            _ => Vec::new(),
879        }
880    }
881
882    /// Add an issue to the current session.
883    fn add_issue(&mut self, issue: IdentifiedIssue) {
884        self.session_state.identified_issues.push(issue);
885    }
886
887    /// Calculate variance of a sequence of values.
888    fn calculate_variance(&self, values: &[f64]) -> f64 {
889        if values.len() < 2 {
890            return 0.0;
891        }
892
893        let mean = values.iter().sum::<f64>() / values.len() as f64;
894        let variance =
895            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
896
897        variance
898    }
899
900    /// Calculate trend (slope) of a sequence of values.
901    fn calculate_trend(&self, values: &[f64]) -> f64 {
902        if values.len() < 2 {
903            return 0.0;
904        }
905
906        let n = values.len() as f64;
907        let x_mean = (n - 1.0) / 2.0;
908        let y_mean = values.iter().sum::<f64>() / n;
909
910        let numerator: f64 = values
911            .iter()
912            .enumerate()
913            .map(|(i, &y)| (i as f64 - x_mean) * (y - y_mean))
914            .sum();
915
916        let denominator: f64 = (0..values.len()).map(|i| (i as f64 - x_mean).powi(2)).sum();
917
918        if denominator == 0.0 {
919            0.0
920        } else {
921            numerator / denominator
922        }
923    }
924
925    /// Update session statistics.
926    fn update_session_statistics(&mut self) {
927        let mut issues_by_category = HashMap::new();
928        for issue in &self.session_state.identified_issues {
929            *issues_by_category.entry(issue.category.clone()).or_insert(0) += 1;
930        }
931
932        let avg_confidence = if self.session_state.recommendations.is_empty() {
933            0.0
934        } else {
935            self.session_state.recommendations.iter().map(|r| r.confidence).sum::<f64>()
936                / self.session_state.recommendations.len() as f64
937        };
938
939        self.session_state.session_stats = SessionStatistics {
940            total_issues: self.session_state.identified_issues.len(),
941            issues_by_category,
942            total_recommendations: self.session_state.recommendations.len(),
943            avg_recommendation_confidence: avg_confidence,
944            analysis_duration: self.session_state.session_stats.analysis_duration,
945        };
946    }
947
948    /// Generate analysis summary.
949    fn generate_analysis_summary(&self) -> String {
950        let critical_issues = self
951            .session_state
952            .identified_issues
953            .iter()
954            .filter(|i| i.severity == IssueSeverity::Critical)
955            .count();
956
957        let major_issues = self
958            .session_state
959            .identified_issues
960            .iter()
961            .filter(|i| i.severity == IssueSeverity::Major)
962            .count();
963
964        let high_priority_recommendations = self
965            .session_state
966            .recommendations
967            .iter()
968            .filter(|r| r.priority == AutoDebugRecommendationPriority::High)
969            .count();
970
971        format!(
972            "Auto-debugging analysis completed. Found {} critical issues, {} major issues. \
973            Generated {} recommendations with {} high-priority actions. \
974            Average recommendation confidence: {:.2}",
975            critical_issues,
976            major_issues,
977            self.session_state.recommendations.len(),
978            high_priority_recommendations,
979            self.session_state.session_stats.avg_recommendation_confidence
980        )
981    }
982}
983
984/// Comprehensive debugging report.
985#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
986pub struct DebuggingReport {
987    /// Session information
988    pub session_info: DebuggingSession,
989    /// All identified issues
990    pub identified_issues: Vec<IdentifiedIssue>,
991    /// Generated recommendations
992    pub recommendations: Vec<DebuggingRecommendation>,
993    /// Analysis summary
994    pub summary: String,
995}
996
997impl IssuePatternDatabase {
998    /// Create a new pattern database with default patterns.
999    pub fn new() -> Self {
1000        Self {
1001            learning_rate_patterns: Self::create_learning_rate_patterns(),
1002            gradient_patterns: Self::create_gradient_patterns(),
1003            convergence_patterns: Self::create_convergence_patterns(),
1004            layer_patterns: Self::create_layer_patterns(),
1005        }
1006    }
1007
1008    /// Create default learning rate patterns.
1009    fn create_learning_rate_patterns() -> Vec<IssuePattern> {
1010        vec![IssuePattern {
1011            name: "Loss Explosion".to_string(),
1012            description: "Rapid increase in loss indicating learning rate too high".to_string(),
1013            conditions: vec![PatternCondition {
1014                metric: "loss".to_string(),
1015                operator: ComparisonOperator::Increasing,
1016                threshold: 2.0,
1017                consecutive_count: 3,
1018            }],
1019            issue_category: IssueCategory::LearningRate,
1020            confidence_weight: 0.9,
1021            solutions: vec![
1022                "Reduce learning rate by factor of 10".to_string(),
1023                "Enable gradient clipping".to_string(),
1024            ],
1025        }]
1026    }
1027
1028    /// Create default gradient patterns.
1029    fn create_gradient_patterns() -> Vec<IssuePattern> {
1030        vec![]
1031    }
1032
1033    /// Create default convergence patterns.
1034    fn create_convergence_patterns() -> Vec<IssuePattern> {
1035        vec![]
1036    }
1037
1038    /// Create default layer patterns.
1039    fn create_layer_patterns() -> Vec<IssuePattern> {
1040        vec![]
1041    }
1042}
1043
1044impl DebuggingSession {
1045    /// Create a new debugging session.
1046    fn new() -> Self {
1047        Self {
1048            session_start: chrono::Utc::now(),
1049            identified_issues: Vec::new(),
1050            recommendations: Vec::new(),
1051            session_stats: SessionStatistics {
1052                total_issues: 0,
1053                issues_by_category: HashMap::new(),
1054                total_recommendations: 0,
1055                avg_recommendation_confidence: 0.0,
1056                analysis_duration: chrono::Duration::zero(),
1057            },
1058        }
1059    }
1060}
1061
1062impl Default for AutoDebugger {
1063    fn default() -> Self {
1064        Self::new()
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071
1072    #[test]
1073    fn test_auto_debugger_creation() {
1074        let debugger = AutoDebugger::new();
1075        assert_eq!(debugger.performance_history.len(), 0);
1076        assert_eq!(debugger.layer_history.len(), 0);
1077    }
1078
1079    #[test]
1080    fn test_issue_identification() {
1081        let mut debugger = AutoDebugger::new();
1082
1083        let issue = IdentifiedIssue {
1084            category: IssueCategory::LearningRate,
1085            description: "Test issue".to_string(),
1086            severity: IssueSeverity::Major,
1087            confidence: 0.8,
1088            evidence: vec!["Test evidence".to_string()],
1089            potential_causes: vec!["Test cause".to_string()],
1090            identified_at: chrono::Utc::now(),
1091        };
1092
1093        debugger.add_issue(issue);
1094        assert_eq!(debugger.session_state.identified_issues.len(), 1);
1095    }
1096
1097    #[test]
1098    fn test_variance_calculation() {
1099        let debugger = AutoDebugger::new();
1100        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1101        let variance = debugger.calculate_variance(&values);
1102        assert!(variance > 0.0);
1103    }
1104
1105    #[test]
1106    fn test_trend_calculation() {
1107        let debugger = AutoDebugger::new();
1108        let increasing_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1109        let trend = debugger.calculate_trend(&increasing_values);
1110        assert!(trend > 0.0);
1111    }
1112}