Skip to main content

trustformers_debug/
model_diagnostics_main.rs

1//! Model-level diagnostics and analysis tools.
2//!
3//! This module has been refactored into a modular architecture for better
4//! organization and maintainability. All previous functionality remains
5//! available through comprehensive re-exports to ensure backward compatibility.
6
7use crate::DebugConfig;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10
11// Import all modular components directly from the model_diagnostics directory
12use crate::model_diagnostics::*;
13
14/// Main model diagnostics system that coordinates all diagnostic components.
15#[derive(Debug)]
16pub struct ModelDiagnostics {
17    #[allow(dead_code)]
18    config: DebugConfig,
19    performance_analyzer: PerformanceAnalyzer,
20    architecture_analyzer: ArchitectureAnalyzer,
21    training_analyzer: TrainingDynamicsAnalyzer,
22    layer_analyzer: LayerAnalyzer,
23    alert_manager: AlertManager,
24    auto_debugger: AutoDebugger,
25    analytics_engine: AdvancedAnalytics,
26    current_step: usize,
27}
28
29impl ModelDiagnostics {
30    /// Create new model diagnostics system.
31    pub fn new(config: &DebugConfig) -> Self {
32        Self {
33            config: config.clone(),
34            performance_analyzer: PerformanceAnalyzer::new(),
35            architecture_analyzer: ArchitectureAnalyzer::new(),
36            training_analyzer: TrainingDynamicsAnalyzer::new(),
37            layer_analyzer: LayerAnalyzer::new(),
38            alert_manager: AlertManager::new(),
39            auto_debugger: AutoDebugger::new(),
40            analytics_engine: AdvancedAnalytics::new(),
41            current_step: 0,
42        }
43    }
44
45    /// Record performance metrics.
46    pub fn record_performance(&mut self, metrics: ModelPerformanceMetrics) -> Result<()> {
47        self.performance_analyzer.record_metrics(metrics.clone());
48        self.auto_debugger.record_performance_metrics(metrics.clone());
49        self.analytics_engine.record_performance_metrics(&metrics);
50
51        // Process metrics for alerts
52        self.alert_manager.process_performance_metrics(&metrics)?;
53
54        Ok(())
55    }
56
57    /// Record architecture information.
58    pub fn record_architecture(&mut self, arch_info: ModelArchitectureInfo) {
59        self.architecture_analyzer.record_architecture(arch_info);
60    }
61
62    /// Record layer activation statistics.
63    pub fn record_layer_stats(&mut self, stats: LayerActivationStats) -> Result<()> {
64        self.layer_analyzer.record_layer_stats(stats.clone());
65        self.auto_debugger.record_layer_stats(stats.clone());
66
67        // Process for alerts
68        self.alert_manager.process_layer_stats(&stats)?;
69
70        Ok(())
71    }
72
73    /// Record training dynamics.
74    pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) -> Result<()> {
75        self.training_analyzer.record_training_dynamics(dynamics.clone());
76        self.auto_debugger.record_training_dynamics(dynamics.clone());
77
78        // Process for alerts
79        self.alert_manager.process_training_dynamics(&dynamics)?;
80
81        Ok(())
82    }
83
84    /// Calculate overall model health score.
85    fn calculate_health_score(&self) -> f64 {
86        // Simple implementation - could be enhanced
87        let performance_score = 0.8; // Based on performance metrics
88        let architecture_score = 0.7; // Based on architecture analysis
89        let training_score = 0.9; // Based on training dynamics
90
91        (performance_score + architecture_score + training_score) / 3.0
92    }
93
94    /// Aggregate recommendations from all diagnostic components.
95    fn aggregate_recommendations(&self) -> Vec<String> {
96        let mut recommendations = Vec::new();
97
98        // Collect recommendations from architecture analyzer
99        if let Ok(arch_analysis) = self.architecture_analyzer.analyze_architecture() {
100            for recommendation in arch_analysis.recommendations {
101                recommendations.push(format!("[Architecture] {}", recommendation));
102            }
103        }
104
105        // Collect recommendations from performance analyzer
106        let perf_summary = self.performance_analyzer.generate_performance_summary();
107        // Add performance-related recommendations based on metrics
108        if perf_summary.current_loss > perf_summary.best_loss * 1.5 {
109            recommendations.push(
110                "[Performance] Current loss significantly higher than best - check for training instability"
111                    .to_string(),
112            );
113        }
114        if perf_summary.peak_memory_mb > 16384.0 {
115            // > 16GB
116            recommendations.push(
117                "[Performance] High memory usage detected - consider gradient checkpointing or smaller batch size"
118                    .to_string(),
119            );
120        }
121
122        // Collect recommendations from training analyzer
123        let training_dynamics = self.training_analyzer.analyze_training_dynamics();
124        match training_dynamics.training_stability {
125            TrainingStability::Unstable => {
126                recommendations.push(
127                    "[Training] Training stability issues detected - consider reducing learning rate or applying gradient clipping"
128                        .to_string(),
129                );
130            },
131            TrainingStability::Unknown => {
132                recommendations.push(
133                    "[Training] Collect more training metrics for better stability assessment"
134                        .to_string(),
135                );
136            },
137            _ => {},
138        }
139
140        // Check for plateau conditions
141        if let Some(plateau) = &training_dynamics.plateau_detection {
142            if plateau.duration_steps > 100 {
143                recommendations.push(
144                    "[Training] Training plateau detected - consider learning rate adjustment or early stopping"
145                        .to_string(),
146                );
147            }
148        }
149
150        // Check convergence status
151        match training_dynamics.convergence_status {
152            ConvergenceStatus::Diverging => {
153                recommendations.push(
154                    "[Training] Model is diverging - reduce learning rate immediately".to_string(),
155                );
156            },
157            ConvergenceStatus::Plateau => {
158                recommendations.push(
159                    "[Training] Training has reached a plateau - consider changing optimization strategy or early stopping"
160                        .to_string(),
161                );
162            },
163            ConvergenceStatus::Oscillating => {
164                recommendations.push(
165                    "[Training] Training is oscillating - reduce learning rate or increase batch size"
166                        .to_string(),
167                );
168            },
169            _ => {},
170        }
171
172        // Add recommendations based on overfitting/underfitting indicators
173        if !training_dynamics.overfitting_indicators.is_empty() {
174            recommendations.push(
175                "[Training] Overfitting detected - consider regularization, dropout, or early stopping"
176                    .to_string(),
177            );
178        }
179        if !training_dynamics.underfitting_indicators.is_empty() {
180            recommendations.push(
181                "[Training] Underfitting detected - consider increasing model capacity or training longer"
182                    .to_string(),
183            );
184        }
185
186        // Collect recommendations from analytics engine
187        if let Ok(analytics_report) = self.analytics_engine.generate_analytics_report() {
188            for recommendation in analytics_report.recommendations {
189                recommendations.push(format!("[Analytics] {}", recommendation));
190            }
191        }
192
193        // Remove duplicates while preserving order
194        let mut seen = std::collections::HashSet::new();
195        recommendations.retain(|r| seen.insert(r.clone()));
196
197        recommendations
198    }
199
200    /// Get current step.
201    pub fn current_step(&self) -> usize {
202        self.current_step
203    }
204
205    /// Analyze current training dynamics.
206    pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
207        self.training_analyzer.analyze_training_dynamics()
208    }
209
210    /// Increment step counter.
211    pub fn increment_step(&mut self) {
212        self.current_step += 1;
213    }
214
215    /// Start the diagnostics system.
216    pub async fn start(&mut self) -> Result<()> {
217        // Initialize all components
218        Ok(())
219    }
220
221    /// Generate comprehensive diagnostics report (async version).
222    pub async fn generate_report(&self) -> Result<ModelDiagnosticsReport> {
223        self.generate_report_sync()
224    }
225
226    /// Generate comprehensive diagnostics report (sync version).
227    pub fn generate_report_sync(&self) -> Result<ModelDiagnosticsReport> {
228        let performance_summary = self.performance_analyzer.generate_performance_summary();
229        let architectural_analysis = self.architecture_analyzer.analyze_architecture().ok();
230        let training_dynamics = self.training_analyzer.analyze_training_dynamics();
231        let alerts = self.alert_manager.get_active_alerts().to_vec();
232
233        // Generate auto-debugging analysis
234        let auto_debugging_results = None; // Simplified for now
235
236        // Generate analytics report
237        let analytics_report = self.analytics_engine.generate_analytics_report().ok();
238
239        Ok(ModelDiagnosticsReport {
240            current_step: self.current_step,
241            training_dynamics,
242            performance_summary,
243            architectural_analysis,
244            alerts: alerts.into_iter().map(|a| a.alert).collect(),
245            recommendations: self.aggregate_recommendations(),
246            health_score: self.calculate_health_score(),
247            auto_debugging_results,
248            analytics_report,
249        })
250    }
251}
252
253/// Comprehensive model diagnostics report.
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelDiagnosticsReport {
256    /// Current training step
257    pub current_step: usize,
258    /// Training dynamics analysis
259    pub training_dynamics: TrainingDynamics,
260    /// Performance summary
261    pub performance_summary: PerformanceSummary,
262    /// Architectural analysis results
263    pub architectural_analysis: Option<ArchitecturalAnalysis>,
264    /// Active diagnostic alerts
265    pub alerts: Vec<ModelDiagnosticAlert>,
266    /// Optimization recommendations
267    pub recommendations: Vec<String>,
268    /// Overall model health score
269    pub health_score: f64,
270    /// Auto-debugging results
271    pub auto_debugging_results: Option<DebuggingReport>,
272    /// Advanced analytics report
273    pub analytics_report: Option<AnalyticsReport>,
274}
275
276impl Default for ModelDiagnosticsReport {
277    fn default() -> Self {
278        Self {
279            current_step: 0,
280            training_dynamics: TrainingDynamics {
281                convergence_status: ConvergenceStatus::Unknown,
282                training_stability: TrainingStability::Unknown,
283                learning_efficiency: 0.0,
284                overfitting_indicators: Vec::new(),
285                underfitting_indicators: Vec::new(),
286                plateau_detection: None,
287            },
288            performance_summary: PerformanceSummary::default(),
289            architectural_analysis: None,
290            alerts: Vec::new(),
291            recommendations: Vec::new(),
292            health_score: 0.0,
293            auto_debugging_results: None,
294            analytics_report: None,
295        }
296    }
297}