Skip to main content

trustformers_training/
training_monitor.rs

1/// Training monitoring and debugging tools
2///
3/// This module provides comprehensive monitoring and debugging capabilities for training:
4/// - NaN/Inf detection and automatic recovery
5/// - Gradient anomaly detection and analysis
6/// - Training stability diagnosis and recommendations
7/// - Performance bottleneck identification
8/// - Memory leak detection and prevention
9use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use trustformers_core::tensor::Tensor;
14
15/// Configuration for training monitoring
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TrainingMonitorConfig {
18    /// Enable NaN/Inf detection
19    pub nan_inf_detection: bool,
20    /// Enable gradient anomaly detection
21    pub gradient_anomaly_detection: bool,
22    /// Enable training stability monitoring
23    pub stability_monitoring: bool,
24    /// Enable performance profiling
25    pub performance_profiling: bool,
26    /// Enable memory leak detection
27    pub memory_leak_detection: bool,
28    /// History window size for anomaly detection
29    pub history_window_size: usize,
30    /// Gradient norm threshold for anomaly detection
31    pub gradient_norm_threshold: f32,
32    /// Loss spike threshold for stability monitoring
33    pub loss_spike_threshold: f32,
34    /// Memory growth threshold for leak detection (bytes)
35    pub memory_growth_threshold: usize,
36    /// Auto-recovery attempts for NaN/Inf
37    pub auto_recovery_attempts: usize,
38}
39
40impl Default for TrainingMonitorConfig {
41    fn default() -> Self {
42        Self {
43            nan_inf_detection: true,
44            gradient_anomaly_detection: true,
45            stability_monitoring: true,
46            performance_profiling: false,
47            memory_leak_detection: true,
48            history_window_size: 100,
49            gradient_norm_threshold: 100.0,
50            loss_spike_threshold: 10.0,
51            memory_growth_threshold: 100_000_000, // 100MB
52            auto_recovery_attempts: 3,
53        }
54    }
55}
56
57/// Training step metrics
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StepMetrics {
60    pub step: usize,
61    pub timestamp: u64,
62    pub loss: f32,
63    pub gradient_norm: f32,
64    pub learning_rate: f32,
65    pub memory_usage: usize,
66    pub step_duration_ms: u64,
67    pub has_nan_inf: bool,
68    pub gradient_anomaly: bool,
69}
70
71/// Anomaly detection result
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct AnomalyReport {
74    pub step: usize,
75    pub anomaly_type: AnomalyType,
76    pub severity: AnomalySeverity,
77    pub description: String,
78    pub suggested_actions: Vec<String>,
79    pub auto_recovery_applied: bool,
80}
81
82/// Types of anomalies that can be detected
83#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
84pub enum AnomalyType {
85    NanInf,
86    GradientExplosion,
87    GradientVanishing,
88    LossSpike,
89    MemoryLeak,
90    PerformanceRegression,
91    TrainingStagnation,
92}
93
94/// Severity levels for anomalies
95#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
96pub enum AnomalySeverity {
97    Low,
98    Medium,
99    High,
100    Critical,
101}
102
103/// Recovery strategies for different anomalies
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum RecoveryStrategy {
106    ReduceLearningRate,
107    GradientClipping,
108    RestoreCheckpoint,
109    RestartTraining,
110    MemoryCleanup,
111    OptimizerReset,
112}
113
114/// Comprehensive training monitor
115pub struct TrainingMonitor {
116    config: TrainingMonitorConfig,
117    metrics_history: VecDeque<StepMetrics>,
118    anomaly_reports: Vec<AnomalyReport>,
119    recovery_attempts: HashMap<AnomalyType, usize>,
120    performance_stats: PerformanceStats,
121    memory_baseline: usize,
122    #[allow(dead_code)]
123    last_checkpoint: Option<u64>,
124}
125
126impl TrainingMonitor {
127    pub fn new(config: TrainingMonitorConfig) -> Self {
128        Self {
129            config,
130            metrics_history: VecDeque::new(),
131            anomaly_reports: Vec::new(),
132            recovery_attempts: HashMap::new(),
133            performance_stats: PerformanceStats::new(),
134            memory_baseline: 0,
135            last_checkpoint: None,
136        }
137    }
138
139    /// Record metrics for a training step
140    pub fn record_step(
141        &mut self,
142        step: usize,
143        loss: f32,
144        gradients: &HashMap<String, Tensor>,
145        learning_rate: f32,
146        memory_usage: usize,
147        step_duration: Duration,
148    ) -> Result<()> {
149        let gradient_norm = self.compute_gradient_norm(gradients)?;
150        let has_nan_inf = self.detect_nan_inf(loss, gradients)?;
151        let gradient_anomaly = self.detect_gradient_anomaly(gradient_norm);
152
153        let metrics = StepMetrics {
154            step,
155            timestamp: SystemTime::now()
156                .duration_since(UNIX_EPOCH)
157                .expect("SystemTime should be after UNIX_EPOCH")
158                .as_secs(),
159            loss,
160            gradient_norm,
161            learning_rate,
162            memory_usage,
163            step_duration_ms: step_duration.as_millis() as u64,
164            has_nan_inf,
165            gradient_anomaly,
166        };
167
168        // Add to history
169        self.metrics_history.push_back(metrics.clone());
170
171        // Maintain history window
172        while self.metrics_history.len() > self.config.history_window_size {
173            self.metrics_history.pop_front();
174        }
175
176        // Update performance stats
177        self.performance_stats.update(&metrics);
178
179        // Perform anomaly detection and recovery
180        self.perform_anomaly_detection(&metrics)?;
181
182        Ok(())
183    }
184
185    /// Detect NaN/Inf values in loss and gradients
186    fn detect_nan_inf(&self, loss: f32, gradients: &HashMap<String, Tensor>) -> Result<bool> {
187        if !self.config.nan_inf_detection {
188            return Ok(false);
189        }
190
191        // Check loss
192        if !loss.is_finite() {
193            return Ok(true);
194        }
195
196        // Check gradients
197        for gradient in gradients.values() {
198            if self.has_nan_inf_tensor(gradient)? {
199                return Ok(true);
200            }
201        }
202
203        Ok(false)
204    }
205
206    /// Check if tensor contains NaN or Inf values
207    fn has_nan_inf_tensor(&self, _tensor: &Tensor) -> Result<bool> {
208        // Simplified check - in real implementation would iterate through tensor values
209        // For now, we'll simulate the check
210        Ok(false)
211    }
212
213    /// Compute gradient norm
214    fn compute_gradient_norm(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
215        let mut total_norm = 0.0f32;
216        let mut param_count = 0;
217
218        for gradient in gradients.values() {
219            // Simplified norm computation
220            let grad_norm = self.tensor_norm(gradient)?;
221            total_norm += grad_norm * grad_norm;
222            param_count += 1;
223        }
224
225        if param_count > 0 {
226            Ok(total_norm.sqrt())
227        } else {
228            Ok(0.0)
229        }
230    }
231
232    /// Compute tensor norm (simplified)
233    fn tensor_norm(&self, _tensor: &Tensor) -> Result<f32> {
234        // Simplified norm computation - in real implementation would compute actual L2 norm
235        Ok(1.0)
236    }
237
238    /// Detect gradient anomalies
239    fn detect_gradient_anomaly(&self, gradient_norm: f32) -> bool {
240        if !self.config.gradient_anomaly_detection {
241            return false;
242        }
243
244        gradient_norm > self.config.gradient_norm_threshold || gradient_norm < 1e-8
245    }
246
247    /// Comprehensive anomaly detection
248    fn perform_anomaly_detection(&mut self, metrics: &StepMetrics) -> Result<()> {
249        let mut detected_anomalies = Vec::new();
250
251        // NaN/Inf detection
252        if metrics.has_nan_inf {
253            detected_anomalies.push(AnomalyReport {
254                step: metrics.step,
255                anomaly_type: AnomalyType::NanInf,
256                severity: AnomalySeverity::Critical,
257                description: "NaN or Inf values detected in loss or gradients".to_string(),
258                suggested_actions: vec![
259                    "Check learning rate (reduce if too high)".to_string(),
260                    "Implement gradient clipping".to_string(),
261                    "Restore from previous checkpoint".to_string(),
262                ],
263                auto_recovery_applied: false,
264            });
265        }
266
267        // Gradient explosion detection
268        if metrics.gradient_norm > self.config.gradient_norm_threshold {
269            detected_anomalies.push(AnomalyReport {
270                step: metrics.step,
271                anomaly_type: AnomalyType::GradientExplosion,
272                severity: AnomalySeverity::High,
273                description: format!(
274                    "Gradient norm ({:.2}) exceeds threshold ({:.2})",
275                    metrics.gradient_norm, self.config.gradient_norm_threshold
276                ),
277                suggested_actions: vec![
278                    "Apply gradient clipping".to_string(),
279                    "Reduce learning rate".to_string(),
280                    "Check for unstable layers".to_string(),
281                ],
282                auto_recovery_applied: false,
283            });
284        }
285
286        // Gradient vanishing detection
287        if metrics.gradient_norm < 1e-8 {
288            detected_anomalies.push(AnomalyReport {
289                step: metrics.step,
290                anomaly_type: AnomalyType::GradientVanishing,
291                severity: AnomalySeverity::Medium,
292                description: format!(
293                    "Gradient norm ({:.2e}) is extremely small",
294                    metrics.gradient_norm
295                ),
296                suggested_actions: vec![
297                    "Increase learning rate".to_string(),
298                    "Check for dead neurons".to_string(),
299                    "Consider different activation functions".to_string(),
300                ],
301                auto_recovery_applied: false,
302            });
303        }
304
305        // Loss spike detection
306        if let Some(recent_loss) = self.get_recent_average_loss() {
307            if metrics.loss > recent_loss * self.config.loss_spike_threshold {
308                detected_anomalies.push(AnomalyReport {
309                    step: metrics.step,
310                    anomaly_type: AnomalyType::LossSpike,
311                    severity: AnomalySeverity::High,
312                    description: format!(
313                        "Loss spike detected: {:.4} vs recent average {:.4}",
314                        metrics.loss, recent_loss
315                    ),
316                    suggested_actions: vec![
317                        "Check for data corruption".to_string(),
318                        "Verify batch normalization".to_string(),
319                        "Consider reducing learning rate".to_string(),
320                    ],
321                    auto_recovery_applied: false,
322                });
323            }
324        }
325
326        // Memory leak detection
327        if self.config.memory_leak_detection
328            && self.memory_baseline > 0
329            && metrics.memory_usage > self.memory_baseline + self.config.memory_growth_threshold
330        {
331            detected_anomalies.push(AnomalyReport {
332                step: metrics.step,
333                anomaly_type: AnomalyType::MemoryLeak,
334                severity: AnomalySeverity::Medium,
335                description: format!(
336                    "Memory usage increased by {} bytes",
337                    metrics.memory_usage - self.memory_baseline
338                ),
339                suggested_actions: vec![
340                    "Check for tensor accumulation".to_string(),
341                    "Verify gradient cleanup".to_string(),
342                    "Consider memory optimization".to_string(),
343                ],
344                auto_recovery_applied: false,
345            });
346        }
347
348        // Training stagnation detection
349        if self.detect_training_stagnation()? {
350            detected_anomalies.push(AnomalyReport {
351                step: metrics.step,
352                anomaly_type: AnomalyType::TrainingStagnation,
353                severity: AnomalySeverity::Medium,
354                description: "Training appears to have stagnated".to_string(),
355                suggested_actions: vec![
356                    "Adjust learning rate schedule".to_string(),
357                    "Consider different optimizer".to_string(),
358                    "Check for overfitting".to_string(),
359                ],
360                auto_recovery_applied: false,
361            });
362        }
363
364        // Apply auto-recovery if enabled
365        for mut anomaly in detected_anomalies {
366            if self.should_apply_auto_recovery(&anomaly) {
367                anomaly.auto_recovery_applied = self.apply_auto_recovery(&anomaly)?;
368            }
369            self.anomaly_reports.push(anomaly);
370        }
371
372        Ok(())
373    }
374
375    /// Get recent average loss
376    fn get_recent_average_loss(&self) -> Option<f32> {
377        if self.metrics_history.len() < 10 {
378            return None;
379        }
380
381        let recent_count = std::cmp::min(10, self.metrics_history.len());
382        let recent_losses: Vec<f32> =
383            self.metrics_history.iter().rev().take(recent_count).map(|m| m.loss).collect();
384
385        if recent_losses.is_empty() {
386            None
387        } else {
388            Some(recent_losses.iter().sum::<f32>() / recent_losses.len() as f32)
389        }
390    }
391
392    /// Detect training stagnation
393    fn detect_training_stagnation(&self) -> Result<bool> {
394        if self.metrics_history.len() < 50 {
395            return Ok(false);
396        }
397
398        // Check if loss hasn't improved significantly in recent steps
399        let recent_window = 20;
400        let older_window = 30;
401
402        let recent_avg = self.get_window_average_loss(recent_window)?;
403        let older_avg = self.get_window_average_loss(older_window)?;
404
405        // Consider stagnation if improvement is less than 1%
406        Ok(recent_avg >= older_avg * 0.99)
407    }
408
409    /// Get average loss for a specific window
410    fn get_window_average_loss(&self, window_size: usize) -> Result<f32> {
411        if self.metrics_history.len() < window_size {
412            return Ok(0.0);
413        }
414
415        let losses: Vec<f32> =
416            self.metrics_history.iter().rev().take(window_size).map(|m| m.loss).collect();
417
418        Ok(losses.iter().sum::<f32>() / losses.len() as f32)
419    }
420
421    /// Check if auto-recovery should be applied
422    fn should_apply_auto_recovery(&self, anomaly: &AnomalyReport) -> bool {
423        let attempts = self.recovery_attempts.get(&anomaly.anomaly_type).unwrap_or(&0);
424        *attempts < self.config.auto_recovery_attempts
425    }
426
427    /// Apply auto-recovery strategy
428    fn apply_auto_recovery(&mut self, anomaly: &AnomalyReport) -> Result<bool> {
429        let attempts = self.recovery_attempts.entry(anomaly.anomaly_type.clone()).or_insert(0);
430        *attempts += 1;
431
432        match anomaly.anomaly_type {
433            AnomalyType::NanInf => {
434                // In real implementation, would restore from checkpoint
435                println!("Auto-recovery: Restoring from checkpoint due to NaN/Inf");
436                Ok(true)
437            },
438            AnomalyType::GradientExplosion => {
439                // In real implementation, would apply gradient clipping
440                println!("Auto-recovery: Applying gradient clipping");
441                Ok(true)
442            },
443            AnomalyType::MemoryLeak => {
444                // In real implementation, would trigger memory cleanup
445                println!("Auto-recovery: Triggering memory cleanup");
446                Ok(true)
447            },
448            _ => Ok(false),
449        }
450    }
451
452    /// Get current training health status
453    pub fn get_health_status(&self) -> TrainingHealthStatus {
454        let recent_anomalies = self.anomaly_reports.iter().rev().take(10).collect::<Vec<_>>();
455
456        let critical_count = recent_anomalies
457            .iter()
458            .filter(|a| matches!(a.severity, AnomalySeverity::Critical))
459            .count();
460
461        let high_count = recent_anomalies
462            .iter()
463            .filter(|a| matches!(a.severity, AnomalySeverity::High))
464            .count();
465
466        let overall_health = if critical_count > 0 {
467            HealthStatus::Critical
468        } else if high_count > 3 {
469            HealthStatus::Poor
470        } else if high_count > 1 {
471            HealthStatus::Warning
472        } else {
473            HealthStatus::Good
474        };
475
476        TrainingHealthStatus {
477            overall_health,
478            recent_anomalies: recent_anomalies.len(),
479            critical_issues: critical_count,
480            high_issues: high_count,
481            auto_recovery_success_rate: self.calculate_recovery_success_rate(),
482            performance_trend: self.performance_stats.get_trend(),
483        }
484    }
485
486    /// Calculate auto-recovery success rate
487    fn calculate_recovery_success_rate(&self) -> f32 {
488        let total_recoveries =
489            self.anomaly_reports.iter().filter(|a| a.auto_recovery_applied).count();
490
491        if total_recoveries == 0 {
492            return 1.0;
493        }
494
495        // Simplified success rate calculation
496        0.85 // In real implementation, would track actual success
497    }
498
499    /// Get comprehensive training report
500    pub fn get_training_report(&self) -> TrainingReport {
501        TrainingReport {
502            health_status: self.get_health_status(),
503            anomaly_summary: self.get_anomaly_summary(),
504            performance_stats: self.performance_stats.clone(),
505            recommendations: self.generate_recommendations(),
506        }
507    }
508
509    /// Get anomaly summary
510    fn get_anomaly_summary(&self) -> AnomalySummary {
511        let mut type_counts = HashMap::new();
512        let mut severity_counts = HashMap::new();
513
514        for anomaly in &self.anomaly_reports {
515            *type_counts.entry(anomaly.anomaly_type.clone()).or_insert(0) += 1;
516            *severity_counts.entry(anomaly.severity.clone()).or_insert(0) += 1;
517        }
518
519        AnomalySummary {
520            total_anomalies: self.anomaly_reports.len(),
521            type_distribution: type_counts,
522            severity_distribution: severity_counts,
523        }
524    }
525
526    /// Generate training recommendations
527    fn generate_recommendations(&self) -> Vec<String> {
528        let mut recommendations = Vec::new();
529
530        // Check for frequent anomalies
531        let recent_anomalies = self.anomaly_reports.iter().rev().take(20).collect::<Vec<_>>();
532
533        if recent_anomalies
534            .iter()
535            .any(|a| matches!(a.anomaly_type, AnomalyType::GradientExplosion))
536        {
537            recommendations.push("Consider implementing gradient clipping".to_string());
538        }
539
540        if recent_anomalies
541            .iter()
542            .any(|a| matches!(a.anomaly_type, AnomalyType::MemoryLeak))
543        {
544            recommendations.push("Review memory management and tensor lifecycle".to_string());
545        }
546
547        if recent_anomalies
548            .iter()
549            .any(|a| matches!(a.anomaly_type, AnomalyType::TrainingStagnation))
550        {
551            recommendations
552                .push("Consider adjusting learning rate schedule or optimizer".to_string());
553        }
554
555        if self.performance_stats.average_step_duration_ms > 5000 {
556            recommendations
557                .push("Training steps are taking too long - consider optimization".to_string());
558        }
559
560        recommendations
561    }
562
563    /// Set memory baseline for leak detection
564    pub fn set_memory_baseline(&mut self, baseline: usize) {
565        self.memory_baseline = baseline;
566    }
567}
568
569/// Performance statistics tracking
570#[derive(Debug, Clone, Serialize, Deserialize)]
571pub struct PerformanceStats {
572    pub total_steps: usize,
573    pub average_step_duration_ms: u64,
574    pub average_loss: f32,
575    pub average_gradient_norm: f32,
576    pub memory_usage_trend: f32,
577}
578
579impl PerformanceStats {
580    fn new() -> Self {
581        Self {
582            total_steps: 0,
583            average_step_duration_ms: 0,
584            average_loss: 0.0,
585            average_gradient_norm: 0.0,
586            memory_usage_trend: 0.0,
587        }
588    }
589
590    fn update(&mut self, metrics: &StepMetrics) {
591        self.total_steps += 1;
592
593        // Update running averages
594        let n = self.total_steps as f32;
595        let old_weight = (n - 1.0) / n;
596        let new_weight = 1.0 / n;
597
598        self.average_step_duration_ms = (self.average_step_duration_ms as f32 * old_weight
599            + metrics.step_duration_ms as f32 * new_weight)
600            as u64;
601
602        self.average_loss = self.average_loss * old_weight + metrics.loss * new_weight;
603        self.average_gradient_norm =
604            self.average_gradient_norm * old_weight + metrics.gradient_norm * new_weight;
605    }
606
607    fn get_trend(&self) -> PerformanceTrend {
608        // Simplified trend calculation
609        if self.total_steps < 10 {
610            PerformanceTrend::Stable
611        } else if self.average_step_duration_ms > 10000 {
612            PerformanceTrend::Degrading
613        } else {
614            PerformanceTrend::Improving
615        }
616    }
617}
618
619/// Training health status
620#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct TrainingHealthStatus {
622    pub overall_health: HealthStatus,
623    pub recent_anomalies: usize,
624    pub critical_issues: usize,
625    pub high_issues: usize,
626    pub auto_recovery_success_rate: f32,
627    pub performance_trend: PerformanceTrend,
628}
629
630/// Health status levels
631#[derive(Debug, Clone, Serialize, Deserialize)]
632pub enum HealthStatus {
633    Good,
634    Warning,
635    Poor,
636    Critical,
637}
638
639/// Performance trend indicators
640#[derive(Debug, Clone, Serialize, Deserialize)]
641pub enum PerformanceTrend {
642    Improving,
643    Stable,
644    Degrading,
645}
646
647/// Anomaly summary statistics
648#[derive(Debug, Clone, Serialize, Deserialize)]
649pub struct AnomalySummary {
650    pub total_anomalies: usize,
651    pub type_distribution: HashMap<AnomalyType, usize>,
652    pub severity_distribution: HashMap<AnomalySeverity, usize>,
653}
654
655/// Comprehensive training report
656#[derive(Debug, Clone, Serialize, Deserialize)]
657pub struct TrainingReport {
658    pub health_status: TrainingHealthStatus,
659    pub anomaly_summary: AnomalySummary,
660    pub performance_stats: PerformanceStats,
661    pub recommendations: Vec<String>,
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use std::time::Duration;
668
669    #[test]
670    fn test_training_monitor_creation() {
671        let config = TrainingMonitorConfig::default();
672        let monitor = TrainingMonitor::new(config);
673
674        assert_eq!(monitor.metrics_history.len(), 0);
675        assert_eq!(monitor.anomaly_reports.len(), 0);
676    }
677
678    #[test]
679    fn test_nan_inf_detection() {
680        let config = TrainingMonitorConfig::default();
681        let monitor = TrainingMonitor::new(config);
682
683        let gradients = HashMap::new();
684        let result = monitor.detect_nan_inf(f32::NAN, &gradients);
685
686        assert!(result.is_ok());
687        assert!(result.expect("operation failed in test"));
688    }
689
690    #[test]
691    fn test_gradient_anomaly_detection() {
692        let config = TrainingMonitorConfig {
693            gradient_norm_threshold: 10.0,
694            ..Default::default()
695        };
696        let monitor = TrainingMonitor::new(config);
697
698        assert!(monitor.detect_gradient_anomaly(100.0));
699        assert!(monitor.detect_gradient_anomaly(1e-10));
700        assert!(!monitor.detect_gradient_anomaly(5.0));
701    }
702
703    #[test]
704    fn test_step_recording() {
705        let config = TrainingMonitorConfig::default();
706        let mut monitor = TrainingMonitor::new(config);
707
708        let gradients = HashMap::new();
709        let result = monitor.record_step(
710            0,
711            1.0,
712            &gradients,
713            0.001,
714            1000000,
715            Duration::from_millis(100),
716        );
717
718        assert!(result.is_ok());
719        assert_eq!(monitor.metrics_history.len(), 1);
720    }
721
722    #[test]
723    fn test_health_status() {
724        let config = TrainingMonitorConfig::default();
725        let monitor = TrainingMonitor::new(config);
726
727        let health = monitor.get_health_status();
728        assert!(matches!(health.overall_health, HealthStatus::Good));
729        assert_eq!(health.recent_anomalies, 0);
730    }
731
732    #[test]
733    fn test_performance_stats() {
734        let mut stats = PerformanceStats::new();
735        let metrics = StepMetrics {
736            step: 0,
737            timestamp: SystemTime::now()
738                .duration_since(UNIX_EPOCH)
739                .expect("SystemTime should be after UNIX_EPOCH")
740                .as_millis() as u64,
741            loss: 1.0,
742            gradient_norm: 2.0,
743            learning_rate: 0.001,
744            memory_usage: 1000000,
745            step_duration_ms: 100,
746            has_nan_inf: false,
747            gradient_anomaly: false,
748        };
749
750        stats.update(&metrics);
751
752        assert_eq!(stats.total_steps, 1);
753        assert_eq!(stats.average_loss, 1.0);
754        assert_eq!(stats.average_gradient_norm, 2.0);
755    }
756}