trustformers_debug/
model_diagnostics_main.rs1use crate::DebugConfig;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10
11use crate::model_diagnostics::*;
13
14#[derive(Debug)]
16pub struct ModelDiagnostics {
17 #[allow(dead_code)]
18 config: DebugConfig,
19 performance_analyzer: PerformanceAnalyzer,
20 architecture_analyzer: ArchitectureAnalyzer,
21 training_analyzer: TrainingDynamicsAnalyzer,
22 layer_analyzer: LayerAnalyzer,
23 alert_manager: AlertManager,
24 auto_debugger: AutoDebugger,
25 analytics_engine: AdvancedAnalytics,
26 current_step: usize,
27}
28
29impl ModelDiagnostics {
30 pub fn new(config: &DebugConfig) -> Self {
32 Self {
33 config: config.clone(),
34 performance_analyzer: PerformanceAnalyzer::new(),
35 architecture_analyzer: ArchitectureAnalyzer::new(),
36 training_analyzer: TrainingDynamicsAnalyzer::new(),
37 layer_analyzer: LayerAnalyzer::new(),
38 alert_manager: AlertManager::new(),
39 auto_debugger: AutoDebugger::new(),
40 analytics_engine: AdvancedAnalytics::new(),
41 current_step: 0,
42 }
43 }
44
45 pub fn record_performance(&mut self, metrics: ModelPerformanceMetrics) -> Result<()> {
47 self.performance_analyzer.record_metrics(metrics.clone());
48 self.auto_debugger.record_performance_metrics(metrics.clone());
49 self.analytics_engine.record_performance_metrics(&metrics);
50
51 self.alert_manager.process_performance_metrics(&metrics)?;
53
54 Ok(())
55 }
56
57 pub fn record_architecture(&mut self, arch_info: ModelArchitectureInfo) {
59 self.architecture_analyzer.record_architecture(arch_info);
60 }
61
62 pub fn record_layer_stats(&mut self, stats: LayerActivationStats) -> Result<()> {
64 self.layer_analyzer.record_layer_stats(stats.clone());
65 self.auto_debugger.record_layer_stats(stats.clone());
66
67 self.alert_manager.process_layer_stats(&stats)?;
69
70 Ok(())
71 }
72
73 pub fn record_training_dynamics(&mut self, dynamics: TrainingDynamics) -> Result<()> {
75 self.training_analyzer.record_training_dynamics(dynamics.clone());
76 self.auto_debugger.record_training_dynamics(dynamics.clone());
77
78 self.alert_manager.process_training_dynamics(&dynamics)?;
80
81 Ok(())
82 }
83
84 fn calculate_health_score(&self) -> f64 {
86 let performance_score = 0.8; let architecture_score = 0.7; let training_score = 0.9; (performance_score + architecture_score + training_score) / 3.0
92 }
93
94 fn aggregate_recommendations(&self) -> Vec<String> {
96 let mut recommendations = Vec::new();
97
98 if let Ok(arch_analysis) = self.architecture_analyzer.analyze_architecture() {
100 for recommendation in arch_analysis.recommendations {
101 recommendations.push(format!("[Architecture] {}", recommendation));
102 }
103 }
104
105 let perf_summary = self.performance_analyzer.generate_performance_summary();
107 if perf_summary.current_loss > perf_summary.best_loss * 1.5 {
109 recommendations.push(
110 "[Performance] Current loss significantly higher than best - check for training instability"
111 .to_string(),
112 );
113 }
114 if perf_summary.peak_memory_mb > 16384.0 {
115 recommendations.push(
117 "[Performance] High memory usage detected - consider gradient checkpointing or smaller batch size"
118 .to_string(),
119 );
120 }
121
122 let training_dynamics = self.training_analyzer.analyze_training_dynamics();
124 match training_dynamics.training_stability {
125 TrainingStability::Unstable => {
126 recommendations.push(
127 "[Training] Training stability issues detected - consider reducing learning rate or applying gradient clipping"
128 .to_string(),
129 );
130 },
131 TrainingStability::Unknown => {
132 recommendations.push(
133 "[Training] Collect more training metrics for better stability assessment"
134 .to_string(),
135 );
136 },
137 _ => {},
138 }
139
140 if let Some(plateau) = &training_dynamics.plateau_detection {
142 if plateau.duration_steps > 100 {
143 recommendations.push(
144 "[Training] Training plateau detected - consider learning rate adjustment or early stopping"
145 .to_string(),
146 );
147 }
148 }
149
150 match training_dynamics.convergence_status {
152 ConvergenceStatus::Diverging => {
153 recommendations.push(
154 "[Training] Model is diverging - reduce learning rate immediately".to_string(),
155 );
156 },
157 ConvergenceStatus::Plateau => {
158 recommendations.push(
159 "[Training] Training has reached a plateau - consider changing optimization strategy or early stopping"
160 .to_string(),
161 );
162 },
163 ConvergenceStatus::Oscillating => {
164 recommendations.push(
165 "[Training] Training is oscillating - reduce learning rate or increase batch size"
166 .to_string(),
167 );
168 },
169 _ => {},
170 }
171
172 if !training_dynamics.overfitting_indicators.is_empty() {
174 recommendations.push(
175 "[Training] Overfitting detected - consider regularization, dropout, or early stopping"
176 .to_string(),
177 );
178 }
179 if !training_dynamics.underfitting_indicators.is_empty() {
180 recommendations.push(
181 "[Training] Underfitting detected - consider increasing model capacity or training longer"
182 .to_string(),
183 );
184 }
185
186 if let Ok(analytics_report) = self.analytics_engine.generate_analytics_report() {
188 for recommendation in analytics_report.recommendations {
189 recommendations.push(format!("[Analytics] {}", recommendation));
190 }
191 }
192
193 let mut seen = std::collections::HashSet::new();
195 recommendations.retain(|r| seen.insert(r.clone()));
196
197 recommendations
198 }
199
200 pub fn current_step(&self) -> usize {
202 self.current_step
203 }
204
205 pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
207 self.training_analyzer.analyze_training_dynamics()
208 }
209
210 pub fn increment_step(&mut self) {
212 self.current_step += 1;
213 }
214
215 pub async fn start(&mut self) -> Result<()> {
217 Ok(())
219 }
220
221 pub async fn generate_report(&self) -> Result<ModelDiagnosticsReport> {
223 self.generate_report_sync()
224 }
225
226 pub fn generate_report_sync(&self) -> Result<ModelDiagnosticsReport> {
228 let performance_summary = self.performance_analyzer.generate_performance_summary();
229 let architectural_analysis = self.architecture_analyzer.analyze_architecture().ok();
230 let training_dynamics = self.training_analyzer.analyze_training_dynamics();
231 let alerts = self.alert_manager.get_active_alerts().to_vec();
232
233 let auto_debugging_results = None; let analytics_report = self.analytics_engine.generate_analytics_report().ok();
238
239 Ok(ModelDiagnosticsReport {
240 current_step: self.current_step,
241 training_dynamics,
242 performance_summary,
243 architectural_analysis,
244 alerts: alerts.into_iter().map(|a| a.alert).collect(),
245 recommendations: self.aggregate_recommendations(),
246 health_score: self.calculate_health_score(),
247 auto_debugging_results,
248 analytics_report,
249 })
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct ModelDiagnosticsReport {
256 pub current_step: usize,
258 pub training_dynamics: TrainingDynamics,
260 pub performance_summary: PerformanceSummary,
262 pub architectural_analysis: Option<ArchitecturalAnalysis>,
264 pub alerts: Vec<ModelDiagnosticAlert>,
266 pub recommendations: Vec<String>,
268 pub health_score: f64,
270 pub auto_debugging_results: Option<DebuggingReport>,
272 pub analytics_report: Option<AnalyticsReport>,
274}
275
276impl Default for ModelDiagnosticsReport {
277 fn default() -> Self {
278 Self {
279 current_step: 0,
280 training_dynamics: TrainingDynamics {
281 convergence_status: ConvergenceStatus::Unknown,
282 training_stability: TrainingStability::Unknown,
283 learning_efficiency: 0.0,
284 overfitting_indicators: Vec::new(),
285 underfitting_indicators: Vec::new(),
286 plateau_detection: None,
287 },
288 performance_summary: PerformanceSummary::default(),
289 architectural_analysis: None,
290 alerts: Vec::new(),
291 recommendations: Vec::new(),
292 health_score: 0.0,
293 auto_debugging_results: None,
294 analytics_report: None,
295 }
296 }
297}