trustformers_debug/model_diagnostics/
training.rs1use std::collections::VecDeque;
8
9use super::types::{
10 ConvergenceStatus, ModelPerformanceMetrics, OverfittingIndicator, PlateauInfo,
11 TrainingDynamics, TrainingStability, UnderfittingIndicator,
12};
13
14#[derive(Debug)]
16pub struct TrainingDynamicsAnalyzer {
17 metrics_history: VecDeque<ModelPerformanceMetrics>,
19 config: TrainingAnalysisConfig,
21 current_state: TrainingState,
23}
24
25#[derive(Debug, Clone)]
27pub struct TrainingAnalysisConfig {
28 pub convergence_window: usize,
30 pub min_improvement_threshold: f64,
32 pub max_variance_threshold: f64,
34 pub min_plateau_duration: usize,
36 pub overfitting_gap_threshold: f64,
38 pub min_learning_rate: f64,
40}
41
42impl Default for TrainingAnalysisConfig {
43 fn default() -> Self {
44 Self {
45 convergence_window: 20,
46 min_improvement_threshold: 0.001,
47 max_variance_threshold: 0.1,
48 min_plateau_duration: 10,
49 overfitting_gap_threshold: 0.05,
50 min_learning_rate: 1e-6,
51 }
52 }
53}
54
55#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TrainingState {
58 steps_since_improvement: usize,
60 best_loss: f64,
62 current_plateau: Option<PlateauInfo>,
64 convergence_history: VecDeque<ConvergenceStatus>,
66}
67
68impl Default for TrainingState {
69 fn default() -> Self {
70 Self {
71 steps_since_improvement: 0,
72 best_loss: f64::INFINITY,
73 current_plateau: None,
74 convergence_history: VecDeque::new(),
75 }
76 }
77}
78
79impl TrainingDynamicsAnalyzer {
80 pub fn new() -> Self {
82 Self {
83 metrics_history: VecDeque::new(),
84 config: TrainingAnalysisConfig::default(),
85 current_state: TrainingState::default(),
86 }
87 }
88
89 pub fn with_config(config: TrainingAnalysisConfig) -> Self {
91 Self {
92 metrics_history: VecDeque::new(),
93 config,
94 current_state: TrainingState::default(),
95 }
96 }
97
98 pub fn add_metrics(&mut self, metrics: ModelPerformanceMetrics) {
100 if metrics.loss < self.current_state.best_loss {
102 self.current_state.best_loss = metrics.loss;
103 self.current_state.steps_since_improvement = 0;
104 } else {
105 self.current_state.steps_since_improvement += 1;
106 }
107
108 self.metrics_history.push_back(metrics);
109
110 if self.metrics_history.len() > 1000 {
112 self.metrics_history.pop_front();
113 }
114
115 let status = self.detect_convergence_status();
117 self.current_state.convergence_history.push_back(status);
118 if self.current_state.convergence_history.len() > 50 {
119 self.current_state.convergence_history.pop_front();
120 }
121 }
122
123 pub fn record_training_dynamics(&mut self, _dynamics: TrainingDynamics) {
125 }
128
129 pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
131 let convergence_status = self.detect_convergence_status();
132 let training_stability = self.assess_training_stability();
133 let learning_efficiency = self.calculate_learning_efficiency();
134 let overfitting_indicators = self.detect_overfitting_indicators();
135 let underfitting_indicators = self.detect_underfitting_indicators();
136 let plateau_detection = self.detect_plateau();
137
138 TrainingDynamics {
139 convergence_status,
140 training_stability,
141 learning_efficiency,
142 overfitting_indicators,
143 underfitting_indicators,
144 plateau_detection,
145 }
146 }
147
148 pub fn detect_convergence_status(&self) -> ConvergenceStatus {
150 if self.metrics_history.len() < self.config.convergence_window {
151 return ConvergenceStatus::Unknown;
152 }
153
154 let recent_metrics: Vec<_> =
155 self.metrics_history.iter().rev().take(self.config.convergence_window).collect();
156
157 let losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
158
159 if self.is_converged(&losses) {
161 ConvergenceStatus::Converged
162 } else if self.is_diverging(&losses) {
163 ConvergenceStatus::Diverging
164 } else if self.is_oscillating(&losses) {
165 ConvergenceStatus::Oscillating
166 } else if self.is_plateau(&losses) {
167 ConvergenceStatus::Plateau
168 } else if self.is_converging(&losses) {
169 ConvergenceStatus::Converging
170 } else {
171 ConvergenceStatus::Unknown
172 }
173 }
174
175 pub fn assess_training_stability(&self) -> TrainingStability {
177 if self.metrics_history.len() < 10 {
178 return TrainingStability::Unknown;
179 }
180
181 let recent_losses: Vec<f64> =
182 self.metrics_history.iter().rev().take(20).map(|m| m.loss).collect();
183
184 let variance = self.calculate_variance(&recent_losses);
185
186 if variance > self.config.max_variance_threshold {
187 TrainingStability::Unstable
188 } else if variance > self.config.max_variance_threshold / 2.0 {
189 TrainingStability::HighVariance
190 } else {
191 TrainingStability::Stable
192 }
193 }
194
195 pub fn calculate_learning_efficiency(&self) -> f64 {
197 if self.metrics_history.len() < 2 {
198 return 0.0;
199 }
200
201 let initial_loss = self
202 .metrics_history
203 .front()
204 .expect("metrics_history has at least 2 elements")
205 .loss;
206 let current_loss = self
207 .metrics_history
208 .back()
209 .expect("metrics_history has at least 2 elements")
210 .loss;
211 let steps = self.metrics_history.len();
212
213 if initial_loss <= current_loss {
214 return 0.0;
215 }
216
217 let improvement = (initial_loss - current_loss) / initial_loss;
218 let efficiency = improvement / (steps as f64).sqrt();
219
220 efficiency.min(1.0)
221 }
222
223 pub fn detect_overfitting_indicators(&self) -> Vec<OverfittingIndicator> {
225 let mut indicators = Vec::new();
226
227 if self.metrics_history.len() > 10 {
229 let recent_losses: Vec<f64> =
230 self.metrics_history.iter().rev().take(10).map(|m| m.loss).collect();
231
232 let avg_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
234 if avg_loss < 0.01 {
235 indicators.push(OverfittingIndicator::PerfectTrainingAccuracy);
236 }
237
238 let variance = self.calculate_variance(&recent_losses);
240 if variance > 0.05 {
241 indicators.push(OverfittingIndicator::HighVarianceInValidation);
242 }
243 }
244
245 indicators
246 }
247
248 pub fn detect_underfitting_indicators(&self) -> Vec<UnderfittingIndicator> {
250 let mut indicators = Vec::new();
251
252 if let Some(current_metrics) = self.metrics_history.back() {
253 if current_metrics.loss > 1.0 {
255 indicators.push(UnderfittingIndicator::HighTrainingLoss {
256 loss: current_metrics.loss,
257 threshold: 1.0,
258 });
259 }
260
261 if let Some(accuracy) = current_metrics.accuracy {
263 if accuracy < 0.5 {
264 indicators.push(UnderfittingIndicator::LowTrainingAccuracy {
265 accuracy,
266 threshold: 0.5,
267 });
268 }
269 }
270
271 if self.current_state.steps_since_improvement > 50 {
273 indicators.push(UnderfittingIndicator::SlowConvergence {
274 steps_taken: self.metrics_history.len(),
275 expected: self.metrics_history.len() / 2,
276 });
277 }
278
279 if self.current_state.steps_since_improvement > 100 {
281 indicators.push(UnderfittingIndicator::NoLearning {
282 steps_without_improvement: self.current_state.steps_since_improvement,
283 });
284 }
285 }
286
287 indicators
288 }
289
290 pub fn detect_plateau(&self) -> Option<PlateauInfo> {
292 if self.metrics_history.len() < self.config.min_plateau_duration {
293 return None;
294 }
295
296 let recent_losses: Vec<f64> = self
297 .metrics_history
298 .iter()
299 .rev()
300 .take(self.config.min_plateau_duration)
301 .map(|m| m.loss)
302 .collect();
303
304 let variance = self.calculate_variance(&recent_losses);
305 let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
306
307 if variance < self.config.min_improvement_threshold {
309 let start_step = self.metrics_history.len() - self.config.min_plateau_duration;
310 Some(PlateauInfo {
311 start_step,
312 duration_steps: self.config.min_plateau_duration,
313 plateau_value: mean_loss,
314 variance,
315 })
316 } else {
317 None
318 }
319 }
320
321 pub fn generate_training_recommendations(&self) -> Vec<TrainingRecommendation> {
323 let mut recommendations = Vec::new();
324 let dynamics = self.analyze_training_dynamics();
325
326 match dynamics.convergence_status {
327 ConvergenceStatus::Diverging => {
328 recommendations.push(TrainingRecommendation {
329 category: "Convergence".to_string(),
330 priority: TrainingRecommendationPriority::Critical,
331 description: "Training is diverging".to_string(),
332 action: "Reduce learning rate immediately".to_string(),
333 expected_impact: 0.8,
334 });
335 },
336 ConvergenceStatus::Plateau => {
337 recommendations.push(TrainingRecommendation {
338 category: "Convergence".to_string(),
339 priority: TrainingRecommendationPriority::High,
340 description: "Training has reached a plateau".to_string(),
341 action: "Consider learning rate scheduling or data augmentation".to_string(),
342 expected_impact: 0.6,
343 });
344 },
345 _ => {},
346 }
347
348 if let TrainingStability::Unstable = dynamics.training_stability {
349 recommendations.push(TrainingRecommendation {
350 category: "Stability".to_string(),
351 priority: TrainingRecommendationPriority::High,
352 description: "Training is unstable".to_string(),
353 action: "Reduce learning rate or add gradient clipping".to_string(),
354 expected_impact: 0.7,
355 });
356 }
357
358 if dynamics.learning_efficiency < 0.3 {
359 recommendations.push(TrainingRecommendation {
360 category: "Efficiency".to_string(),
361 priority: TrainingRecommendationPriority::Medium,
362 description: "Low learning efficiency detected".to_string(),
363 action: "Consider architecture changes or hyperparameter tuning".to_string(),
364 expected_impact: 0.5,
365 });
366 }
367
368 recommendations
369 }
370
371 fn is_converged(&self, losses: &[f64]) -> bool {
373 if losses.len() < 5 {
374 return false;
375 }
376
377 let recent_variance = self.calculate_variance(&losses[..5]);
378 recent_variance < self.config.min_improvement_threshold && losses[0] < 0.01
379 }
380
381 fn is_diverging(&self, losses: &[f64]) -> bool {
382 if losses.len() < 3 {
383 return false;
384 }
385
386 losses.windows(2).all(|w| w[1] >= w[0])
388 && (losses.last().expect("losses has at least 3 elements")
389 / losses.first().expect("losses has at least 3 elements"))
390 > 1.1
391 }
392
393 fn is_oscillating(&self, losses: &[f64]) -> bool {
394 if losses.len() < 6 {
395 return false;
396 }
397
398 let mut direction_changes = 0;
400 for window in losses.windows(3) {
401 let trend1 = window[1] - window[0];
402 let trend2 = window[2] - window[1];
403 if trend1.signum() != trend2.signum() {
404 direction_changes += 1;
405 }
406 }
407
408 direction_changes > losses.len() / 3
409 }
410
411 fn is_plateau(&self, losses: &[f64]) -> bool {
412 let variance = self.calculate_variance(losses);
413 variance < self.config.min_improvement_threshold
414 }
415
416 fn is_converging(&self, losses: &[f64]) -> bool {
417 if losses.len() < 3 {
418 return false;
419 }
420
421 let trend = self.calculate_trend(losses);
423 trend < -self.config.min_improvement_threshold
424 }
425
426 fn calculate_variance(&self, values: &[f64]) -> f64 {
427 if values.len() < 2 {
428 return 0.0;
429 }
430
431 let mean = values.iter().sum::<f64>() / values.len() as f64;
432 let variance =
433 values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
434 variance
435 }
436
437 fn calculate_trend(&self, values: &[f64]) -> f64 {
438 if values.len() < 2 {
439 return 0.0;
440 }
441
442 let n = values.len() as f64;
443 let x_mean = (n - 1.0) / 2.0;
444 let y_mean = values.iter().sum::<f64>() / n;
445
446 let mut numerator = 0.0;
447 let mut denominator = 0.0;
448
449 for (i, &y) in values.iter().enumerate() {
450 let x = i as f64;
451 numerator += (x - x_mean) * (y - y_mean);
452 denominator += (x - x_mean).powi(2);
453 }
454
455 if denominator == 0.0 {
456 0.0
457 } else {
458 numerator / denominator
459 }
460 }
461
462 pub fn clear(&mut self) {
464 self.metrics_history.clear();
465 self.current_state = TrainingState::default();
466 }
467
468 pub fn get_training_state(&self) -> &TrainingState {
470 &self.current_state
471 }
472
473 pub async fn generate_report(&self) -> anyhow::Result<TrainingDynamicsReport> {
475 let training_dynamics = self.analyze_training_dynamics();
476 let recommendations = self.generate_recommendations();
477
478 Ok(TrainingDynamicsReport {
479 training_dynamics,
480 recommendations,
481 current_state: self.current_state.clone(),
482 })
483 }
484
485 fn generate_recommendations(&self) -> Vec<TrainingRecommendation> {
487 let mut recommendations = Vec::new();
488
489 recommendations.push(TrainingRecommendation {
491 category: "General".to_string(),
492 description: "Continue monitoring training dynamics".to_string(),
493 action: "Monitor training progress and adjust parameters as needed".to_string(),
494 priority: TrainingRecommendationPriority::Low,
495 expected_impact: 0.1,
496 });
497
498 recommendations
499 }
500}
501
502impl Default for TrainingDynamicsAnalyzer {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
510pub struct TrainingRecommendation {
511 pub category: String,
513 pub priority: TrainingRecommendationPriority,
515 pub description: String,
517 pub action: String,
519 pub expected_impact: f64,
521}
522
523#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
525pub enum TrainingRecommendationPriority {
526 Low,
528 Medium,
530 High,
532 Critical,
534}
535
536#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
538pub struct TrainingDynamicsReport {
539 pub training_dynamics: TrainingDynamics,
541 pub recommendations: Vec<TrainingRecommendation>,
543 pub current_state: TrainingState,
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use chrono::Utc;
551
552 fn create_test_metrics(step: usize, loss: f64) -> ModelPerformanceMetrics {
553 ModelPerformanceMetrics {
554 training_step: step,
555 loss,
556 accuracy: Some(0.8),
557 learning_rate: 0.001,
558 batch_size: 32,
559 throughput_samples_per_sec: 100.0,
560 memory_usage_mb: 1000.0,
561 gpu_utilization: Some(0.9),
562 timestamp: Utc::now(),
563 }
564 }
565
566 #[test]
567 fn test_training_dynamics_analyzer_creation() {
568 let analyzer = TrainingDynamicsAnalyzer::new();
569 assert_eq!(analyzer.metrics_history.len(), 0);
570 }
571
572 #[test]
573 fn test_add_metrics() {
574 let mut analyzer = TrainingDynamicsAnalyzer::new();
575 let metrics = create_test_metrics(1, 0.5);
576
577 analyzer.add_metrics(metrics);
578 assert_eq!(analyzer.metrics_history.len(), 1);
579 assert_eq!(analyzer.current_state.best_loss, 0.5);
580 }
581
582 #[test]
583 fn test_convergence_detection() {
584 let mut analyzer = TrainingDynamicsAnalyzer::new();
585
586 for i in 1..=25 {
588 let loss = 1.0 / (i as f64);
589 let metrics = create_test_metrics(i, loss);
590 analyzer.add_metrics(metrics);
591 }
592
593 let status = analyzer.detect_convergence_status();
594 matches!(
595 status,
596 ConvergenceStatus::Converging | ConvergenceStatus::Converged
597 );
598 }
599
600 #[test]
601 fn test_learning_efficiency_calculation() {
602 let mut analyzer = TrainingDynamicsAnalyzer::new();
603
604 analyzer.add_metrics(create_test_metrics(1, 1.0));
605 analyzer.add_metrics(create_test_metrics(2, 0.5));
606 analyzer.add_metrics(create_test_metrics(3, 0.25));
607
608 let efficiency = analyzer.calculate_learning_efficiency();
609 assert!(efficiency > 0.0);
610 }
611
612 #[test]
613 fn test_plateau_detection() {
614 let mut analyzer = TrainingDynamicsAnalyzer::new();
615
616 for i in 1..=15 {
618 let metrics = create_test_metrics(i, 0.1); analyzer.add_metrics(metrics);
620 }
621
622 let plateau = analyzer.detect_plateau();
623 assert!(plateau.is_some());
624 }
625
626 #[test]
627 fn test_training_stability_assessment() {
628 let mut analyzer = TrainingDynamicsAnalyzer::new();
629
630 for i in 1..=20 {
632 let loss = 0.5 + (i as f64 * 0.001); let metrics = create_test_metrics(i, loss);
634 analyzer.add_metrics(metrics);
635 }
636
637 let stability = analyzer.assess_training_stability();
638 matches!(stability, TrainingStability::Stable);
639 }
640}