Skip to main content

trustformers_debug/model_diagnostics/
training.rs

1//! Training dynamics and convergence analysis.
2//!
3//! This module provides comprehensive training dynamics analysis including
4//! convergence detection, overfitting/underfitting identification, plateau
5//! detection, and training stability assessment for optimizing training processes.
6
7use std::collections::VecDeque;
8
9use super::types::{
10    ConvergenceStatus, ModelPerformanceMetrics, OverfittingIndicator, PlateauInfo,
11    TrainingDynamics, TrainingStability, UnderfittingIndicator,
12};
13
14/// Training dynamics analyzer for monitoring and analyzing training behavior.
15#[derive(Debug)]
16pub struct TrainingDynamicsAnalyzer {
17    /// Historical metrics for analysis
18    metrics_history: VecDeque<ModelPerformanceMetrics>,
19    /// Configuration for analysis thresholds
20    config: TrainingAnalysisConfig,
21    /// Current training state
22    current_state: TrainingState,
23}
24
25/// Configuration for training analysis.
26#[derive(Debug, Clone)]
27pub struct TrainingAnalysisConfig {
28    /// Window size for convergence analysis
29    pub convergence_window: usize,
30    /// Minimum improvement threshold for convergence
31    pub min_improvement_threshold: f64,
32    /// Maximum variance threshold for stability
33    pub max_variance_threshold: f64,
34    /// Minimum plateau duration to consider
35    pub min_plateau_duration: usize,
36    /// Train-validation gap threshold for overfitting
37    pub overfitting_gap_threshold: f64,
38    /// Minimum learning rate for underfitting detection
39    pub min_learning_rate: f64,
40}
41
42impl Default for TrainingAnalysisConfig {
43    fn default() -> Self {
44        Self {
45            convergence_window: 20,
46            min_improvement_threshold: 0.001,
47            max_variance_threshold: 0.1,
48            min_plateau_duration: 10,
49            overfitting_gap_threshold: 0.05,
50            min_learning_rate: 1e-6,
51        }
52    }
53}
54
55/// Current training state information.
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TrainingState {
58    /// Steps since last improvement
59    steps_since_improvement: usize,
60    /// Best loss achieved so far
61    best_loss: f64,
62    /// Current plateau information
63    current_plateau: Option<PlateauInfo>,
64    /// Convergence status history
65    convergence_history: VecDeque<ConvergenceStatus>,
66}
67
68impl Default for TrainingState {
69    fn default() -> Self {
70        Self {
71            steps_since_improvement: 0,
72            best_loss: f64::INFINITY,
73            current_plateau: None,
74            convergence_history: VecDeque::new(),
75        }
76    }
77}
78
79impl TrainingDynamicsAnalyzer {
80    /// Create a new training dynamics analyzer.
81    pub fn new() -> Self {
82        Self {
83            metrics_history: VecDeque::new(),
84            config: TrainingAnalysisConfig::default(),
85            current_state: TrainingState::default(),
86        }
87    }
88
89    /// Create a new analyzer with custom configuration.
90    pub fn with_config(config: TrainingAnalysisConfig) -> Self {
91        Self {
92            metrics_history: VecDeque::new(),
93            config,
94            current_state: TrainingState::default(),
95        }
96    }
97
98    /// Add new training metrics for analysis.
99    pub fn add_metrics(&mut self, metrics: ModelPerformanceMetrics) {
100        // Update training state
101        if metrics.loss < self.current_state.best_loss {
102            self.current_state.best_loss = metrics.loss;
103            self.current_state.steps_since_improvement = 0;
104        } else {
105            self.current_state.steps_since_improvement += 1;
106        }
107
108        self.metrics_history.push_back(metrics);
109
110        // Maintain reasonable history size
111        if self.metrics_history.len() > 1000 {
112            self.metrics_history.pop_front();
113        }
114
115        // Update convergence history
116        let status = self.detect_convergence_status();
117        self.current_state.convergence_history.push_back(status);
118        if self.current_state.convergence_history.len() > 50 {
119            self.current_state.convergence_history.pop_front();
120        }
121    }
122
123    /// Record training dynamics information.
124    pub fn record_training_dynamics(&mut self, _dynamics: TrainingDynamics) {
125        // Training dynamics are computed via analysis rather than stored directly
126        // This method is provided for API compatibility
127    }
128
129    /// Analyze current training dynamics.
130    pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
131        let convergence_status = self.detect_convergence_status();
132        let training_stability = self.assess_training_stability();
133        let learning_efficiency = self.calculate_learning_efficiency();
134        let overfitting_indicators = self.detect_overfitting_indicators();
135        let underfitting_indicators = self.detect_underfitting_indicators();
136        let plateau_detection = self.detect_plateau();
137
138        TrainingDynamics {
139            convergence_status,
140            training_stability,
141            learning_efficiency,
142            overfitting_indicators,
143            underfitting_indicators,
144            plateau_detection,
145        }
146    }
147
148    /// Detect current convergence status.
149    pub fn detect_convergence_status(&self) -> ConvergenceStatus {
150        if self.metrics_history.len() < self.config.convergence_window {
151            return ConvergenceStatus::Unknown;
152        }
153
154        let recent_metrics: Vec<_> =
155            self.metrics_history.iter().rev().take(self.config.convergence_window).collect();
156
157        let losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
158
159        // Check for convergence patterns
160        if self.is_converged(&losses) {
161            ConvergenceStatus::Converged
162        } else if self.is_diverging(&losses) {
163            ConvergenceStatus::Diverging
164        } else if self.is_oscillating(&losses) {
165            ConvergenceStatus::Oscillating
166        } else if self.is_plateau(&losses) {
167            ConvergenceStatus::Plateau
168        } else if self.is_converging(&losses) {
169            ConvergenceStatus::Converging
170        } else {
171            ConvergenceStatus::Unknown
172        }
173    }
174
175    /// Assess training stability.
176    pub fn assess_training_stability(&self) -> TrainingStability {
177        if self.metrics_history.len() < 10 {
178            return TrainingStability::Unknown;
179        }
180
181        let recent_losses: Vec<f64> =
182            self.metrics_history.iter().rev().take(20).map(|m| m.loss).collect();
183
184        let variance = self.calculate_variance(&recent_losses);
185
186        if variance > self.config.max_variance_threshold {
187            TrainingStability::Unstable
188        } else if variance > self.config.max_variance_threshold / 2.0 {
189            TrainingStability::HighVariance
190        } else {
191            TrainingStability::Stable
192        }
193    }
194
195    /// Calculate learning efficiency score.
196    pub fn calculate_learning_efficiency(&self) -> f64 {
197        if self.metrics_history.len() < 2 {
198            return 0.0;
199        }
200
201        let initial_loss = self
202            .metrics_history
203            .front()
204            .expect("metrics_history has at least 2 elements")
205            .loss;
206        let current_loss = self
207            .metrics_history
208            .back()
209            .expect("metrics_history has at least 2 elements")
210            .loss;
211        let steps = self.metrics_history.len();
212
213        if initial_loss <= current_loss {
214            return 0.0;
215        }
216
217        let improvement = (initial_loss - current_loss) / initial_loss;
218        let efficiency = improvement / (steps as f64).sqrt();
219
220        efficiency.min(1.0)
221    }
222
223    /// Detect overfitting indicators.
224    pub fn detect_overfitting_indicators(&self) -> Vec<OverfittingIndicator> {
225        let mut indicators = Vec::new();
226
227        // Check for validation accuracy indicators (simulated for now)
228        if self.metrics_history.len() > 10 {
229            let recent_losses: Vec<f64> =
230                self.metrics_history.iter().rev().take(10).map(|m| m.loss).collect();
231
232            // Simulate validation gap detection
233            let avg_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
234            if avg_loss < 0.01 {
235                indicators.push(OverfittingIndicator::PerfectTrainingAccuracy);
236            }
237
238            // Check for loss variance indicating overfitting
239            let variance = self.calculate_variance(&recent_losses);
240            if variance > 0.05 {
241                indicators.push(OverfittingIndicator::HighVarianceInValidation);
242            }
243        }
244
245        indicators
246    }
247
248    /// Detect underfitting indicators.
249    pub fn detect_underfitting_indicators(&self) -> Vec<UnderfittingIndicator> {
250        let mut indicators = Vec::new();
251
252        if let Some(current_metrics) = self.metrics_history.back() {
253            // High training loss
254            if current_metrics.loss > 1.0 {
255                indicators.push(UnderfittingIndicator::HighTrainingLoss {
256                    loss: current_metrics.loss,
257                    threshold: 1.0,
258                });
259            }
260
261            // Low accuracy (simulated)
262            if let Some(accuracy) = current_metrics.accuracy {
263                if accuracy < 0.5 {
264                    indicators.push(UnderfittingIndicator::LowTrainingAccuracy {
265                        accuracy,
266                        threshold: 0.5,
267                    });
268                }
269            }
270
271            // Slow convergence
272            if self.current_state.steps_since_improvement > 50 {
273                indicators.push(UnderfittingIndicator::SlowConvergence {
274                    steps_taken: self.metrics_history.len(),
275                    expected: self.metrics_history.len() / 2,
276                });
277            }
278
279            // No learning
280            if self.current_state.steps_since_improvement > 100 {
281                indicators.push(UnderfittingIndicator::NoLearning {
282                    steps_without_improvement: self.current_state.steps_since_improvement,
283                });
284            }
285        }
286
287        indicators
288    }
289
290    /// Detect plateau in training.
291    pub fn detect_plateau(&self) -> Option<PlateauInfo> {
292        if self.metrics_history.len() < self.config.min_plateau_duration {
293            return None;
294        }
295
296        let recent_losses: Vec<f64> = self
297            .metrics_history
298            .iter()
299            .rev()
300            .take(self.config.min_plateau_duration)
301            .map(|m| m.loss)
302            .collect();
303
304        let variance = self.calculate_variance(&recent_losses);
305        let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
306
307        // Check if variance is low enough to indicate plateau
308        if variance < self.config.min_improvement_threshold {
309            let start_step = self.metrics_history.len() - self.config.min_plateau_duration;
310            Some(PlateauInfo {
311                start_step,
312                duration_steps: self.config.min_plateau_duration,
313                plateau_value: mean_loss,
314                variance,
315            })
316        } else {
317            None
318        }
319    }
320
321    /// Generate training recommendations based on current dynamics.
322    pub fn generate_training_recommendations(&self) -> Vec<TrainingRecommendation> {
323        let mut recommendations = Vec::new();
324        let dynamics = self.analyze_training_dynamics();
325
326        match dynamics.convergence_status {
327            ConvergenceStatus::Diverging => {
328                recommendations.push(TrainingRecommendation {
329                    category: "Convergence".to_string(),
330                    priority: TrainingRecommendationPriority::Critical,
331                    description: "Training is diverging".to_string(),
332                    action: "Reduce learning rate immediately".to_string(),
333                    expected_impact: 0.8,
334                });
335            },
336            ConvergenceStatus::Plateau => {
337                recommendations.push(TrainingRecommendation {
338                    category: "Convergence".to_string(),
339                    priority: TrainingRecommendationPriority::High,
340                    description: "Training has reached a plateau".to_string(),
341                    action: "Consider learning rate scheduling or data augmentation".to_string(),
342                    expected_impact: 0.6,
343                });
344            },
345            _ => {},
346        }
347
348        if let TrainingStability::Unstable = dynamics.training_stability {
349            recommendations.push(TrainingRecommendation {
350                category: "Stability".to_string(),
351                priority: TrainingRecommendationPriority::High,
352                description: "Training is unstable".to_string(),
353                action: "Reduce learning rate or add gradient clipping".to_string(),
354                expected_impact: 0.7,
355            });
356        }
357
358        if dynamics.learning_efficiency < 0.3 {
359            recommendations.push(TrainingRecommendation {
360                category: "Efficiency".to_string(),
361                priority: TrainingRecommendationPriority::Medium,
362                description: "Low learning efficiency detected".to_string(),
363                action: "Consider architecture changes or hyperparameter tuning".to_string(),
364                expected_impact: 0.5,
365            });
366        }
367
368        recommendations
369    }
370
371    // Helper methods for convergence detection
372    fn is_converged(&self, losses: &[f64]) -> bool {
373        if losses.len() < 5 {
374            return false;
375        }
376
377        let recent_variance = self.calculate_variance(&losses[..5]);
378        recent_variance < self.config.min_improvement_threshold && losses[0] < 0.01
379    }
380
381    fn is_diverging(&self, losses: &[f64]) -> bool {
382        if losses.len() < 3 {
383            return false;
384        }
385
386        // Check if loss is consistently increasing
387        losses.windows(2).all(|w| w[1] >= w[0])
388            && (losses.last().expect("losses has at least 3 elements")
389                / losses.first().expect("losses has at least 3 elements"))
390                > 1.1
391    }
392
393    fn is_oscillating(&self, losses: &[f64]) -> bool {
394        if losses.len() < 6 {
395            return false;
396        }
397
398        // Check for oscillating pattern
399        let mut direction_changes = 0;
400        for window in losses.windows(3) {
401            let trend1 = window[1] - window[0];
402            let trend2 = window[2] - window[1];
403            if trend1.signum() != trend2.signum() {
404                direction_changes += 1;
405            }
406        }
407
408        direction_changes > losses.len() / 3
409    }
410
411    fn is_plateau(&self, losses: &[f64]) -> bool {
412        let variance = self.calculate_variance(losses);
413        variance < self.config.min_improvement_threshold
414    }
415
416    fn is_converging(&self, losses: &[f64]) -> bool {
417        if losses.len() < 3 {
418            return false;
419        }
420
421        // Check if loss is generally decreasing
422        let trend = self.calculate_trend(losses);
423        trend < -self.config.min_improvement_threshold
424    }
425
426    fn calculate_variance(&self, values: &[f64]) -> f64 {
427        if values.len() < 2 {
428            return 0.0;
429        }
430
431        let mean = values.iter().sum::<f64>() / values.len() as f64;
432        let variance =
433            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
434        variance
435    }
436
437    fn calculate_trend(&self, values: &[f64]) -> f64 {
438        if values.len() < 2 {
439            return 0.0;
440        }
441
442        let n = values.len() as f64;
443        let x_mean = (n - 1.0) / 2.0;
444        let y_mean = values.iter().sum::<f64>() / n;
445
446        let mut numerator = 0.0;
447        let mut denominator = 0.0;
448
449        for (i, &y) in values.iter().enumerate() {
450            let x = i as f64;
451            numerator += (x - x_mean) * (y - y_mean);
452            denominator += (x - x_mean).powi(2);
453        }
454
455        if denominator == 0.0 {
456            0.0
457        } else {
458            numerator / denominator
459        }
460    }
461
462    /// Clear analysis history.
463    pub fn clear(&mut self) {
464        self.metrics_history.clear();
465        self.current_state = TrainingState::default();
466    }
467
468    /// Get current training state information.
469    pub fn get_training_state(&self) -> &TrainingState {
470        &self.current_state
471    }
472
473    /// Generate comprehensive training dynamics report.
474    pub async fn generate_report(&self) -> anyhow::Result<TrainingDynamicsReport> {
475        let training_dynamics = self.analyze_training_dynamics();
476        let recommendations = self.generate_recommendations();
477
478        Ok(TrainingDynamicsReport {
479            training_dynamics,
480            recommendations,
481            current_state: self.current_state.clone(),
482        })
483    }
484
485    /// Generate training recommendations.
486    fn generate_recommendations(&self) -> Vec<TrainingRecommendation> {
487        let mut recommendations = Vec::new();
488
489        // Add basic recommendations based on current state
490        recommendations.push(TrainingRecommendation {
491            category: "General".to_string(),
492            description: "Continue monitoring training dynamics".to_string(),
493            action: "Monitor training progress and adjust parameters as needed".to_string(),
494            priority: TrainingRecommendationPriority::Low,
495            expected_impact: 0.1,
496        });
497
498        recommendations
499    }
500}
501
502impl Default for TrainingDynamicsAnalyzer {
503    fn default() -> Self {
504        Self::new()
505    }
506}
507
508/// Training recommendation.
509#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
510pub struct TrainingRecommendation {
511    /// Category of the recommendation
512    pub category: String,
513    /// Priority level
514    pub priority: TrainingRecommendationPriority,
515    /// Description of the issue
516    pub description: String,
517    /// Recommended action
518    pub action: String,
519    /// Expected impact (0.0 to 1.0)
520    pub expected_impact: f64,
521}
522
523/// Priority levels for training recommendations.
524#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
525pub enum TrainingRecommendationPriority {
526    /// Low priority
527    Low,
528    /// Medium priority
529    Medium,
530    /// High priority
531    High,
532    /// Critical priority
533    Critical,
534}
535
536/// Comprehensive training dynamics report.
537#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
538pub struct TrainingDynamicsReport {
539    /// Training dynamics analysis
540    pub training_dynamics: TrainingDynamics,
541    /// Generated recommendations
542    pub recommendations: Vec<TrainingRecommendation>,
543    /// Current training state
544    pub current_state: TrainingState,
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    use chrono::Utc;
551
552    fn create_test_metrics(step: usize, loss: f64) -> ModelPerformanceMetrics {
553        ModelPerformanceMetrics {
554            training_step: step,
555            loss,
556            accuracy: Some(0.8),
557            learning_rate: 0.001,
558            batch_size: 32,
559            throughput_samples_per_sec: 100.0,
560            memory_usage_mb: 1000.0,
561            gpu_utilization: Some(0.9),
562            timestamp: Utc::now(),
563        }
564    }
565
566    #[test]
567    fn test_training_dynamics_analyzer_creation() {
568        let analyzer = TrainingDynamicsAnalyzer::new();
569        assert_eq!(analyzer.metrics_history.len(), 0);
570    }
571
572    #[test]
573    fn test_add_metrics() {
574        let mut analyzer = TrainingDynamicsAnalyzer::new();
575        let metrics = create_test_metrics(1, 0.5);
576
577        analyzer.add_metrics(metrics);
578        assert_eq!(analyzer.metrics_history.len(), 1);
579        assert_eq!(analyzer.current_state.best_loss, 0.5);
580    }
581
582    #[test]
583    fn test_convergence_detection() {
584        let mut analyzer = TrainingDynamicsAnalyzer::new();
585
586        // Add converging sequence
587        for i in 1..=25 {
588            let loss = 1.0 / (i as f64);
589            let metrics = create_test_metrics(i, loss);
590            analyzer.add_metrics(metrics);
591        }
592
593        let status = analyzer.detect_convergence_status();
594        matches!(
595            status,
596            ConvergenceStatus::Converging | ConvergenceStatus::Converged
597        );
598    }
599
600    #[test]
601    fn test_learning_efficiency_calculation() {
602        let mut analyzer = TrainingDynamicsAnalyzer::new();
603
604        analyzer.add_metrics(create_test_metrics(1, 1.0));
605        analyzer.add_metrics(create_test_metrics(2, 0.5));
606        analyzer.add_metrics(create_test_metrics(3, 0.25));
607
608        let efficiency = analyzer.calculate_learning_efficiency();
609        assert!(efficiency > 0.0);
610    }
611
612    #[test]
613    fn test_plateau_detection() {
614        let mut analyzer = TrainingDynamicsAnalyzer::new();
615
616        // Add plateau sequence
617        for i in 1..=15 {
618            let metrics = create_test_metrics(i, 0.1); // Constant loss
619            analyzer.add_metrics(metrics);
620        }
621
622        let plateau = analyzer.detect_plateau();
623        assert!(plateau.is_some());
624    }
625
626    #[test]
627    fn test_training_stability_assessment() {
628        let mut analyzer = TrainingDynamicsAnalyzer::new();
629
630        // Add stable sequence
631        for i in 1..=20 {
632            let loss = 0.5 + (i as f64 * 0.001); // Very small variance
633            let metrics = create_test_metrics(i, loss);
634            analyzer.add_metrics(metrics);
635        }
636
637        let stability = analyzer.assess_training_stability();
638        matches!(stability, TrainingStability::Stable);
639    }
640}