trustformers_debug/model_diagnostics/
auto_debug.rs

1//! Auto-debugging and automated recommendation system.
2//!
3//! This module provides intelligent debugging capabilities that automatically
4//! analyze model behavior, identify potential issues, and generate actionable
5//! recommendations for optimization and problem resolution.
6
7use anyhow::Result;
8use std::collections::{HashMap, VecDeque};
9
10use super::types::{
11    ConvergenceStatus, LayerActivationStats, ModelPerformanceMetrics, TrainingDynamics,
12};
13
14/// Auto-debugging system for intelligent model analysis.
15#[derive(Debug)]
16pub struct AutoDebugger {
17    /// Debugging configuration
18    config: AutoDebugConfig,
19    /// Historical performance data for analysis
20    performance_history: VecDeque<ModelPerformanceMetrics>,
21    /// Layer statistics history
22    layer_history: HashMap<String, VecDeque<LayerActivationStats>>,
23    /// Training dynamics history
24    dynamics_history: VecDeque<TrainingDynamics>,
25    /// Known issue patterns and solutions
26    #[allow(dead_code)]
27    issue_patterns: IssuePatternDatabase,
28    /// Current debugging session state
29    session_state: DebuggingSession,
30}
31
32/// Configuration for auto-debugging system.
33#[derive(Debug, Clone)]
34pub struct AutoDebugConfig {
35    /// Maximum history size for analysis
36    pub max_history_size: usize,
37    /// Minimum samples required for pattern detection
38    pub min_samples_for_analysis: usize,
39    /// Confidence threshold for recommendations
40    pub recommendation_confidence_threshold: f64,
41    /// Enable advanced pattern recognition
42    pub enable_advanced_patterns: bool,
43    /// Enable hyperparameter suggestions
44    pub enable_hyperparameter_suggestions: bool,
45    /// Enable architectural recommendations
46    pub enable_architectural_recommendations: bool,
47}
48
49/// Current debugging session state.
50#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
51pub struct DebuggingSession {
52    /// Session start time
53    pub session_start: chrono::DateTime<chrono::Utc>,
54    /// Issues identified in current session
55    pub identified_issues: Vec<IdentifiedIssue>,
56    /// Recommendations generated
57    pub recommendations: Vec<DebuggingRecommendation>,
58    /// Session statistics
59    pub session_stats: SessionStatistics,
60}
61
62/// An identified issue with diagnostic information.
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct IdentifiedIssue {
65    /// Issue category
66    pub category: IssueCategory,
67    /// Issue description
68    pub description: String,
69    /// Severity level
70    pub severity: IssueSeverity,
71    /// Confidence in identification
72    pub confidence: f64,
73    /// Evidence supporting the identification
74    pub evidence: Vec<String>,
75    /// Potential causes
76    pub potential_causes: Vec<String>,
77    /// When issue was identified
78    pub identified_at: chrono::DateTime<chrono::Utc>,
79}
80
81/// Categories of issues that can be identified.
82#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
83pub enum IssueCategory {
84    /// Learning rate related issues
85    LearningRate,
86    /// Gradient flow problems
87    GradientFlow,
88    /// Overfitting issues
89    Overfitting,
90    /// Underfitting issues
91    Underfitting,
92    /// Data quality problems
93    DataQuality,
94    /// Architecture inefficiencies
95    Architecture,
96    /// Memory management issues
97    Memory,
98    /// Convergence problems
99    Convergence,
100    /// Numerical stability issues
101    NumericalStability,
102}
103
104/// Severity levels for identified issues.
105#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
106pub enum IssueSeverity {
107    /// Minor issue with low impact
108    Minor,
109    /// Moderate issue affecting performance
110    Moderate,
111    /// Major issue requiring attention
112    Major,
113    /// Critical issue requiring immediate action
114    Critical,
115}
116
117/// Auto-generated debugging recommendation.
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct DebuggingRecommendation {
120    /// Recommendation category
121    pub category: RecommendationCategory,
122    /// Recommendation title
123    pub title: String,
124    /// Detailed description
125    pub description: String,
126    /// Specific actions to take
127    pub actions: Vec<String>,
128    /// Expected impact
129    pub expected_impact: String,
130    /// Confidence in recommendation
131    pub confidence: f64,
132    /// Priority level
133    pub priority: AutoDebugRecommendationPriority,
134    /// Relevant hyperparameters to adjust
135    pub hyperparameter_suggestions: Vec<HyperparameterSuggestion>,
136}
137
138/// Categories of recommendations.
139#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub enum RecommendationCategory {
141    /// Hyperparameter adjustments
142    HyperparameterTuning,
143    /// Architectural changes
144    ArchitecturalModification,
145    /// Data preprocessing recommendations
146    DataPreprocessing,
147    /// Training strategy changes
148    TrainingStrategy,
149    /// Debugging and monitoring
150    DebuggingAndMonitoring,
151    /// Resource optimization
152    ResourceOptimization,
153}
154
155/// Priority levels for recommendations.
156#[derive(Debug, Clone, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize)]
157pub enum AutoDebugRecommendationPriority {
158    /// Low priority suggestion
159    Low,
160    /// Medium priority recommendation
161    Medium,
162    /// High priority action needed
163    High,
164    /// Urgent action required
165    Urgent,
166}
167
168/// Hyperparameter adjustment suggestion.
169#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
170pub struct HyperparameterSuggestion {
171    /// Parameter name
172    pub parameter_name: String,
173    /// Current value (if known)
174    pub current_value: Option<f64>,
175    /// Suggested value
176    pub suggested_value: f64,
177    /// Adjustment reasoning
178    pub reasoning: String,
179    /// Expected effect
180    pub expected_effect: String,
181}
182
183/// Database of known issue patterns and solutions.
184#[derive(Debug, Clone)]
185pub struct IssuePatternDatabase {
186    /// Learning rate patterns
187    pub learning_rate_patterns: Vec<IssuePattern>,
188    /// Gradient patterns
189    pub gradient_patterns: Vec<IssuePattern>,
190    /// Convergence patterns
191    pub convergence_patterns: Vec<IssuePattern>,
192    /// Layer behavior patterns
193    pub layer_patterns: Vec<IssuePattern>,
194}
195
196/// Pattern for identifying specific issues.
197#[derive(Debug, Clone)]
198pub struct IssuePattern {
199    /// Pattern name
200    pub name: String,
201    /// Pattern description
202    pub description: String,
203    /// Conditions that must be met
204    pub conditions: Vec<PatternCondition>,
205    /// Associated issue category
206    pub issue_category: IssueCategory,
207    /// Confidence weight for this pattern
208    pub confidence_weight: f64,
209    /// Recommended solutions
210    pub solutions: Vec<String>,
211}
212
213/// Condition for pattern matching.
214#[derive(Debug, Clone)]
215pub struct PatternCondition {
216    /// Metric name
217    pub metric: String,
218    /// Comparison operator
219    pub operator: ComparisonOperator,
220    /// Threshold value
221    pub threshold: f64,
222    /// Number of consecutive occurrences required
223    pub consecutive_count: usize,
224}
225
226/// Comparison operators for pattern conditions.
227#[derive(Debug, Clone)]
228pub enum ComparisonOperator {
229    /// Greater than
230    GreaterThan,
231    /// Less than
232    LessThan,
233    /// Equal to (within tolerance)
234    EqualTo,
235    /// Increasing trend
236    Increasing,
237    /// Decreasing trend
238    Decreasing,
239    /// Oscillating pattern
240    Oscillating,
241}
242
243/// Session statistics for debugging analysis.
244#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
245pub struct SessionStatistics {
246    /// Total issues identified
247    pub total_issues: usize,
248    /// Issues by category
249    pub issues_by_category: HashMap<IssueCategory, usize>,
250    /// Total recommendations generated
251    pub total_recommendations: usize,
252    /// Average confidence of recommendations
253    pub avg_recommendation_confidence: f64,
254    /// Analysis time taken
255    pub analysis_duration: chrono::Duration,
256}
257
258impl Default for AutoDebugConfig {
259    fn default() -> Self {
260        Self {
261            max_history_size: 1000,
262            min_samples_for_analysis: 10,
263            recommendation_confidence_threshold: 0.7,
264            enable_advanced_patterns: true,
265            enable_hyperparameter_suggestions: true,
266            enable_architectural_recommendations: true,
267        }
268    }
269}
270
271impl AutoDebugger {
272    /// Create a new auto-debugger.
273    pub fn new() -> Self {
274        Self {
275            config: AutoDebugConfig::default(),
276            performance_history: VecDeque::new(),
277            layer_history: HashMap::new(),
278            dynamics_history: VecDeque::new(),
279            issue_patterns: IssuePatternDatabase::new(),
280            session_state: DebuggingSession::new(),
281        }
282    }
283
284    /// Create auto-debugger with custom configuration.
285    pub fn with_config(config: AutoDebugConfig) -> Self {
286        Self {
287            config,
288            performance_history: VecDeque::new(),
289            layer_history: HashMap::new(),
290            dynamics_history: VecDeque::new(),
291            issue_patterns: IssuePatternDatabase::new(),
292            session_state: DebuggingSession::new(),
293        }
294    }
295
296    /// Record performance metrics for analysis.
297    pub fn record_performance_metrics(&mut self, metrics: ModelPerformanceMetrics) {
298        self.performance_history.push_back(metrics);
299
300        while self.performance_history.len() > self.config.max_history_size {
301            self.performance_history.pop_front();
302        }
303    }
304
305    /// Record layer statistics for analysis.
306    pub fn record_layer_stats(&mut self, stats: LayerActivationStats) {
307        let layer_name = stats.layer_name.clone();
308
309        let layer_history = self.layer_history.entry(layer_name).or_insert_with(VecDeque::new);
310        layer_history.push_back(stats);
311
312        while layer_history.len() > self.config.max_history_size {
313            layer_history.pop_front();
314        }
315    }
316
317    /// Record training dynamics for analysis.
318    pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) {
319        self.dynamics_history.push_back(dynamics);
320
321        while self.dynamics_history.len() > self.config.max_history_size {
322            self.dynamics_history.pop_front();
323        }
324    }
325
326    /// Perform comprehensive auto-debugging analysis.
327    pub fn perform_analysis(&mut self) -> Result<DebuggingReport> {
328        let analysis_start = chrono::Utc::now();
329
330        if self.performance_history.len() < self.config.min_samples_for_analysis {
331            return Err(anyhow::anyhow!("Insufficient data for analysis"));
332        }
333
334        // Clear previous session state
335        self.session_state = DebuggingSession::new();
336
337        // Analyze different aspects
338        self.analyze_learning_rate_issues()?;
339        self.analyze_convergence_issues()?;
340        self.analyze_gradient_flow_issues()?;
341        self.analyze_layer_health_issues()?;
342        self.analyze_memory_issues()?;
343        self.analyze_overfitting_underfitting()?;
344
345        // Generate recommendations based on identified issues
346        self.generate_recommendations()?;
347
348        // Update session statistics
349        self.session_state.session_stats.analysis_duration = chrono::Utc::now() - analysis_start;
350        self.update_session_statistics();
351
352        Ok(DebuggingReport {
353            session_info: self.session_state.clone(),
354            identified_issues: self.session_state.identified_issues.clone(),
355            recommendations: self.session_state.recommendations.clone(),
356            summary: self.generate_analysis_summary(),
357        })
358    }
359
360    /// Analyze learning rate related issues.
361    fn analyze_learning_rate_issues(&mut self) -> Result<()> {
362        let recent_metrics: Vec<_> = self.performance_history.iter().rev().take(20).collect();
363        if recent_metrics.len() < 10 {
364            return Ok(());
365        }
366
367        let mut issues_to_add = Vec::new();
368
369        // Check for loss explosion
370        let recent_losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
371        if let Some(max_loss) = recent_losses.iter().max_by(|a, b| a.partial_cmp(b).unwrap()) {
372            if let Some(min_loss) = recent_losses.iter().min_by(|a, b| a.partial_cmp(b).unwrap()) {
373                if max_loss / min_loss > 10.0 {
374                    issues_to_add.push(IdentifiedIssue {
375                        category: IssueCategory::LearningRate,
376                        description: "Learning rate too high - loss explosion detected".to_string(),
377                        severity: IssueSeverity::Critical,
378                        confidence: 0.9,
379                        evidence: vec![
380                            format!("Loss ratio: {:.2}", max_loss / min_loss),
381                            "Rapid loss increase observed".to_string(),
382                        ],
383                        potential_causes: vec![
384                            "Learning rate set too high".to_string(),
385                            "Gradient clipping disabled".to_string(),
386                            "Numerical instability".to_string(),
387                        ],
388                        identified_at: chrono::Utc::now(),
389                    });
390                }
391            }
392        }
393
394        // Check for learning stagnation
395        let loss_variance = self.calculate_variance(&recent_losses);
396        let recent_metrics_len = recent_metrics.len();
397        if loss_variance < 1e-6 && recent_metrics_len >= 15 {
398            issues_to_add.push(IdentifiedIssue {
399                category: IssueCategory::LearningRate,
400                description: "Learning rate too low - training stagnation".to_string(),
401                severity: IssueSeverity::Major,
402                confidence: 0.8,
403                evidence: vec![
404                    format!("Loss variance: {:.2e}", loss_variance),
405                    "No learning progress in recent steps".to_string(),
406                ],
407                potential_causes: vec![
408                    "Learning rate set too low".to_string(),
409                    "Learning rate decay too aggressive".to_string(),
410                    "Model has converged".to_string(),
411                ],
412                identified_at: chrono::Utc::now(),
413            });
414        }
415
416        // Add all collected issues
417        for issue in issues_to_add {
418            self.add_issue(issue);
419        }
420
421        Ok(())
422    }
423
424    /// Analyze convergence related issues.
425    fn analyze_convergence_issues(&mut self) -> Result<()> {
426        if let Some(latest_dynamics) = self.dynamics_history.back() {
427            match latest_dynamics.convergence_status {
428                ConvergenceStatus::Diverging => {
429                    self.add_issue(IdentifiedIssue {
430                        category: IssueCategory::Convergence,
431                        description: "Training is diverging".to_string(),
432                        severity: IssueSeverity::Critical,
433                        confidence: 0.95,
434                        evidence: vec!["Convergence status: Diverging".to_string()],
435                        potential_causes: vec![
436                            "Learning rate too high".to_string(),
437                            "Gradient explosion".to_string(),
438                            "Numerical instability".to_string(),
439                        ],
440                        identified_at: chrono::Utc::now(),
441                    });
442                },
443                ConvergenceStatus::Plateau => {
444                    if let Some(plateau_info) = &latest_dynamics.plateau_detection {
445                        if plateau_info.duration_steps > 100 {
446                            self.add_issue(IdentifiedIssue {
447                                category: IssueCategory::Convergence,
448                                description: "Training has plateaued".to_string(),
449                                severity: IssueSeverity::Moderate,
450                                confidence: 0.8,
451                                evidence: vec![
452                                    format!(
453                                        "Plateau duration: {} steps",
454                                        plateau_info.duration_steps
455                                    ),
456                                    format!("Plateau value: {:.4}", plateau_info.plateau_value),
457                                ],
458                                potential_causes: vec![
459                                    "Learning rate too low".to_string(),
460                                    "Model capacity insufficient".to_string(),
461                                    "Local minimum reached".to_string(),
462                                ],
463                                identified_at: chrono::Utc::now(),
464                            });
465                        }
466                    }
467                },
468                ConvergenceStatus::Oscillating => {
469                    self.add_issue(IdentifiedIssue {
470                        category: IssueCategory::NumericalStability,
471                        description: "Training is oscillating".to_string(),
472                        severity: IssueSeverity::Moderate,
473                        confidence: 0.7,
474                        evidence: vec!["Convergence status: Oscillating".to_string()],
475                        potential_causes: vec![
476                            "Learning rate too high".to_string(),
477                            "Batch size too small".to_string(),
478                            "Momentum settings suboptimal".to_string(),
479                        ],
480                        identified_at: chrono::Utc::now(),
481                    });
482                },
483                _ => {},
484            }
485        }
486
487        Ok(())
488    }
489
490    /// Analyze gradient flow issues.
491    fn analyze_gradient_flow_issues(&mut self) -> Result<()> {
492        let mut issues_to_add = Vec::new();
493
494        // Check layer statistics for gradient flow problems
495        for (layer_name, layer_history) in &self.layer_history {
496            if let Some(latest_stats) = layer_history.back() {
497                // Check for dead neurons
498                if latest_stats.dead_neurons_ratio > 0.5 {
499                    issues_to_add.push(IdentifiedIssue {
500                        category: IssueCategory::GradientFlow,
501                        description: format!("High dead neuron ratio in layer {}", layer_name),
502                        severity: IssueSeverity::Major,
503                        confidence: 0.85,
504                        evidence: vec![
505                            format!(
506                                "Dead neurons: {:.1}%",
507                                latest_stats.dead_neurons_ratio * 100.0
508                            ),
509                            format!("Layer: {}", layer_name),
510                        ],
511                        potential_causes: vec![
512                            "Dying ReLU problem".to_string(),
513                            "Poor weight initialization".to_string(),
514                            "Learning rate too high".to_string(),
515                        ],
516                        identified_at: chrono::Utc::now(),
517                    });
518                }
519
520                // Check for activation saturation
521                if latest_stats.saturated_neurons_ratio > 0.3 {
522                    issues_to_add.push(IdentifiedIssue {
523                        category: IssueCategory::GradientFlow,
524                        description: format!("High activation saturation in layer {}", layer_name),
525                        severity: IssueSeverity::Moderate,
526                        confidence: 0.8,
527                        evidence: vec![
528                            format!(
529                                "Saturated neurons: {:.1}%",
530                                latest_stats.saturated_neurons_ratio * 100.0
531                            ),
532                            format!("Layer: {}", layer_name),
533                        ],
534                        potential_causes: vec![
535                            "Vanishing gradient problem".to_string(),
536                            "Poor activation function choice".to_string(),
537                            "Input normalization issues".to_string(),
538                        ],
539                        identified_at: chrono::Utc::now(),
540                    });
541                }
542            }
543        }
544
545        // Add all collected issues
546        for issue in issues_to_add {
547            self.add_issue(issue);
548        }
549
550        Ok(())
551    }
552
553    /// Analyze layer health issues.
554    fn analyze_layer_health_issues(&mut self) -> Result<()> {
555        let mut issues_to_add = Vec::new();
556
557        for (layer_name, layer_history) in &self.layer_history {
558            if layer_history.len() >= 5 {
559                let recent_stats: Vec<_> = layer_history.iter().rev().take(5).collect();
560
561                // Check for activation variance trends
562                let variances: Vec<f64> = recent_stats.iter().map(|s| s.std_activation).collect();
563                let avg_variance = variances.iter().sum::<f64>() / variances.len() as f64;
564
565                if avg_variance < 0.01 {
566                    issues_to_add.push(IdentifiedIssue {
567                        category: IssueCategory::Architecture,
568                        description: format!("Low activation variance in layer {}", layer_name),
569                        severity: IssueSeverity::Minor,
570                        confidence: 0.6,
571                        evidence: vec![
572                            format!("Average variance: {:.4}", avg_variance),
573                            format!("Layer: {}", layer_name),
574                        ],
575                        potential_causes: vec![
576                            "Poor weight initialization".to_string(),
577                            "Input normalization too aggressive".to_string(),
578                            "Activation function saturation".to_string(),
579                        ],
580                        identified_at: chrono::Utc::now(),
581                    });
582                }
583            }
584        }
585
586        // Add all collected issues
587        for issue in issues_to_add {
588            self.add_issue(issue);
589        }
590
591        Ok(())
592    }
593
594    /// Analyze memory usage issues.
595    fn analyze_memory_issues(&mut self) -> Result<()> {
596        if self.performance_history.len() >= 10 {
597            let recent_memory: Vec<f64> = self
598                .performance_history
599                .iter()
600                .rev()
601                .take(10)
602                .map(|m| m.memory_usage_mb)
603                .collect();
604
605            // Check for memory leaks
606            let memory_trend = self.calculate_trend(&recent_memory);
607            if memory_trend > 10.0 {
608                // MB per step
609                self.add_issue(IdentifiedIssue {
610                    category: IssueCategory::Memory,
611                    description: "Memory leak detected".to_string(),
612                    severity: IssueSeverity::Critical,
613                    confidence: 0.9,
614                    evidence: vec![
615                        format!("Memory growth rate: {:.2} MB/step", memory_trend),
616                        "Increasing memory usage trend".to_string(),
617                    ],
618                    potential_causes: vec![
619                        "Gradient accumulation without clearing".to_string(),
620                        "Cached tensors not being released".to_string(),
621                        "Memory fragmentation".to_string(),
622                    ],
623                    identified_at: chrono::Utc::now(),
624                });
625            }
626
627            // Check for excessive memory usage
628            if let Some(max_memory) = recent_memory.iter().max_by(|a, b| a.partial_cmp(b).unwrap())
629            {
630                if *max_memory > 16384.0 {
631                    // 16GB
632                    self.add_issue(IdentifiedIssue {
633                        category: IssueCategory::Memory,
634                        description: "Excessive memory usage detected".to_string(),
635                        severity: IssueSeverity::Major,
636                        confidence: 0.8,
637                        evidence: vec![
638                            format!("Peak memory: {:.0} MB", max_memory),
639                            "High memory consumption".to_string(),
640                        ],
641                        potential_causes: vec![
642                            "Batch size too large".to_string(),
643                            "Model too large for available memory".to_string(),
644                            "Inefficient memory allocation".to_string(),
645                        ],
646                        identified_at: chrono::Utc::now(),
647                    });
648                }
649            }
650        }
651
652        Ok(())
653    }
654
655    /// Analyze overfitting and underfitting issues.
656    fn analyze_overfitting_underfitting(&mut self) -> Result<()> {
657        let mut issues_to_add = Vec::new();
658
659        if let Some(latest_dynamics) = self.dynamics_history.back() {
660            // Check for overfitting indicators
661            for indicator in &latest_dynamics.overfitting_indicators {
662                match indicator {
663                    super::types::OverfittingIndicator::TrainValidationGap { gap } => {
664                        if *gap > 0.1 {
665                            issues_to_add.push(IdentifiedIssue {
666                                category: IssueCategory::Overfitting,
667                                description: "Large training-validation gap detected".to_string(),
668                                severity: IssueSeverity::Major,
669                                confidence: 0.85,
670                                evidence: vec![
671                                    format!("Train-validation gap: {:.3}", gap),
672                                    "Overfitting indicator present".to_string(),
673                                ],
674                                potential_causes: vec![
675                                    "Model complexity too high".to_string(),
676                                    "Insufficient regularization".to_string(),
677                                    "Training set too small".to_string(),
678                                ],
679                                identified_at: chrono::Utc::now(),
680                            });
681                        }
682                    },
683                    _ => {},
684                }
685            }
686
687            // Check for underfitting indicators
688            for indicator in &latest_dynamics.underfitting_indicators {
689                match indicator {
690                    super::types::UnderfittingIndicator::HighTrainingLoss { loss, threshold } => {
691                        issues_to_add.push(IdentifiedIssue {
692                            category: IssueCategory::Underfitting,
693                            description: "High training loss indicates underfitting".to_string(),
694                            severity: IssueSeverity::Moderate,
695                            confidence: 0.7,
696                            evidence: vec![
697                                format!("Training loss: {:.3}", loss),
698                                format!("Threshold: {:.3}", threshold),
699                            ],
700                            potential_causes: vec![
701                                "Model capacity too low".to_string(),
702                                "Learning rate too low".to_string(),
703                                "Insufficient training time".to_string(),
704                            ],
705                            identified_at: chrono::Utc::now(),
706                        });
707                    },
708                    super::types::UnderfittingIndicator::SlowConvergence {
709                        steps_taken,
710                        expected,
711                    } => {
712                        issues_to_add.push(IdentifiedIssue {
713                            category: IssueCategory::Underfitting,
714                            description: "Slow convergence detected".to_string(),
715                            severity: IssueSeverity::Minor,
716                            confidence: 0.6,
717                            evidence: vec![
718                                format!("Steps taken: {}", steps_taken),
719                                format!("Expected: {}", expected),
720                            ],
721                            potential_causes: vec![
722                                "Learning rate too conservative".to_string(),
723                                "Optimizer choice suboptimal".to_string(),
724                                "Poor initialization".to_string(),
725                            ],
726                            identified_at: chrono::Utc::now(),
727                        });
728                    },
729                    _ => {},
730                }
731            }
732        }
733
734        // Add all collected issues
735        for issue in issues_to_add {
736            self.add_issue(issue);
737        }
738
739        Ok(())
740    }
741
742    /// Generate recommendations based on identified issues.
743    fn generate_recommendations(&mut self) -> Result<()> {
744        for issue in &self.session_state.identified_issues {
745            let recommendations = self.generate_recommendations_for_issue(issue);
746            self.session_state.recommendations.extend(recommendations);
747        }
748
749        // Sort recommendations by priority and confidence
750        self.session_state.recommendations.sort_by(|a, b| {
751            b.priority
752                .partial_cmp(&a.priority)
753                .unwrap()
754                .then(b.confidence.partial_cmp(&a.confidence).unwrap())
755        });
756
757        Ok(())
758    }
759
760    /// Generate specific recommendations for an issue.
761    fn generate_recommendations_for_issue(
762        &self,
763        issue: &IdentifiedIssue,
764    ) -> Vec<DebuggingRecommendation> {
765        match issue.category {
766            IssueCategory::LearningRate => {
767                if issue.description.contains("too high") {
768                    vec![DebuggingRecommendation {
769                        category: RecommendationCategory::HyperparameterTuning,
770                        title: "Reduce Learning Rate".to_string(),
771                        description: "Lower the learning rate to stabilize training".to_string(),
772                        actions: vec![
773                            "Reduce learning rate by factor of 2-10".to_string(),
774                            "Enable gradient clipping".to_string(),
775                            "Consider learning rate scheduling".to_string(),
776                        ],
777                        expected_impact: "Stabilized training with reduced loss oscillations"
778                            .to_string(),
779                        confidence: 0.9,
780                        priority: AutoDebugRecommendationPriority::High,
781                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
782                            parameter_name: "learning_rate".to_string(),
783                            current_value: None,
784                            suggested_value: 0.0001,
785                            reasoning: "Reduce to prevent loss explosion".to_string(),
786                            expected_effect: "More stable training".to_string(),
787                        }],
788                    }]
789                } else if issue.description.contains("too low") {
790                    vec![DebuggingRecommendation {
791                        category: RecommendationCategory::HyperparameterTuning,
792                        title: "Increase Learning Rate".to_string(),
793                        description: "Increase learning rate to improve convergence speed"
794                            .to_string(),
795                        actions: vec![
796                            "Increase learning rate by factor of 2-5".to_string(),
797                            "Use learning rate warmup".to_string(),
798                            "Consider adaptive learning rate methods".to_string(),
799                        ],
800                        expected_impact: "Faster convergence and better final performance"
801                            .to_string(),
802                        confidence: 0.8,
803                        priority: AutoDebugRecommendationPriority::Medium,
804                        hyperparameter_suggestions: vec![HyperparameterSuggestion {
805                            parameter_name: "learning_rate".to_string(),
806                            current_value: None,
807                            suggested_value: 0.001,
808                            reasoning: "Increase to improve learning speed".to_string(),
809                            expected_effect: "Faster convergence".to_string(),
810                        }],
811                    }]
812                } else {
813                    Vec::new()
814                }
815            },
816            IssueCategory::Memory => {
817                vec![DebuggingRecommendation {
818                    category: RecommendationCategory::ResourceOptimization,
819                    title: "Optimize Memory Usage".to_string(),
820                    description: "Implement memory optimization strategies".to_string(),
821                    actions: vec![
822                        "Reduce batch size".to_string(),
823                        "Enable gradient checkpointing".to_string(),
824                        "Clear cached tensors regularly".to_string(),
825                        "Use mixed precision training".to_string(),
826                    ],
827                    expected_impact: "Reduced memory consumption and stable training".to_string(),
828                    confidence: 0.85,
829                    priority: AutoDebugRecommendationPriority::High,
830                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
831                        parameter_name: "batch_size".to_string(),
832                        current_value: None,
833                        suggested_value: 16.0,
834                        reasoning: "Reduce to lower memory usage".to_string(),
835                        expected_effect: "Lower memory consumption".to_string(),
836                    }],
837                }]
838            },
839            IssueCategory::Overfitting => {
840                vec![DebuggingRecommendation {
841                    category: RecommendationCategory::TrainingStrategy,
842                    title: "Address Overfitting".to_string(),
843                    description: "Implement regularization strategies to reduce overfitting"
844                        .to_string(),
845                    actions: vec![
846                        "Add dropout layers".to_string(),
847                        "Increase weight decay".to_string(),
848                        "Use data augmentation".to_string(),
849                        "Reduce model complexity".to_string(),
850                        "Implement early stopping".to_string(),
851                    ],
852                    expected_impact: "Better generalization and validation performance".to_string(),
853                    confidence: 0.8,
854                    priority: AutoDebugRecommendationPriority::Medium,
855                    hyperparameter_suggestions: vec![HyperparameterSuggestion {
856                        parameter_name: "dropout_rate".to_string(),
857                        current_value: None,
858                        suggested_value: 0.1,
859                        reasoning: "Add regularization to reduce overfitting".to_string(),
860                        expected_effect: "Better generalization".to_string(),
861                    }],
862                }]
863            },
864            IssueCategory::GradientFlow => {
865                vec![DebuggingRecommendation {
866                    category: RecommendationCategory::ArchitecturalModification,
867                    title: "Improve Gradient Flow".to_string(),
868                    description: "Address gradient flow issues in the network".to_string(),
869                    actions: vec![
870                        "Use different activation functions (e.g., Leaky ReLU, Swish)".to_string(),
871                        "Add batch normalization".to_string(),
872                        "Implement residual connections".to_string(),
873                        "Adjust weight initialization".to_string(),
874                    ],
875                    expected_impact: "Better gradient flow and training stability".to_string(),
876                    confidence: 0.75,
877                    priority: AutoDebugRecommendationPriority::Medium,
878                    hyperparameter_suggestions: Vec::new(),
879                }]
880            },
881            _ => Vec::new(),
882        }
883    }
884
885    /// Add an issue to the current session.
886    fn add_issue(&mut self, issue: IdentifiedIssue) {
887        self.session_state.identified_issues.push(issue);
888    }
889
890    /// Calculate variance of a sequence of values.
891    fn calculate_variance(&self, values: &[f64]) -> f64 {
892        if values.len() < 2 {
893            return 0.0;
894        }
895
896        let mean = values.iter().sum::<f64>() / values.len() as f64;
897        let variance =
898            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
899
900        variance
901    }
902
903    /// Calculate trend (slope) of a sequence of values.
904    fn calculate_trend(&self, values: &[f64]) -> f64 {
905        if values.len() < 2 {
906            return 0.0;
907        }
908
909        let n = values.len() as f64;
910        let x_mean = (n - 1.0) / 2.0;
911        let y_mean = values.iter().sum::<f64>() / n;
912
913        let numerator: f64 = values
914            .iter()
915            .enumerate()
916            .map(|(i, &y)| (i as f64 - x_mean) * (y - y_mean))
917            .sum();
918
919        let denominator: f64 = (0..values.len()).map(|i| (i as f64 - x_mean).powi(2)).sum();
920
921        if denominator == 0.0 {
922            0.0
923        } else {
924            numerator / denominator
925        }
926    }
927
928    /// Update session statistics.
929    fn update_session_statistics(&mut self) {
930        let mut issues_by_category = HashMap::new();
931        for issue in &self.session_state.identified_issues {
932            *issues_by_category.entry(issue.category.clone()).or_insert(0) += 1;
933        }
934
935        let avg_confidence = if self.session_state.recommendations.is_empty() {
936            0.0
937        } else {
938            self.session_state.recommendations.iter().map(|r| r.confidence).sum::<f64>()
939                / self.session_state.recommendations.len() as f64
940        };
941
942        self.session_state.session_stats = SessionStatistics {
943            total_issues: self.session_state.identified_issues.len(),
944            issues_by_category,
945            total_recommendations: self.session_state.recommendations.len(),
946            avg_recommendation_confidence: avg_confidence,
947            analysis_duration: self.session_state.session_stats.analysis_duration,
948        };
949    }
950
951    /// Generate analysis summary.
952    fn generate_analysis_summary(&self) -> String {
953        let critical_issues = self
954            .session_state
955            .identified_issues
956            .iter()
957            .filter(|i| i.severity == IssueSeverity::Critical)
958            .count();
959
960        let major_issues = self
961            .session_state
962            .identified_issues
963            .iter()
964            .filter(|i| i.severity == IssueSeverity::Major)
965            .count();
966
967        let high_priority_recommendations = self
968            .session_state
969            .recommendations
970            .iter()
971            .filter(|r| r.priority == AutoDebugRecommendationPriority::High)
972            .count();
973
974        format!(
975            "Auto-debugging analysis completed. Found {} critical issues, {} major issues. \
976            Generated {} recommendations with {} high-priority actions. \
977            Average recommendation confidence: {:.2}",
978            critical_issues,
979            major_issues,
980            self.session_state.recommendations.len(),
981            high_priority_recommendations,
982            self.session_state.session_stats.avg_recommendation_confidence
983        )
984    }
985}
986
987/// Comprehensive debugging report.
988#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
989pub struct DebuggingReport {
990    /// Session information
991    pub session_info: DebuggingSession,
992    /// All identified issues
993    pub identified_issues: Vec<IdentifiedIssue>,
994    /// Generated recommendations
995    pub recommendations: Vec<DebuggingRecommendation>,
996    /// Analysis summary
997    pub summary: String,
998}
999
1000impl IssuePatternDatabase {
1001    /// Create a new pattern database with default patterns.
1002    pub fn new() -> Self {
1003        Self {
1004            learning_rate_patterns: Self::create_learning_rate_patterns(),
1005            gradient_patterns: Self::create_gradient_patterns(),
1006            convergence_patterns: Self::create_convergence_patterns(),
1007            layer_patterns: Self::create_layer_patterns(),
1008        }
1009    }
1010
1011    /// Create default learning rate patterns.
1012    fn create_learning_rate_patterns() -> Vec<IssuePattern> {
1013        vec![IssuePattern {
1014            name: "Loss Explosion".to_string(),
1015            description: "Rapid increase in loss indicating learning rate too high".to_string(),
1016            conditions: vec![PatternCondition {
1017                metric: "loss".to_string(),
1018                operator: ComparisonOperator::Increasing,
1019                threshold: 2.0,
1020                consecutive_count: 3,
1021            }],
1022            issue_category: IssueCategory::LearningRate,
1023            confidence_weight: 0.9,
1024            solutions: vec![
1025                "Reduce learning rate by factor of 10".to_string(),
1026                "Enable gradient clipping".to_string(),
1027            ],
1028        }]
1029    }
1030
1031    /// Create default gradient patterns.
1032    fn create_gradient_patterns() -> Vec<IssuePattern> {
1033        vec![]
1034    }
1035
1036    /// Create default convergence patterns.
1037    fn create_convergence_patterns() -> Vec<IssuePattern> {
1038        vec![]
1039    }
1040
1041    /// Create default layer patterns.
1042    fn create_layer_patterns() -> Vec<IssuePattern> {
1043        vec![]
1044    }
1045}
1046
1047impl DebuggingSession {
1048    /// Create a new debugging session.
1049    fn new() -> Self {
1050        Self {
1051            session_start: chrono::Utc::now(),
1052            identified_issues: Vec::new(),
1053            recommendations: Vec::new(),
1054            session_stats: SessionStatistics {
1055                total_issues: 0,
1056                issues_by_category: HashMap::new(),
1057                total_recommendations: 0,
1058                avg_recommendation_confidence: 0.0,
1059                analysis_duration: chrono::Duration::zero(),
1060            },
1061        }
1062    }
1063}
1064
1065impl Default for AutoDebugger {
1066    fn default() -> Self {
1067        Self::new()
1068    }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073    use super::*;
1074
1075    #[test]
1076    fn test_auto_debugger_creation() {
1077        let debugger = AutoDebugger::new();
1078        assert_eq!(debugger.performance_history.len(), 0);
1079        assert_eq!(debugger.layer_history.len(), 0);
1080    }
1081
1082    #[test]
1083    fn test_issue_identification() {
1084        let mut debugger = AutoDebugger::new();
1085
1086        let issue = IdentifiedIssue {
1087            category: IssueCategory::LearningRate,
1088            description: "Test issue".to_string(),
1089            severity: IssueSeverity::Major,
1090            confidence: 0.8,
1091            evidence: vec!["Test evidence".to_string()],
1092            potential_causes: vec!["Test cause".to_string()],
1093            identified_at: chrono::Utc::now(),
1094        };
1095
1096        debugger.add_issue(issue);
1097        assert_eq!(debugger.session_state.identified_issues.len(), 1);
1098    }
1099
1100    #[test]
1101    fn test_variance_calculation() {
1102        let debugger = AutoDebugger::new();
1103        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1104        let variance = debugger.calculate_variance(&values);
1105        assert!(variance > 0.0);
1106    }
1107
1108    #[test]
1109    fn test_trend_calculation() {
1110        let debugger = AutoDebugger::new();
1111        let increasing_values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1112        let trend = debugger.calculate_trend(&increasing_values);
1113        assert!(trend > 0.0);
1114    }
1115}