trustformers_debug/model_diagnostics/
training.rs1use std::collections::VecDeque;
8
9use super::types::{
10 ConvergenceStatus, ModelPerformanceMetrics, OverfittingIndicator, PlateauInfo,
11 TrainingDynamics, TrainingStability, UnderfittingIndicator,
12};
13
14#[derive(Debug)]
16pub struct TrainingDynamicsAnalyzer {
17 metrics_history: VecDeque<ModelPerformanceMetrics>,
19 config: TrainingAnalysisConfig,
21 current_state: TrainingState,
23}
24
25#[derive(Debug, Clone)]
27pub struct TrainingAnalysisConfig {
28 pub convergence_window: usize,
30 pub min_improvement_threshold: f64,
32 pub max_variance_threshold: f64,
34 pub min_plateau_duration: usize,
36 pub overfitting_gap_threshold: f64,
38 pub min_learning_rate: f64,
40}
41
42impl Default for TrainingAnalysisConfig {
43 fn default() -> Self {
44 Self {
45 convergence_window: 20,
46 min_improvement_threshold: 0.001,
47 max_variance_threshold: 0.1,
48 min_plateau_duration: 10,
49 overfitting_gap_threshold: 0.05,
50 min_learning_rate: 1e-6,
51 }
52 }
53}
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TrainingState {
58 steps_since_improvement: usize,
60 best_loss: f64,
62 current_plateau: Option<PlateauInfo>,
64 convergence_history: VecDeque<ConvergenceStatus>,
66}
67
68impl Default for TrainingState {
69 fn default() -> Self {
70 Self {
71 steps_since_improvement: 0,
72 best_loss: f64::INFINITY,
73 current_plateau: None,
74 convergence_history: VecDeque::new(),
75 }
76 }
77}
78
79impl TrainingDynamicsAnalyzer {
80 pub fn new() -> Self {
82 Self {
83 metrics_history: VecDeque::new(),
84 config: TrainingAnalysisConfig::default(),
85 current_state: TrainingState::default(),
86 }
87 }
88
89 pub fn with_config(config: TrainingAnalysisConfig) -> Self {
91 Self {
92 metrics_history: VecDeque::new(),
93 config,
94 current_state: TrainingState::default(),
95 }
96 }
97
98 pub fn add_metrics(&mut self, metrics: ModelPerformanceMetrics) {
100 if metrics.loss < self.current_state.best_loss {
102 self.current_state.best_loss = metrics.loss;
103 self.current_state.steps_since_improvement = 0;
104 } else {
105 self.current_state.steps_since_improvement += 1;
106 }
107
108 self.metrics_history.push_back(metrics);
109
110 if self.metrics_history.len() > 1000 {
112 self.metrics_history.pop_front();
113 }
114
115 let status = self.detect_convergence_status();
117 self.current_state.convergence_history.push_back(status);
118 if self.current_state.convergence_history.len() > 50 {
119 self.current_state.convergence_history.pop_front();
120 }
121 }
122
123 pub fn record_training_dynamics(&mut self, _dynamics: TrainingDynamics) {
125 }
128
129 pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
131 let convergence_status = self.detect_convergence_status();
132 let training_stability = self.assess_training_stability();
133 let learning_efficiency = self.calculate_learning_efficiency();
134 let overfitting_indicators = self.detect_overfitting_indicators();
135 let underfitting_indicators = self.detect_underfitting_indicators();
136 let plateau_detection = self.detect_plateau();
137
138 TrainingDynamics {
139 convergence_status,
140 training_stability,
141 learning_efficiency,
142 overfitting_indicators,
143 underfitting_indicators,
144 plateau_detection,
145 }
146 }
147
148 pub fn detect_convergence_status(&self) -> ConvergenceStatus {
150 if self.metrics_history.len() < self.config.convergence_window {
151 return ConvergenceStatus::Unknown;
152 }
153
154 let recent_metrics: Vec<_> =
155 self.metrics_history.iter().rev().take(self.config.convergence_window).collect();
156
157 let losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
158
159 if self.is_converged(&losses) {
161 ConvergenceStatus::Converged
162 } else if self.is_diverging(&losses) {
163 ConvergenceStatus::Diverging
164 } else if self.is_oscillating(&losses) {
165 ConvergenceStatus::Oscillating
166 } else if self.is_plateau(&losses) {
167 ConvergenceStatus::Plateau
168 } else if self.is_converging(&losses) {
169 ConvergenceStatus::Converging
170 } else {
171 ConvergenceStatus::Unknown
172 }
173 }
174
175 pub fn assess_training_stability(&self) -> TrainingStability {
177 if self.metrics_history.len() < 10 {
178 return TrainingStability::Unknown;
179 }
180
181 let recent_losses: Vec<f64> =
182 self.metrics_history.iter().rev().take(20).map(|m| m.loss).collect();
183
184 let variance = self.calculate_variance(&recent_losses);
185
186 if variance > self.config.max_variance_threshold {
187 TrainingStability::Unstable
188 } else if variance > self.config.max_variance_threshold / 2.0 {
189 TrainingStability::HighVariance
190 } else {
191 TrainingStability::Stable
192 }
193 }
194
195 pub fn calculate_learning_efficiency(&self) -> f64 {
197 if self.metrics_history.len() < 2 {
198 return 0.0;
199 }
200
201 let initial_loss = self.metrics_history.front().unwrap().loss;
202 let current_loss = self.metrics_history.back().unwrap().loss;
203 let steps = self.metrics_history.len();
204
205 if initial_loss <= current_loss {
206 return 0.0;
207 }
208
209 let improvement = (initial_loss - current_loss) / initial_loss;
210 let efficiency = improvement / (steps as f64).sqrt();
211
212 efficiency.min(1.0)
213 }
214
215 pub fn detect_overfitting_indicators(&self) -> Vec<OverfittingIndicator> {
217 let mut indicators = Vec::new();
218
219 if self.metrics_history.len() > 10 {
221 let recent_losses: Vec<f64> =
222 self.metrics_history.iter().rev().take(10).map(|m| m.loss).collect();
223
224 let avg_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
226 if avg_loss < 0.01 {
227 indicators.push(OverfittingIndicator::PerfectTrainingAccuracy);
228 }
229
230 let variance = self.calculate_variance(&recent_losses);
232 if variance > 0.05 {
233 indicators.push(OverfittingIndicator::HighVarianceInValidation);
234 }
235 }
236
237 indicators
238 }
239
240 pub fn detect_underfitting_indicators(&self) -> Vec<UnderfittingIndicator> {
242 let mut indicators = Vec::new();
243
244 if let Some(current_metrics) = self.metrics_history.back() {
245 if current_metrics.loss > 1.0 {
247 indicators.push(UnderfittingIndicator::HighTrainingLoss {
248 loss: current_metrics.loss,
249 threshold: 1.0,
250 });
251 }
252
253 if let Some(accuracy) = current_metrics.accuracy {
255 if accuracy < 0.5 {
256 indicators.push(UnderfittingIndicator::LowTrainingAccuracy {
257 accuracy,
258 threshold: 0.5,
259 });
260 }
261 }
262
263 if self.current_state.steps_since_improvement > 50 {
265 indicators.push(UnderfittingIndicator::SlowConvergence {
266 steps_taken: self.metrics_history.len(),
267 expected: self.metrics_history.len() / 2,
268 });
269 }
270
271 if self.current_state.steps_since_improvement > 100 {
273 indicators.push(UnderfittingIndicator::NoLearning {
274 steps_without_improvement: self.current_state.steps_since_improvement,
275 });
276 }
277 }
278
279 indicators
280 }
281
282 pub fn detect_plateau(&self) -> Option<PlateauInfo> {
284 if self.metrics_history.len() < self.config.min_plateau_duration {
285 return None;
286 }
287
288 let recent_losses: Vec<f64> = self
289 .metrics_history
290 .iter()
291 .rev()
292 .take(self.config.min_plateau_duration)
293 .map(|m| m.loss)
294 .collect();
295
296 let variance = self.calculate_variance(&recent_losses);
297 let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
298
299 if variance < self.config.min_improvement_threshold {
301 let start_step = self.metrics_history.len() - self.config.min_plateau_duration;
302 Some(PlateauInfo {
303 start_step,
304 duration_steps: self.config.min_plateau_duration,
305 plateau_value: mean_loss,
306 variance,
307 })
308 } else {
309 None
310 }
311 }
312
313 pub fn generate_training_recommendations(&self) -> Vec<TrainingRecommendation> {
315 let mut recommendations = Vec::new();
316 let dynamics = self.analyze_training_dynamics();
317
318 match dynamics.convergence_status {
319 ConvergenceStatus::Diverging => {
320 recommendations.push(TrainingRecommendation {
321 category: "Convergence".to_string(),
322 priority: TrainingRecommendationPriority::Critical,
323 description: "Training is diverging".to_string(),
324 action: "Reduce learning rate immediately".to_string(),
325 expected_impact: 0.8,
326 });
327 },
328 ConvergenceStatus::Plateau => {
329 recommendations.push(TrainingRecommendation {
330 category: "Convergence".to_string(),
331 priority: TrainingRecommendationPriority::High,
332 description: "Training has reached a plateau".to_string(),
333 action: "Consider learning rate scheduling or data augmentation".to_string(),
334 expected_impact: 0.6,
335 });
336 },
337 _ => {},
338 }
339
340 if let TrainingStability::Unstable = dynamics.training_stability {
341 recommendations.push(TrainingRecommendation {
342 category: "Stability".to_string(),
343 priority: TrainingRecommendationPriority::High,
344 description: "Training is unstable".to_string(),
345 action: "Reduce learning rate or add gradient clipping".to_string(),
346 expected_impact: 0.7,
347 });
348 }
349
350 if dynamics.learning_efficiency < 0.3 {
351 recommendations.push(TrainingRecommendation {
352 category: "Efficiency".to_string(),
353 priority: TrainingRecommendationPriority::Medium,
354 description: "Low learning efficiency detected".to_string(),
355 action: "Consider architecture changes or hyperparameter tuning".to_string(),
356 expected_impact: 0.5,
357 });
358 }
359
360 recommendations
361 }
362
363 fn is_converged(&self, losses: &[f64]) -> bool {
365 if losses.len() < 5 {
366 return false;
367 }
368
369 let recent_variance = self.calculate_variance(&losses[..5]);
370 recent_variance < self.config.min_improvement_threshold && losses[0] < 0.01
371 }
372
373 fn is_diverging(&self, losses: &[f64]) -> bool {
374 if losses.len() < 3 {
375 return false;
376 }
377
378 losses.windows(2).all(|w| w[1] >= w[0])
380 && (losses.last().unwrap() / losses.first().unwrap()) > 1.1
381 }
382
383 fn is_oscillating(&self, losses: &[f64]) -> bool {
384 if losses.len() < 6 {
385 return false;
386 }
387
388 let mut direction_changes = 0;
390 for window in losses.windows(3) {
391 let trend1 = window[1] - window[0];
392 let trend2 = window[2] - window[1];
393 if trend1.signum() != trend2.signum() {
394 direction_changes += 1;
395 }
396 }
397
398 direction_changes > losses.len() / 3
399 }
400
401 fn is_plateau(&self, losses: &[f64]) -> bool {
402 let variance = self.calculate_variance(losses);
403 variance < self.config.min_improvement_threshold
404 }
405
406 fn is_converging(&self, losses: &[f64]) -> bool {
407 if losses.len() < 3 {
408 return false;
409 }
410
411 let trend = self.calculate_trend(losses);
413 trend < -self.config.min_improvement_threshold
414 }
415
416 fn calculate_variance(&self, values: &[f64]) -> f64 {
417 if values.len() < 2 {
418 return 0.0;
419 }
420
421 let mean = values.iter().sum::<f64>() / values.len() as f64;
422 let variance =
423 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
424 variance
425 }
426
427 fn calculate_trend(&self, values: &[f64]) -> f64 {
428 if values.len() < 2 {
429 return 0.0;
430 }
431
432 let n = values.len() as f64;
433 let x_mean = (n - 1.0) / 2.0;
434 let y_mean = values.iter().sum::<f64>() / n;
435
436 let mut numerator = 0.0;
437 let mut denominator = 0.0;
438
439 for (i, &y) in values.iter().enumerate() {
440 let x = i as f64;
441 numerator += (x - x_mean) * (y - y_mean);
442 denominator += (x - x_mean).powi(2);
443 }
444
445 if denominator == 0.0 {
446 0.0
447 } else {
448 numerator / denominator
449 }
450 }
451
452 pub fn clear(&mut self) {
454 self.metrics_history.clear();
455 self.current_state = TrainingState::default();
456 }
457
458 pub fn get_training_state(&self) -> &TrainingState {
460 &self.current_state
461 }
462
463 pub async fn generate_report(&self) -> anyhow::Result<TrainingDynamicsReport> {
465 let training_dynamics = self.analyze_training_dynamics();
466 let recommendations = self.generate_recommendations();
467
468 Ok(TrainingDynamicsReport {
469 training_dynamics,
470 recommendations,
471 current_state: self.current_state.clone(),
472 })
473 }
474
475 fn generate_recommendations(&self) -> Vec<TrainingRecommendation> {
477 let mut recommendations = Vec::new();
478
479 recommendations.push(TrainingRecommendation {
481 category: "General".to_string(),
482 description: "Continue monitoring training dynamics".to_string(),
483 action: "Monitor training progress and adjust parameters as needed".to_string(),
484 priority: TrainingRecommendationPriority::Low,
485 expected_impact: 0.1,
486 });
487
488 recommendations
489 }
490}
491
492impl Default for TrainingDynamicsAnalyzer {
493 fn default() -> Self {
494 Self::new()
495 }
496}
497
498#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
500pub struct TrainingRecommendation {
501 pub category: String,
503 pub priority: TrainingRecommendationPriority,
505 pub description: String,
507 pub action: String,
509 pub expected_impact: f64,
511}
512
513#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
515pub enum TrainingRecommendationPriority {
516 Low,
518 Medium,
520 High,
522 Critical,
524}
525
526#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
528pub struct TrainingDynamicsReport {
529 pub training_dynamics: TrainingDynamics,
531 pub recommendations: Vec<TrainingRecommendation>,
533 pub current_state: TrainingState,
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use chrono::Utc;
541
542 fn create_test_metrics(step: usize, loss: f64) -> ModelPerformanceMetrics {
543 ModelPerformanceMetrics {
544 training_step: step,
545 loss,
546 accuracy: Some(0.8),
547 learning_rate: 0.001,
548 batch_size: 32,
549 throughput_samples_per_sec: 100.0,
550 memory_usage_mb: 1000.0,
551 gpu_utilization: Some(0.9),
552 timestamp: Utc::now(),
553 }
554 }
555
556 #[test]
557 fn test_training_dynamics_analyzer_creation() {
558 let analyzer = TrainingDynamicsAnalyzer::new();
559 assert_eq!(analyzer.metrics_history.len(), 0);
560 }
561
562 #[test]
563 fn test_add_metrics() {
564 let mut analyzer = TrainingDynamicsAnalyzer::new();
565 let metrics = create_test_metrics(1, 0.5);
566
567 analyzer.add_metrics(metrics);
568 assert_eq!(analyzer.metrics_history.len(), 1);
569 assert_eq!(analyzer.current_state.best_loss, 0.5);
570 }
571
572 #[test]
573 fn test_convergence_detection() {
574 let mut analyzer = TrainingDynamicsAnalyzer::new();
575
576 for i in 1..=25 {
578 let loss = 1.0 / (i as f64);
579 let metrics = create_test_metrics(i, loss);
580 analyzer.add_metrics(metrics);
581 }
582
583 let status = analyzer.detect_convergence_status();
584 matches!(
585 status,
586 ConvergenceStatus::Converging | ConvergenceStatus::Converged
587 );
588 }
589
590 #[test]
591 fn test_learning_efficiency_calculation() {
592 let mut analyzer = TrainingDynamicsAnalyzer::new();
593
594 analyzer.add_metrics(create_test_metrics(1, 1.0));
595 analyzer.add_metrics(create_test_metrics(2, 0.5));
596 analyzer.add_metrics(create_test_metrics(3, 0.25));
597
598 let efficiency = analyzer.calculate_learning_efficiency();
599 assert!(efficiency > 0.0);
600 }
601
602 #[test]
603 fn test_plateau_detection() {
604 let mut analyzer = TrainingDynamicsAnalyzer::new();
605
606 for i in 1..=15 {
608 let metrics = create_test_metrics(i, 0.1); analyzer.add_metrics(metrics);
610 }
611
612 let plateau = analyzer.detect_plateau();
613 assert!(plateau.is_some());
614 }
615
616 #[test]
617 fn test_training_stability_assessment() {
618 let mut analyzer = TrainingDynamicsAnalyzer::new();
619
620 for i in 1..=20 {
622 let loss = 0.5 + (i as f64 * 0.001); let metrics = create_test_metrics(i, loss);
624 analyzer.add_metrics(metrics);
625 }
626
627 let stability = analyzer.assess_training_stability();
628 matches!(stability, TrainingStability::Stable);
629 }
630}