Skip to main content

trustformers_debug/model_diagnostics/
auto_debug.rs

1//! Auto-debugging and automated recommendation system.
2//!
3//! This module provides intelligent debugging capabilities that automatically
4//! analyze model behavior, identify potential issues, and generate actionable
5//! recommendations for optimization and problem resolution.
6
7use anyhow::Result;
8use std::collections::{HashMap, VecDeque};
9
10use super::types::{
11    ConvergenceStatus, LayerActivationStats, ModelPerformanceMetrics, TrainingDynamics,
12};
13
14/// Auto-debugging system for intelligent model analysis.
15#[derive(Debug)]
16pub struct AutoDebugger {
17    /// Debugging configuration
18    config: AutoDebugConfig,
19    /// Historical performance data for analysis
20    performance_history: VecDeque<ModelPerformanceMetrics>,
21    /// Layer statistics history
22    layer_history: HashMap<String, VecDeque<LayerActivationStats>>,
23    /// Training dynamics history
24    dynamics_history: VecDeque<TrainingDynamics>,
25    /// Known issue patterns and solutions
26    #[allow(dead_code)]
27    issue_patterns: IssuePatternDatabase,
28    /// Current debugging session state
29    session_state: DebuggingSession,
30}
31
32/// Configuration for auto-debugging system.
33#[derive(Debug, Clone)]
34pub struct AutoDebugConfig {
35    /// Maximum history size for analysis
36    pub max_history_size: usize,
37    /// Minimum samples required for pattern detection
38    pub min_samples_for_analysis: usize,
39    /// Confidence threshold for recommendations
40    pub recommendation_confidence_threshold: f64,
41    /// Enable advanced pattern recognition
42    pub enable_advanced_patterns: bool,
43    /// Enable hyperparameter suggestions
44    pub enable_hyperparameter_suggestions: bool,
45    /// Enable architectural recommendations
46    pub enable_architectural_recommendations: bool,
47}
48
49/// Current debugging session state.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct DebuggingSession {
52    /// Session start time
53    pub session_start: chrono::DateTime<chrono::Utc>,
54    /// Issues identified in current session
55    pub identified_issues: Vec<IdentifiedIssue>,
56    /// Recommendations generated
57    pub recommendations: Vec<DebuggingRecommendation>,
58    /// Session statistics
59    pub session_stats: SessionStatistics,
60}
61
62/// An identified issue with diagnostic information.
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct IdentifiedIssue {
65    /// Issue category
66    pub category: IssueCategory,
67    /// Issue description
68    pub description: String,
69    /// Severity level
70    pub severity: IssueSeverity,
71    /// Confidence in identification
72    pub confidence: f64,
73    /// Evidence supporting the identification
74    pub evidence: Vec<String>,
75    /// Potential causes
76    pub potential_causes: Vec<String>,
77    /// When issue was identified
78    pub identified_at: chrono::DateTime<chrono::Utc>,
79}
80
81/// Categories of issues that can be identified.
82#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
83pub enum IssueCategory {
84    /// Learning rate related issues
85    LearningRate,
86    /// Gradient flow problems
87    GradientFlow,
88    /// Overfitting issues
89    Overfitting,
90    /// Underfitting issues
91    Underfitting,
92    /// Data quality problems
93    DataQuality,
94    /// Architecture inefficiencies
95    Architecture,
96    /// Memory management issues
97    Memory,
98    /// Convergence problems
99    Convergence,
100    /// Numerical stability issues
101    NumericalStability,
102}
103
104/// Severity levels for identified issues.
105#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
106pub enum IssueSeverity {
107    /// Minor issue with low impact
108    Minor,
109    /// Moderate issue affecting performance
110    Moderate,
111    /// Major issue requiring attention
112    Major,
113    /// Critical issue requiring immediate action
114    Critical,
115}
116
117/// Auto-generated debugging recommendation.
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct DebuggingRecommendation {
120    /// Recommendation category
121    pub category: RecommendationCategory,
122    /// Recommendation title
123    pub title: String,
124    /// Detailed description
125    pub description: String,
126    /// Specific actions to take
127    pub actions: Vec<String>,
128    /// Expected impact
129    pub expected_impact: String,
130    /// Confidence in recommendation
131    pub confidence: f64,
132    /// Priority level
133    pub priority: AutoDebugRecommendationPriority,
134    /// Relevant hyperparameters to adjust
135    pub hyperparameter_suggestions: Vec<HyperparameterSuggestion>,
136}
137
138/// Categories of recommendations.
139#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub enum RecommendationCategory {
141    /// Hyperparameter adjustments
142    HyperparameterTuning,
143    /// Architectural changes
144    ArchitecturalModification,
145    /// Data preprocessing recommendations
146    DataPreprocessing,
147    /// Training strategy changes
148    TrainingStrategy,
149    /// Debugging and monitoring
150    DebuggingAndMonitoring,
151    /// Resource optimization
152    ResourceOptimization,
153}
154
155/// Priority levels for recommendations.
156#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
157pub enum AutoDebugRecommendationPriority {
158    /// Low priority suggestion
159    Low,
160    /// Medium priority recommendation
161    Medium,
162    /// High priority action needed
163    High,
164    /// Urgent action required
165    Urgent,
166}
167
168/// Hyperparameter adjustment suggestion.
169#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
170pub struct HyperparameterSuggestion {
171    /// Parameter name
172    pub parameter_name: String,
173    /// Current value (if known)
174    pub current_value: Option<f64>,
175    /// Suggested value
176    pub suggested_value: f64,
177    /// Adjustment reasoning
178    pub reasoning: String,
179    /// Expected effect
180    pub expected_effect: String,
181}
182
183/// Database of known issue patterns and solutions.
184#[derive(Debug, Clone)]
185pub struct IssuePatternDatabase {
186    /// Learning rate patterns
187    pub learning_rate_patterns: Vec<IssuePattern>,
188    /// Gradient patterns
189    pub gradient_patterns: Vec<IssuePattern>,
190    /// Convergence patterns
191    pub convergence_patterns: Vec<IssuePattern>,
192    /// Layer behavior patterns
193    pub layer_patterns: Vec<IssuePattern>,
194}
195
196/// Pattern for identifying specific issues.
197#[derive(Debug, Clone)]
198pub struct IssuePattern {
199    /// Pattern name
200    pub name: String,
201    /// Pattern description
202    pub description: String,
203    /// Conditions that must be met
204    pub conditions: Vec<PatternCondition>,
205    /// Associated issue category
206    pub issue_category: IssueCategory,
207    /// Confidence weight for this pattern
208    pub confidence_weight: f64,
209    /// Recommended solutions
210    pub solutions: Vec<String>,
211}
212
213/// Condition for pattern matching.
214#[derive(Debug, Clone)]
215pub struct PatternCondition {
216    /// Metric name
217    pub metric: String,
218    /// Comparison operator
219    pub operator: ComparisonOperator,
220    /// Threshold value
221    pub threshold: f64,
222    /// Number of consecutive occurrences required
223    pub consecutive_count: usize,
224}
225
226/// Comparison operators for pattern conditions.
227#[derive(Debug, Clone)]
228pub enum ComparisonOperator {
229    /// Greater than
230    GreaterThan,
231    /// Less than
232    LessThan,
233    /// Equal to (within tolerance)
234    EqualTo,
235    /// Increasing trend
236    Increasing,
237    /// Decreasing trend
238    Decreasing,
239    /// Oscillating pattern
240    Oscillating,
241}
242
243/// Session statistics for debugging analysis.
244#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
245pub struct SessionStatistics {
246    /// Total issues identified
247    pub total_issues: usize,
248    /// Issues by category
249    pub issues_by_category: HashMap<IssueCategory, usize>,
250    /// Total recommendations generated
251    pub total_recommendations: usize,
252    /// Average confidence of recommendations
253    pub avg_recommendation_confidence: f64,
254    /// Analysis time taken
255    pub analysis_duration: chrono::Duration,
256}
257
258impl Default for AutoDebugConfig {
259    fn default() -> Self {
260        Self {
261            max_history_size: 1000,
262            min_samples_for_analysis: 10,
263            recommendation_confidence_threshold: 0.7,
264            enable_advanced_patterns: true,
265            enable_hyperparameter_suggestions: true,
266            enable_architectural_recommendations: true,
267        }
268    }
269}
270
271impl AutoDebugger {
272    /// Create a new auto-debugger.
273    pub fn new() -> Self {
274        Self {
275            config: AutoDebugConfig::default(),
276            performance_history: VecDeque::new(),
277            layer_history: HashMap::new(),
278            dynamics_history: VecDeque::new(),
279            issue_patterns: IssuePatternDatabase::new(),
280            session_state: DebuggingSession::new(),
281        }
282    }
283
284    /// Create auto-debugger with custom configuration.
285    pub fn with_config(config: AutoDebugConfig) -> Self {
286        Self {
287            config,
288            performance_history: VecDeque::new(),
289            layer_history: HashMap::new(),
290            dynamics_history: VecDeque::new(),
291            issue_patterns: IssuePatternDatabase::new(),
292            session_state: DebuggingSession::new(),
293        }
294    }
295
296    /// Record performance metrics for analysis.
297    pub fn record_performance_metrics(&mut self, metrics: ModelPerformanceMetrics) {
298        self.performance_history.push_back(metrics);
299
300        while self.performance_history.len() > self.config.max_history_size {
301            self.performance_history.pop_front();
302        }
303    }
304
305    /// Record layer statistics for analysis.
306    pub fn record_layer_stats(&mut self, stats: LayerActivationStats) {
307        let layer_name = stats.layer_name.clone();
308
309        let layer_history = self.layer_history.entry(layer_name).or_default();
310        layer_history.push_back(stats);
311
312        while layer_history.len() > self.config.max_history_size {
313            layer_history.pop_front();
314        }
315    }
316
317    /// Record training dynamics for analysis.
318    pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) {
319        self.dynamics_history.push_back(dynamics);
320
321        while self.dynamics_history.len() > self.config.max_history_size {
322            self.dynamics_history.pop_front();
323        }
324    }
325
326    /// Perform comprehensive auto-debugging analysis.
327    pub fn perform_analysis(&mut self) -> Result<DebuggingReport> {
328        let analysis_start = chrono::Utc::now();
329
330        if self.performance_history.len() < self.config.min_samples_for_analysis {
331            return Err(anyhow::anyhow!("Insufficient data for analysis"));
332        }
333
334        // Clear previous session state
335        self.session_state = DebuggingSession::new();
336
337        // Analyze different aspects
338        self.analyze_learning_rate_issues()?;
339        self.analyze_convergence_issues()?;
340        self.analyze_gradient_flow_issues()?;
341        self.analyze_layer_health_issues()?;
342        self.analyze_memory_issues()?;
343        self.analyze_overfitting_underfitting()?;
344
345        // Generate recommendations based on identified issues
346        self.generate_recommendations()?;
347
348        // Update session statistics
349        self.session_state.session_stats.analysis_duration = chrono::Utc::now() - analysis_start;
350        self.update_session_statistics();
351
352        Ok(DebuggingReport {
353            session_info: self.session_state.clone(),
354            identified_issues: self.session_state.identified_issues.clone(),
355            recommendations: self.session_state.recommendations.clone(),
356            summary: self.generate_analysis_summary(),
357        })
358    }
359
360    /// Analyze learning rate related issues.
361    fn analyze_learning_rate_issues(&mut self) -> Result<()> {
362        let recent_metrics: Vec<_> = self.performance_history.iter().rev().take(20).collect();
363        if recent_metrics.len() < 10 {
364            return Ok(());
365        }
366
367        let mut issues_to_add = Vec::new();
368
369        // Check for loss explosion
370        let recent_losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
371        if let Some(max_loss) = recent_losses
372            .iter()
373            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
374        {
375            if let Some(min_loss) = recent_losses
376                .iter()
377                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
378            {
379                if max_loss / min_loss > 10.0 {
380                    issues_to_add.push(IdentifiedIssue {
381                        category: IssueCategory::LearningRate,
382                        description: "Learning rate too high - loss explosion detected".to_string(),
383                        severity: IssueSeverity::Critical,
384                        confidence: 0.9,
385                        evidence: vec![
386                            format!("Loss ratio: {:.2}", max_loss / min_loss),
387                            "Rapid loss increase observed".to_string(),
388                        ],
389                        potential_causes: vec![
390                            "Learning rate set too high".to_string(),
391                            "Gradient clipping disabled".to_string(),
392                            "Numerical instability".to_string(),
393                        ],
394                        identified_at: chrono::Utc::now(),
395                    });
396                }
397            }
398        }
399
400        // Check for learning stagnation
401        let loss_variance = self.calculate_variance(&recent_losses);
402        let recent_metrics_len = recent_metrics.len();
403        if loss_variance < 1e-6 && recent_metrics_len >= 15 {
404            issues_to_add.push(IdentifiedIssue {
405                category: IssueCategory::LearningRate,
406                description: "Learning rate too low - training stagnation".to_string(),
407                severity: IssueSeverity::Major,
408                confidence: 0.8,
409                evidence: vec![
410                    format!("Loss variance: {:.2e}", loss_variance),
411                    "No learning progress in recent steps".to_string(),
412                ],
413                potential_causes: vec![
414                    "Learning rate set too low".to_string(),
415                    "Learning rate decay too aggressive".to_string(),
416                    "Model has converged".to_string(),
417                ],
418                identified_at: chrono::Utc::now(),
419            });
420        }
421
422        // Add all collected issues
423        for issue in issues_to_add {
424            self.add_issue(issue);
425        }
426
427        Ok(())
428    }
429
430    /// Analyze convergence related issues.
431    fn analyze_convergence_issues(&mut self) -> Result<()> {
432        if let Some(latest_dynamics) = self.dynamics_history.back() {
433            match latest_dynamics.convergence_status {
434                ConvergenceStatus::Diverging => {
435                    self.add_issue(IdentifiedIssue {
436                        category: IssueCategory::Convergence,
437                        description: "Training is diverging".to_string(),
438                        severity: IssueSeverity::Critical,
439                        confidence: 0.95,
440                        evidence: vec!["Convergence status: Diverging".to_string()],
441                        potential_causes: vec![
442                            "Learning rate too high".to_string(),
443                            "Gradient explosion".to_string(),
444                            "Numerical instability".to_string(),
445                        ],
446                        identified_at: chrono::Utc::now(),
447                    });
448                },
449                ConvergenceStatus::Plateau => {
450                    if let Some(plateau_info) = &latest_dynamics.plateau_detection {
451                        if plateau_info.duration_steps > 100 {
452                            self.add_issue(IdentifiedIssue {
453                                category: IssueCategory::Convergence,
454                                description: "Training has plateaued".to_string(),
455                                severity: IssueSeverity::Moderate,
456                                confidence: 0.8,
457                                evidence: vec![
458                                    format!(
459                                        "Plateau duration: {} steps",
460                                        plateau_info.duration_steps
461                                    ),
462                                    format!("Plateau value: {:.4}", plateau_info.plateau_value),
463                                ],
464                                potential_causes: vec![
465                                    "Learning rate too low".to_string(),
466                                    "Model capacity insufficient".to_string(),
467                                    "Local minimum reached".to_string(),
468                                ],
469                                identified_at: chrono::Utc::now(),
470                            });
471                        }
472                    }
473                },
474                ConvergenceStatus::Oscillating => {
475                    self.add_issue(IdentifiedIssue {
476                        category: IssueCategory::NumericalStability,
477                        description: "Training is oscillating".to_string(),
478                        severity: IssueSeverity::Moderate,
479                        confidence: 0.7,
480                        evidence: vec!["Convergence status: Oscillating".to_string()],
481                        potential_causes: vec![
482                            "Learning rate too high".to_string(),
483                            "Batch size too small".to_string(),
484                            "Momentum settings suboptimal".to_string(),
485                        ],
486                        identified_at: chrono::Utc::now(),
487                    });
488                },
489                _ => {},
490            }
491        }
492
493        Ok(())
494    }
495
496    /// Analyze gradient flow issues.
497    fn analyze_gradient_flow_issues(&mut self) -> Result<()> {
498        let mut issues_to_add = Vec::new();
499
500        // Check layer statistics for gradient flow problems
501        for (layer_name, layer_history) in &self.layer_history {
502            if let Some(latest_stats) = layer_history.back() {
503                // Check for dead neurons
504                if latest_stats.dead_neurons_ratio > 0.5 {
505                    issues_to_add.push(IdentifiedIssue {
506                        category: IssueCategory::GradientFlow,
507                        description: format!("High dead neuron ratio in layer {}", layer_name),
508                        severity: IssueSeverity::Major,
509                        confidence: 0.85,
510                        evidence: vec![
511                            format!(
512                                "Dead neurons: {:.1}%",
513                                latest_stats.dead_neurons_ratio * 100.0
514                            ),
515                            format!("Layer: {}", layer_name),
516                        ],
517                        potential_causes: vec![
518                            "Dying ReLU problem".to_string(),
519                            "Poor weight initialization".to_string(),
520                            "Learning rate too high".to_string(),
521                        ],
522                        identified_at: chrono::Utc::now(),
523                    });
524                }
525
526                // Check for activation saturation
527                if latest_stats.saturated_neurons_ratio > 0.3 {
528                    issues_to_add.push(IdentifiedIssue {
529                        category: IssueCategory::GradientFlow,
530                        description: format!("High activation saturation in layer {}", layer_name),
531                        severity: IssueSeverity::Moderate,
532                        confidence: 0.8,
533                        evidence: vec![
534                            format!(
535                                "Saturated neurons: {:.1}%",
536                                latest_stats.saturated_neurons_ratio * 100.0
537                            ),
538                            format!("Layer: {}", layer_name),
539                        ],
540                        potential_causes: vec![
541                            "Vanishing gradient problem".to_string(),
542                            "Poor activation function choice".to_string(),
543                            "Input normalization issues".to_string(),
544                        ],
545                        identified_at: chrono::Utc::now(),
546                    });
547                }
548            }
549        }
550
551        // Add all collected issues
552        for issue in issues_to_add {
553            self.add_issue(issue);
554        }
555
556        Ok(())
557    }
558
559    /// Analyze layer health issues.
560    fn analyze_layer_health_issues(&mut self) -> Result<()> {
561        let mut issues_to_add = Vec::new();
562
563        for (layer_name, layer_history) in &self.layer_history {
564            if layer_history.len() >= 5 {
565                let recent_stats: Vec<_> = layer_history.iter().rev().take(5).collect();
566
567                // Check for activation variance trends
568                let variances: Vec<f64> = recent_stats.iter().map(|s| s.std_activation).collect();
569                let avg_variance = variances.iter().sum::<f64>() / variances.len() as f64;
570
571                if avg_variance < 0.01 {
572                    issues_to_add.push(IdentifiedIssue {
573                        category: IssueCategory::Architecture,
574                        description: format!("Low activation variance in layer {}", layer_name),
575                        severity: IssueSeverity::Minor,
576                        confidence: 0.6,
577                        evidence: vec![
578                            format!("Average variance: {:.4}", avg_variance),
579                            format!("Layer: {}", layer_name),
580                        ],
581                        potential_causes: vec![
582                            "Poor weight initialization".to_string(),
583                            "Input normalization too aggressive".to_string(),
584                            "Activation function saturation".to_string(),
585                        ],
586                        identified_at: chrono::Utc::now(),
587                    });
588                }
589            }
590        }
591
592        // Add all collected issues
593        for issue in issues_to_add {
594            self.add_issue(issue);
595        }
596
597        Ok(())
598    }
599
600    /// Analyze memory usage issues.
601    fn analyze_memory_issues(&mut self) -> Result<()> {
602        if self.performance_history.len() >= 10 {
603            let recent_memory: Vec<f64> = self
604                .performance_history
605                .iter()
606                .rev()
607                .take(10)
608                .map(|m| m.memory_usage_mb)
609                .collect();
610
611            // Check for memory leaks
612            let memory_trend = self.calculate_trend(&recent_memory);
613            if memory_trend > 10.0 {
614                // MB per step
615                self.add_issue(IdentifiedIssue {
616                    category: IssueCategory::Memory,
617                    description: "Memory leak detected".to_string(),
618                    severity: IssueSeverity::Critical,
619                    confidence: 0.9,
620                    evidence: vec![
621                        format!("Memory growth rate: {:.2} MB/step", memory_trend),
622                        "Increasing memory usage trend".to_string(),
623                    ],
624                    potential_causes: vec![
625                        "Gradient accumulation without clearing".to_string(),
626                        "Cached tensors not being released".to_string(),
627                        "Memory fragmentation".to_string(),
628                    ],
629                    identified_at: chrono::Utc::now(),
630                });
631            }
632
633            // Check for excessive memory usage
634            if let Some(max_memory) = recent_memory
635                .iter()
636                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
637            {
638                if *max_memory > 16384.0 {
639                    // 16GB
640                    self.add_issue(IdentifiedIssue {
641                        category: IssueCategory::Memory,
642                        description: "Excessive memory usage detected".to_string(),
643                        severity: IssueSeverity::Major,
644                        confidence: 0.8,
645                        evidence: vec![
646                            format!("Peak memory: {:.0} MB", max_memory),
647                            "High memory consumption".to_string(),
648                        ],
649                        potential_causes: vec![
650                            "Batch size too large".to_string(),
651                            "Model too large for available memory".to_string(),
652                            "Inefficient memory allocation".to_string(),
653                        ],
654                        identified_at: chrono::Utc::now(),
655                    });
656                }
657            }
658        }
659
660        Ok(())
661    }
662
663    /// Analyze overfitting and underfitting issues.
664    fn analyze_overfitting_underfitting(&mut self) -> Result<()> {
665        let mut issues_to_add = Vec::new();
666
667        if let Some(latest_dynamics) = self.dynamics_history.back() {
668            // Check for overfitting indicators
669            for indicator in &latest_dynamics.overfitting_indicators {
670                if let super::types::OverfittingIndicator::TrainValidationGap { gap } = indicator {
671                    if *gap > 0.1 {
672                        issues_to_add.push(IdentifiedIssue {
673                            category: IssueCategory::Overfitting,
674                            description: "Large training-validation gap detected".to_string(),
675                            severity: IssueSeverity::Major,
676                            confidence: 0.85,
677                            evidence: vec![
678                                format!("Train-validation gap: {:.3}", gap),
679                                "Overfitting indicator present".to_string(),
680                            ],
681                            potential_causes: vec![
682                                "Model complexity too high".to_string(),
683                                "Insufficient regularization".to_string(),
684                                "Training set too small".to_string(),
685                            ],
686                            identified_at: chrono::Utc::now(),
687                        });
688                    }
689                }
690            }
691
692            // Check for underfitting indicators
693            for indicator in &latest_dynamics.underfitting_indicators {
694                match indicator {
695                    super::types::UnderfittingIndicator::HighTrainingLoss { loss, threshold } => {
696                        issues_to_add.push(IdentifiedIssue {
697                            category: IssueCategory::Underfitting,
698                            description: "High training loss indicates underfitting".to_string(),
699                            severity: IssueSeverity::Moderate,
700                            confidence: 0.7,
701                            evidence: vec![
702                                format!("Training loss: {:.3}", loss),
703                                format!("Threshold: {:.3}", threshold),
704                            ],
705                            potential_causes: vec![
706                                "Model capacity too low".to_string(),
707                                "Learning rate too low".to_string(),
708                                "Insufficient training time".to_string(),
709                            ],
710                            identified_at: chrono::Utc::now(),
711                        });
712                    },
713                    super::types::UnderfittingIndicator::SlowConvergence {
714                        steps_taken,
715                        expected,
716                    } => {
717                        issues_to_add.push(IdentifiedIssue {
718                            category: IssueCategory::Underfitting,
719                            description: "Slow convergence detected".to_string(),
720                            severity: IssueSeverity::Minor,
721                            confidence: 0.6,
722                            evidence: vec![
723                                format!("Steps taken: {}", steps_taken),
724                                format!("Expected: {}", expected),
725                            ],
726                            potential_causes: vec![
727                                "Learning rate too conservative".to_string(),
728                                "Optimizer choice suboptimal".to_string(),
729                                "Poor initialization".to_string(),
730                            ],
731                            identified_at: chrono::Utc::now(),
732                        });
733                    },
734                    _ => {},
735                }
736            }
737        }
738
739        // Add all collected issues
740        for issue in issues_to_add {
741            self.add_issue(issue);
742        }
743
744        Ok(())
745    }
746
747    /// Generate recommendations based on identified issues.
748    fn generate_recommendations(&mut self) -> Result<()> {
749        for issue in &self.session_state.identified_issues {
750            let recommendations = self.generate_recommendations_for_issue(issue);
751            self.session_state.recommendations.extend(recommendations);
752        }
753
754        // Sort recommendations by priority and confidence
755        self.session_state.recommendations.sort_by(|a, b| {
756            b.priority
757                .partial_cmp(&a.priority)
758                .unwrap_or(std::cmp::Ordering::Equal)
759                .then(b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal))
760        });
761
762        Ok(())
763    }
764
765    /// Generate specific recommendations for an issue.
766    fn generate_recommendations_for_issue(
767        &self,
768        issue: &IdentifiedIssue,
769    ) -> Vec<DebuggingRecommendation> {
770        match issue.category {
771            IssueCategory::LearningRate => {
772                if issue.description.contains("too high") {
773                    vec![DebuggingRecommendation {
774                        category: RecommendationCategory::HyperparameterTuning,
775                        title: "Reduce Learning Rate".to_string(),
776                        description: "Lower the learning rate to stabilize training".to_string(),
777                        actions: vec![
778                            "Reduce learning rate by factor of 2-10".to_string(),
779                            "Enable gradient clipping".to_string(),
780                            "Consider learning rate scheduling".to_string(),
781                        ],
782                        expected_impact: "Stabilized training with reduced loss oscillations"
783                            .to_string(),
784                        confidence: 0.9,
785                        priority: AutoDebugRecommendationPriority::High,
786                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
787                            parameter_name: "learning_rate".to_string(),
788                            current_value: None,
789                            suggested_value: 0.0001,
790                            reasoning: "Reduce to prevent loss explosion".to_string(),
791                            expected_effect: "More stable training".to_string(),
792                        }],
793                    }]
794                } else if issue.description.contains("too low") {
795                    vec![DebuggingRecommendation {
796                        category: RecommendationCategory::HyperparameterTuning,
797                        title: "Increase Learning Rate".to_string(),
798                        description: "Increase learning rate to improve convergence speed"
799                            .to_string(),
800                        actions: vec![
801                            "Increase learning rate by factor of 2-5".to_string(),
802                            "Use learning rate warmup".to_string(),
803                            "Consider adaptive learning rate methods".to_string(),
804                        ],
805                        expected_impact: "Faster convergence and better final performance"
806                            .to_string(),
807                        confidence: 0.8,
808                        priority: AutoDebugRecommendationPriority::Medium,
809                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
810                            parameter_name: "learning_rate".to_string(),
811                            current_value: None,
812                            suggested_value: 0.001,
813                            reasoning: "Increase to improve learning speed".to_string(),
814                            expected_effect: "Faster convergence".to_string(),
815                        }],
816                    }]
817                } else {
818                    Vec::new()
819                }
820            },
821            IssueCategory::Memory => {
822                vec![DebuggingRecommendation {
823                    category: RecommendationCategory::ResourceOptimization,
824                    title: "Optimize Memory Usage".to_string(),
825                    description: "Implement memory optimization strategies".to_string(),
826                    actions: vec![
827                        "Reduce batch size".to_string(),
828                        "Enable gradient checkpointing".to_string(),
829                        "Clear cached tensors regularly".to_string(),
830                        "Use mixed precision training".to_string(),
831                    ],
832                    expected_impact: "Reduced memory consumption and stable training".to_string(),
833                    confidence: 0.85,
834                    priority: AutoDebugRecommendationPriority::High,
835                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
836                        parameter_name: "batch_size".to_string(),
837                        current_value: None,
838                        suggested_value: 16.0,
839                        reasoning: "Reduce to lower memory usage".to_string(),
840                        expected_effect: "Lower memory consumption".to_string(),
841                    }],
842                }]
843            },
844            IssueCategory::Overfitting => {
845                vec![DebuggingRecommendation {
846                    category: RecommendationCategory::TrainingStrategy,
847                    title: "Address Overfitting".to_string(),
848                    description: "Implement regularization strategies to reduce overfitting"
849                        .to_string(),
850                    actions: vec![
851                        "Add dropout layers".to_string(),
852                        "Increase weight decay".to_string(),
853                        "Use data augmentation".to_string(),
854                        "Reduce model complexity".to_string(),
855                        "Implement early stopping".to_string(),
856                    ],
857                    expected_impact: "Better generalization and validation performance".to_string(),
858                    confidence: 0.8,
859                    priority: AutoDebugRecommendationPriority::Medium,
860                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
861                        parameter_name: "dropout_rate".to_string(),
862                        current_value: None,
863                        suggested_value: 0.1,
864                        reasoning: "Add regularization to reduce overfitting".to_string(),
865                        expected_effect: "Better generalization".to_string(),
866                    }],
867                }]
868            },
869            IssueCategory::GradientFlow => {
870                vec![DebuggingRecommendation {
871                    category: RecommendationCategory::ArchitecturalModification,
872                    title: "Improve Gradient Flow".to_string(),
873                    description: "Address gradient flow issues in the network".to_string(),
874                    actions: vec![
875                        "Use different activation functions (e.g., Leaky ReLU, Swish)".to_string(),
876                        "Add batch normalization".to_string(),
877                        "Implement residual connections".to_string(),
878                        "Adjust weight initialization".to_string(),
879                    ],
880                    expected_impact: "Better gradient flow and training stability".to_string(),
881                    confidence: 0.75,
882                    priority: AutoDebugRecommendationPriority::Medium,
883                    hyperparameter_suggestions: Vec::new(),
884                }]
885            },
886            _ => Vec::new(),
887        }
888    }
889
890    /// Add an issue to the current session.
891    fn add_issue(&mut self, issue: IdentifiedIssue) {
892        self.session_state.identified_issues.push(issue);
893    }
894
895    /// Calculate variance of a sequence of values.
896    fn calculate_variance(&self, values: &[f64]) -> f64 {
897        if values.len() < 2 {
898            return 0.0;
899        }
900
901        let mean = values.iter().sum::<f64>() / values.len() as f64;
902        let variance =
903            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
904
905        variance
906    }
907
908    /// Calculate trend (slope) of a sequence of values.
909    fn calculate_trend(&self, values: &[f64]) -> f64 {
910        if values.len() < 2 {
911            return 0.0;
912        }
913
914        let n = values.len() as f64;
915        let x_mean = (n - 1.0) / 2.0;
916        let y_mean = values.iter().sum::<f64>() / n;
917
918        let numerator: f64 = values
919            .iter()
920            .enumerate()
921            .map(|(i, &y)| (i as f64 - x_mean) * (y - y_mean))
922            .sum();
923
924        let denominator: f64 = (0..values.len()).map(|i| (i as f64 - x_mean).powi(2)).sum();
925
926        if denominator == 0.0 {
927            0.0
928        } else {
929            numerator / denominator
930        }
931    }
932
933    /// Update session statistics.
934    fn update_session_statistics(&mut self) {
935        let mut issues_by_category = HashMap::new();
936        for issue in &self.session_state.identified_issues {
937            *issues_by_category.entry(issue.category.clone()).or_insert(0) += 1;
938        }
939
940        let avg_confidence = if self.session_state.recommendations.is_empty() {
941            0.0
942        } else {
943            self.session_state.recommendations.iter().map(|r| r.confidence).sum::<f64>()
944                / self.session_state.recommendations.len() as f64
945        };
946
947        self.session_state.session_stats = SessionStatistics {
948            total_issues: self.session_state.identified_issues.len(),
949            issues_by_category,
950            total_recommendations: self.session_state.recommendations.len(),
951            avg_recommendation_confidence: avg_confidence,
952            analysis_duration: self.session_state.session_stats.analysis_duration,
953        };
954    }
955
956    /// Generate analysis summary.
957    fn generate_analysis_summary(&self) -> String {
958        let critical_issues = self
959            .session_state
960            .identified_issues
961            .iter()
962            .filter(|i| i.severity == IssueSeverity::Critical)
963            .count();
964
965        let major_issues = self
966            .session_state
967            .identified_issues
968            .iter()
969            .filter(|i| i.severity == IssueSeverity::Major)
970            .count();
971
972        let high_priority_recommendations = self
973            .session_state
974            .recommendations
975            .iter()
976            .filter(|r| r.priority == AutoDebugRecommendationPriority::High)
977            .count();
978
979        format!(
980            "Auto-debugging analysis completed. Found {} critical issues, {} major issues. \
981            Generated {} recommendations with {} high-priority actions. \
982            Average recommendation confidence: {:.2}",
983            critical_issues,
984            major_issues,
985            self.session_state.recommendations.len(),
986            high_priority_recommendations,
987            self.session_state.session_stats.avg_recommendation_confidence
988        )
989    }
990}
991
992/// Comprehensive debugging report.
993#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
994pub struct DebuggingReport {
995    /// Session information
996    pub session_info: DebuggingSession,
997    /// All identified issues
998    pub identified_issues: Vec<IdentifiedIssue>,
999    /// Generated recommendations
1000    pub recommendations: Vec<DebuggingRecommendation>,
1001    /// Analysis summary
1002    pub summary: String,
1003}
1004
1005impl IssuePatternDatabase {
1006    /// Create a new pattern database with default patterns.
1007    pub fn new() -> Self {
1008        Self {
1009            learning_rate_patterns: Self::create_learning_rate_patterns(),
1010            gradient_patterns: Self::create_gradient_patterns(),
1011            convergence_patterns: Self::create_convergence_patterns(),
1012            layer_patterns: Self::create_layer_patterns(),
1013        }
1014    }
1015
1016    /// Create default learning rate patterns.
1017    fn create_learning_rate_patterns() -> Vec<IssuePattern> {
1018        vec![IssuePattern {
1019            name: "Loss Explosion".to_string(),
1020            description: "Rapid increase in loss indicating learning rate too high".to_string(),
1021            conditions: vec![PatternCondition {
1022                metric: "loss".to_string(),
1023                operator: ComparisonOperator::Increasing,
1024                threshold: 2.0,
1025                consecutive_count: 3,
1026            }],
1027            issue_category: IssueCategory::LearningRate,
1028            confidence_weight: 0.9,
1029            solutions: vec![
1030                "Reduce learning rate by factor of 10".to_string(),
1031                "Enable gradient clipping".to_string(),
1032            ],
1033        }]
1034    }
1035
1036    /// Create default gradient patterns.
1037    fn create_gradient_patterns() -> Vec<IssuePattern> {
1038        vec![]
1039    }
1040
1041    /// Create default convergence patterns.
1042    fn create_convergence_patterns() -> Vec<IssuePattern> {
1043        vec![]
1044    }
1045
1046    /// Create default layer patterns.
1047    fn create_layer_patterns() -> Vec<IssuePattern> {
1048        vec![]
1049    }
1050}
1051
1052impl DebuggingSession {
1053    /// Create a new debugging session.
1054    fn new() -> Self {
1055        Self {
1056            session_start: chrono::Utc::now(),
1057            identified_issues: Vec::new(),
1058            recommendations: Vec::new(),
1059            session_stats: SessionStatistics {
1060                total_issues: 0,
1061                issues_by_category: HashMap::new(),
1062                total_recommendations: 0,
1063                avg_recommendation_confidence: 0.0,
1064                analysis_duration: chrono::Duration::zero(),
1065            },
1066        }
1067    }
1068}
1069
1070impl Default for AutoDebugger {
1071    fn default() -> Self {
1072        Self::new()
1073    }
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079
1080    #[test]
1081    fn test_auto_debugger_creation() {
1082        let debugger = AutoDebugger::new();
1083        assert_eq!(debugger.performance_history.len(), 0);
1084        assert_eq!(debugger.layer_history.len(), 0);
1085    }
1086
1087    #[test]
1088    fn test_issue_identification() {
1089        let mut debugger = AutoDebugger::new();
1090
1091        let issue = IdentifiedIssue {
1092            category: IssueCategory::LearningRate,
1093            description: "Test issue".to_string(),
1094            severity: IssueSeverity::Major,
1095            confidence: 0.8,
1096            evidence: vec!["Test evidence".to_string()],
1097            potential_causes: vec!["Test cause".to_string()],
1098            identified_at: chrono::Utc::now(),
1099        };
1100
1101        debugger.add_issue(issue);
1102        assert_eq!(debugger.session_state.identified_issues.len(), 1);
1103    }
1104
1105    #[test]
1106    fn test_variance_calculation() {
1107        let debugger = AutoDebugger::new();
1108        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1109        let variance = debugger.calculate_variance(&values);
1110        assert!(variance > 0.0);
1111    }
1112
1113    #[test]
1114    fn test_trend_calculation() {
1115        let debugger = AutoDebugger::new();
1116        let increasing_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1117        let trend = debugger.calculate_trend(&increasing_values);
1118        assert!(trend > 0.0);
1119    }
1120}