Skip to main content

trustformers_debug/model_diagnostics/
training.rs

1//! Training dynamics and convergence analysis.
2//!
3//! This module provides comprehensive training dynamics analysis including
4//! convergence detection, overfitting/underfitting identification, plateau
5//! detection, and training stability assessment for optimizing training processes.
6
7use std::collections::VecDeque;
8
9use super::types::{
10    ConvergenceStatus, ModelPerformanceMetrics, OverfittingIndicator, PlateauInfo,
11    TrainingDynamics, TrainingStability, UnderfittingIndicator,
12};
13
14/// Training dynamics analyzer for monitoring and analyzing training behavior.
15#[derive(Debug)]
16pub struct TrainingDynamicsAnalyzer {
17    /// Historical metrics for analysis
18    metrics_history: VecDeque<ModelPerformanceMetrics>,
19    /// Configuration for analysis thresholds
20    config: TrainingAnalysisConfig,
21    /// Current training state
22    current_state: TrainingState,
23}
24
25/// Configuration for training analysis.
26#[derive(Debug, Clone)]
27pub struct TrainingAnalysisConfig {
28    /// Window size for convergence analysis
29    pub convergence_window: usize,
30    /// Minimum improvement threshold for convergence
31    pub min_improvement_threshold: f64,
32    /// Maximum variance threshold for stability
33    pub max_variance_threshold: f64,
34    /// Minimum plateau duration to consider
35    pub min_plateau_duration: usize,
36    /// Train-validation gap threshold for overfitting
37    pub overfitting_gap_threshold: f64,
38    /// Minimum learning rate for underfitting detection
39    pub min_learning_rate: f64,
40}
41
42impl Default for TrainingAnalysisConfig {
43    fn default() -> Self {
44        Self {
45            convergence_window: 20,
46            min_improvement_threshold: 0.001,
47            max_variance_threshold: 0.1,
48            min_plateau_duration: 10,
49            overfitting_gap_threshold: 0.05,
50            min_learning_rate: 1e-6,
51        }
52    }
53}
54
55/// Current training state information.
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
57pub struct TrainingState {
58    /// Steps since last improvement
59    steps_since_improvement: usize,
60    /// Best loss achieved so far
61    best_loss: f64,
62    /// Current plateau information
63    current_plateau: Option<PlateauInfo>,
64    /// Convergence status history
65    convergence_history: VecDeque<ConvergenceStatus>,
66}
67
68impl Default for TrainingState {
69    fn default() -> Self {
70        Self {
71            steps_since_improvement: 0,
72            best_loss: f64::INFINITY,
73            current_plateau: None,
74            convergence_history: VecDeque::new(),
75        }
76    }
77}
78
79impl TrainingDynamicsAnalyzer {
80    /// Create a new training dynamics analyzer.
81    pub fn new() -> Self {
82        Self {
83            metrics_history: VecDeque::new(),
84            config: TrainingAnalysisConfig::default(),
85            current_state: TrainingState::default(),
86        }
87    }
88
89    /// Create a new analyzer with custom configuration.
90    pub fn with_config(config: TrainingAnalysisConfig) -> Self {
91        Self {
92            metrics_history: VecDeque::new(),
93            config,
94            current_state: TrainingState::default(),
95        }
96    }
97
98    /// Add new training metrics for analysis.
99    pub fn add_metrics(&mut self, metrics: ModelPerformanceMetrics) {
100        // Update training state
101        if metrics.loss < self.current_state.best_loss {
102            self.current_state.best_loss = metrics.loss;
103            self.current_state.steps_since_improvement = 0;
104        } else {
105            self.current_state.steps_since_improvement += 1;
106        }
107
108        self.metrics_history.push_back(metrics);
109
110        // Maintain reasonable history size
111        if self.metrics_history.len() > 1000 {
112            self.metrics_history.pop_front();
113        }
114
115        // Update convergence history
116        let status = self.detect_convergence_status();
117        self.current_state.convergence_history.push_back(status);
118        if self.current_state.convergence_history.len() > 50 {
119            self.current_state.convergence_history.pop_front();
120        }
121    }
122
123    /// Record training dynamics information.
124    pub fn record_training_dynamics(&mut self, _dynamics: TrainingDynamics) {
125        // Training dynamics are computed via analysis rather than stored directly
126        // This method is provided for API compatibility
127    }
128
129    /// Analyze current training dynamics.
130    pub fn analyze_training_dynamics(&self) -> TrainingDynamics {
131        let convergence_status = self.detect_convergence_status();
132        let training_stability = self.assess_training_stability();
133        let learning_efficiency = self.calculate_learning_efficiency();
134        let overfitting_indicators = self.detect_overfitting_indicators();
135        let underfitting_indicators = self.detect_underfitting_indicators();
136        let plateau_detection = self.detect_plateau();
137
138        TrainingDynamics {
139            convergence_status,
140            training_stability,
141            learning_efficiency,
142            overfitting_indicators,
143            underfitting_indicators,
144            plateau_detection,
145        }
146    }
147
148    /// Detect current convergence status.
149    pub fn detect_convergence_status(&self) -> ConvergenceStatus {
150        if self.metrics_history.len() < self.config.convergence_window {
151            return ConvergenceStatus::Unknown;
152        }
153
154        let recent_metrics: Vec<_> =
155            self.metrics_history.iter().rev().take(self.config.convergence_window).collect();
156
157        let losses: Vec<f64> = recent_metrics.iter().map(|m| m.loss).collect();
158
159        // Check for convergence patterns
160        if self.is_converged(&losses) {
161            ConvergenceStatus::Converged
162        } else if self.is_diverging(&losses) {
163            ConvergenceStatus::Diverging
164        } else if self.is_oscillating(&losses) {
165            ConvergenceStatus::Oscillating
166        } else if self.is_plateau(&losses) {
167            ConvergenceStatus::Plateau
168        } else if self.is_converging(&losses) {
169            ConvergenceStatus::Converging
170        } else {
171            ConvergenceStatus::Unknown
172        }
173    }
174
175    /// Assess training stability.
176    pub fn assess_training_stability(&self) -> TrainingStability {
177        if self.metrics_history.len() < 10 {
178            return TrainingStability::Unknown;
179        }
180
181        let recent_losses: Vec<f64> =
182            self.metrics_history.iter().rev().take(20).map(|m| m.loss).collect();
183
184        let variance = self.calculate_variance(&recent_losses);
185
186        if variance > self.config.max_variance_threshold {
187            TrainingStability::Unstable
188        } else if variance > self.config.max_variance_threshold / 2.0 {
189            TrainingStability::HighVariance
190        } else {
191            TrainingStability::Stable
192        }
193    }
194
195    /// Calculate learning efficiency score.
196    pub fn calculate_learning_efficiency(&self) -> f64 {
197        if self.metrics_history.len() < 2 {
198            return 0.0;
199        }
200
201        let initial_loss = self.metrics_history.front().unwrap().loss;
202        let current_loss = self.metrics_history.back().unwrap().loss;
203        let steps = self.metrics_history.len();
204
205        if initial_loss <= current_loss {
206            return 0.0;
207        }
208
209        let improvement = (initial_loss - current_loss) / initial_loss;
210        let efficiency = improvement / (steps as f64).sqrt();
211
212        efficiency.min(1.0)
213    }
214
215    /// Detect overfitting indicators.
216    pub fn detect_overfitting_indicators(&self) -> Vec<OverfittingIndicator> {
217        let mut indicators = Vec::new();
218
219        // Check for validation accuracy indicators (simulated for now)
220        if self.metrics_history.len() > 10 {
221            let recent_losses: Vec<f64> =
222                self.metrics_history.iter().rev().take(10).map(|m| m.loss).collect();
223
224            // Simulate validation gap detection
225            let avg_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
226            if avg_loss < 0.01 {
227                indicators.push(OverfittingIndicator::PerfectTrainingAccuracy);
228            }
229
230            // Check for loss variance indicating overfitting
231            let variance = self.calculate_variance(&recent_losses);
232            if variance > 0.05 {
233                indicators.push(OverfittingIndicator::HighVarianceInValidation);
234            }
235        }
236
237        indicators
238    }
239
240    /// Detect underfitting indicators.
241    pub fn detect_underfitting_indicators(&self) -> Vec<UnderfittingIndicator> {
242        let mut indicators = Vec::new();
243
244        if let Some(current_metrics) = self.metrics_history.back() {
245            // High training loss
246            if current_metrics.loss > 1.0 {
247                indicators.push(UnderfittingIndicator::HighTrainingLoss {
248                    loss: current_metrics.loss,
249                    threshold: 1.0,
250                });
251            }
252
253            // Low accuracy (simulated)
254            if let Some(accuracy) = current_metrics.accuracy {
255                if accuracy < 0.5 {
256                    indicators.push(UnderfittingIndicator::LowTrainingAccuracy {
257                        accuracy,
258                        threshold: 0.5,
259                    });
260                }
261            }
262
263            // Slow convergence
264            if self.current_state.steps_since_improvement > 50 {
265                indicators.push(UnderfittingIndicator::SlowConvergence {
266                    steps_taken: self.metrics_history.len(),
267                    expected: self.metrics_history.len() / 2,
268                });
269            }
270
271            // No learning
272            if self.current_state.steps_since_improvement > 100 {
273                indicators.push(UnderfittingIndicator::NoLearning {
274                    steps_without_improvement: self.current_state.steps_since_improvement,
275                });
276            }
277        }
278
279        indicators
280    }
281
282    /// Detect plateau in training.
283    pub fn detect_plateau(&self) -> Option<PlateauInfo> {
284        if self.metrics_history.len() < self.config.min_plateau_duration {
285            return None;
286        }
287
288        let recent_losses: Vec<f64> = self
289            .metrics_history
290            .iter()
291            .rev()
292            .take(self.config.min_plateau_duration)
293            .map(|m| m.loss)
294            .collect();
295
296        let variance = self.calculate_variance(&recent_losses);
297        let mean_loss = recent_losses.iter().sum::<f64>() / recent_losses.len() as f64;
298
299        // Check if variance is low enough to indicate plateau
300        if variance < self.config.min_improvement_threshold {
301            let start_step = self.metrics_history.len() - self.config.min_plateau_duration;
302            Some(PlateauInfo {
303                start_step,
304                duration_steps: self.config.min_plateau_duration,
305                plateau_value: mean_loss,
306                variance,
307            })
308        } else {
309            None
310        }
311    }
312
313    /// Generate training recommendations based on current dynamics.
314    pub fn generate_training_recommendations(&self) -> Vec<TrainingRecommendation> {
315        let mut recommendations = Vec::new();
316        let dynamics = self.analyze_training_dynamics();
317
318        match dynamics.convergence_status {
319            ConvergenceStatus::Diverging => {
320                recommendations.push(TrainingRecommendation {
321                    category: "Convergence".to_string(),
322                    priority: TrainingRecommendationPriority::Critical,
323                    description: "Training is diverging".to_string(),
324                    action: "Reduce learning rate immediately".to_string(),
325                    expected_impact: 0.8,
326                });
327            },
328            ConvergenceStatus::Plateau => {
329                recommendations.push(TrainingRecommendation {
330                    category: "Convergence".to_string(),
331                    priority: TrainingRecommendationPriority::High,
332                    description: "Training has reached a plateau".to_string(),
333                    action: "Consider learning rate scheduling or data augmentation".to_string(),
334                    expected_impact: 0.6,
335                });
336            },
337            _ => {},
338        }
339
340        if let TrainingStability::Unstable = dynamics.training_stability {
341            recommendations.push(TrainingRecommendation {
342                category: "Stability".to_string(),
343                priority: TrainingRecommendationPriority::High,
344                description: "Training is unstable".to_string(),
345                action: "Reduce learning rate or add gradient clipping".to_string(),
346                expected_impact: 0.7,
347            });
348        }
349
350        if dynamics.learning_efficiency < 0.3 {
351            recommendations.push(TrainingRecommendation {
352                category: "Efficiency".to_string(),
353                priority: TrainingRecommendationPriority::Medium,
354                description: "Low learning efficiency detected".to_string(),
355                action: "Consider architecture changes or hyperparameter tuning".to_string(),
356                expected_impact: 0.5,
357            });
358        }
359
360        recommendations
361    }
362
363    // Helper methods for convergence detection
364    fn is_converged(&self, losses: &[f64]) -> bool {
365        if losses.len() < 5 {
366            return false;
367        }
368
369        let recent_variance = self.calculate_variance(&losses[..5]);
370        recent_variance < self.config.min_improvement_threshold && losses[0] < 0.01
371    }
372
373    fn is_diverging(&self, losses: &[f64]) -> bool {
374        if losses.len() < 3 {
375            return false;
376        }
377
378        // Check if loss is consistently increasing
379        losses.windows(2).all(|w| w[1] >= w[0])
380            && (losses.last().unwrap() / losses.first().unwrap()) > 1.1
381    }
382
383    fn is_oscillating(&self, losses: &[f64]) -> bool {
384        if losses.len() < 6 {
385            return false;
386        }
387
388        // Check for oscillating pattern
389        let mut direction_changes = 0;
390        for window in losses.windows(3) {
391            let trend1 = window[1] - window[0];
392            let trend2 = window[2] - window[1];
393            if trend1.signum() != trend2.signum() {
394                direction_changes += 1;
395            }
396        }
397
398        direction_changes > losses.len() / 3
399    }
400
401    fn is_plateau(&self, losses: &[f64]) -> bool {
402        let variance = self.calculate_variance(losses);
403        variance < self.config.min_improvement_threshold
404    }
405
406    fn is_converging(&self, losses: &[f64]) -> bool {
407        if losses.len() < 3 {
408            return false;
409        }
410
411        // Check if loss is generally decreasing
412        let trend = self.calculate_trend(losses);
413        trend < -self.config.min_improvement_threshold
414    }
415
416    fn calculate_variance(&self, values: &[f64]) -> f64 {
417        if values.len() < 2 {
418            return 0.0;
419        }
420
421        let mean = values.iter().sum::<f64>() / values.len() as f64;
422        let variance =
423            values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
424        variance
425    }
426
427    fn calculate_trend(&self, values: &[f64]) -> f64 {
428        if values.len() < 2 {
429            return 0.0;
430        }
431
432        let n = values.len() as f64;
433        let x_mean = (n - 1.0) / 2.0;
434        let y_mean = values.iter().sum::<f64>() / n;
435
436        let mut numerator = 0.0;
437        let mut denominator = 0.0;
438
439        for (i, &y) in values.iter().enumerate() {
440            let x = i as f64;
441            numerator += (x - x_mean) * (y - y_mean);
442            denominator += (x - x_mean).powi(2);
443        }
444
445        if denominator == 0.0 {
446            0.0
447        } else {
448            numerator / denominator
449        }
450    }
451
452    /// Clear analysis history.
453    pub fn clear(&mut self) {
454        self.metrics_history.clear();
455        self.current_state = TrainingState::default();
456    }
457
458    /// Get current training state information.
459    pub fn get_training_state(&self) -> &TrainingState {
460        &self.current_state
461    }
462
463    /// Generate comprehensive training dynamics report.
464    pub async fn generate_report(&self) -> anyhow::Result<TrainingDynamicsReport> {
465        let training_dynamics = self.analyze_training_dynamics();
466        let recommendations = self.generate_recommendations();
467
468        Ok(TrainingDynamicsReport {
469            training_dynamics,
470            recommendations,
471            current_state: self.current_state.clone(),
472        })
473    }
474
475    /// Generate training recommendations.
476    fn generate_recommendations(&self) -> Vec<TrainingRecommendation> {
477        let mut recommendations = Vec::new();
478
479        // Add basic recommendations based on current state
480        recommendations.push(TrainingRecommendation {
481            category: "General".to_string(),
482            description: "Continue monitoring training dynamics".to_string(),
483            action: "Monitor training progress and adjust parameters as needed".to_string(),
484            priority: TrainingRecommendationPriority::Low,
485            expected_impact: 0.1,
486        });
487
488        recommendations
489    }
490}
491
492impl Default for TrainingDynamicsAnalyzer {
493    fn default() -> Self {
494        Self::new()
495    }
496}
497
498/// Training recommendation.
499#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
500pub struct TrainingRecommendation {
501    /// Category of the recommendation
502    pub category: String,
503    /// Priority level
504    pub priority: TrainingRecommendationPriority,
505    /// Description of the issue
506    pub description: String,
507    /// Recommended action
508    pub action: String,
509    /// Expected impact (0.0 to 1.0)
510    pub expected_impact: f64,
511}
512
513/// Priority levels for training recommendations.
514#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
515pub enum TrainingRecommendationPriority {
516    /// Low priority
517    Low,
518    /// Medium priority
519    Medium,
520    /// High priority
521    High,
522    /// Critical priority
523    Critical,
524}
525
526/// Comprehensive training dynamics report.
527#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
528pub struct TrainingDynamicsReport {
529    /// Training dynamics analysis
530    pub training_dynamics: TrainingDynamics,
531    /// Generated recommendations
532    pub recommendations: Vec<TrainingRecommendation>,
533    /// Current training state
534    pub current_state: TrainingState,
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use chrono::Utc;
541
542    fn create_test_metrics(step: usize, loss: f64) -> ModelPerformanceMetrics {
543        ModelPerformanceMetrics {
544            training_step: step,
545            loss,
546            accuracy: Some(0.8),
547            learning_rate: 0.001,
548            batch_size: 32,
549            throughput_samples_per_sec: 100.0,
550            memory_usage_mb: 1000.0,
551            gpu_utilization: Some(0.9),
552            timestamp: Utc::now(),
553        }
554    }
555
556    #[test]
557    fn test_training_dynamics_analyzer_creation() {
558        let analyzer = TrainingDynamicsAnalyzer::new();
559        assert_eq!(analyzer.metrics_history.len(), 0);
560    }
561
562    #[test]
563    fn test_add_metrics() {
564        let mut analyzer = TrainingDynamicsAnalyzer::new();
565        let metrics = create_test_metrics(1, 0.5);
566
567        analyzer.add_metrics(metrics);
568        assert_eq!(analyzer.metrics_history.len(), 1);
569        assert_eq!(analyzer.current_state.best_loss, 0.5);
570    }
571
572    #[test]
573    fn test_convergence_detection() {
574        let mut analyzer = TrainingDynamicsAnalyzer::new();
575
576        // Add converging sequence
577        for i in 1..=25 {
578            let loss = 1.0 / (i as f64);
579            let metrics = create_test_metrics(i, loss);
580            analyzer.add_metrics(metrics);
581        }
582
583        let status = analyzer.detect_convergence_status();
584        matches!(
585            status,
586            ConvergenceStatus::Converging | ConvergenceStatus::Converged
587        );
588    }
589
590    #[test]
591    fn test_learning_efficiency_calculation() {
592        let mut analyzer = TrainingDynamicsAnalyzer::new();
593
594        analyzer.add_metrics(create_test_metrics(1, 1.0));
595        analyzer.add_metrics(create_test_metrics(2, 0.5));
596        analyzer.add_metrics(create_test_metrics(3, 0.25));
597
598        let efficiency = analyzer.calculate_learning_efficiency();
599        assert!(efficiency > 0.0);
600    }
601
602    #[test]
603    fn test_plateau_detection() {
604        let mut analyzer = TrainingDynamicsAnalyzer::new();
605
606        // Add plateau sequence
607        for i in 1..=15 {
608            let metrics = create_test_metrics(i, 0.1); // Constant loss
609            analyzer.add_metrics(metrics);
610        }
611
612        let plateau = analyzer.detect_plateau();
613        assert!(plateau.is_some());
614    }
615
616    #[test]
617    fn test_training_stability_assessment() {
618        let mut analyzer = TrainingDynamicsAnalyzer::new();
619
620        // Add stable sequence
621        for i in 1..=20 {
622            let loss = 0.5 + (i as f64 * 0.001); // Very small variance
623            let metrics = create_test_metrics(i, loss);
624            analyzer.add_metrics(metrics);
625        }
626
627        let stability = analyzer.assess_training_stability();
628        matches!(stability, TrainingStability::Stable);
629    }
630}