1use anyhow::Result;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13use trustformers_core::tensor::Tensor;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TrainingMonitorConfig {
18 pub nan_inf_detection: bool,
20 pub gradient_anomaly_detection: bool,
22 pub stability_monitoring: bool,
24 pub performance_profiling: bool,
26 pub memory_leak_detection: bool,
28 pub history_window_size: usize,
30 pub gradient_norm_threshold: f32,
32 pub loss_spike_threshold: f32,
34 pub memory_growth_threshold: usize,
36 pub auto_recovery_attempts: usize,
38}
39
40impl Default for TrainingMonitorConfig {
41 fn default() -> Self {
42 Self {
43 nan_inf_detection: true,
44 gradient_anomaly_detection: true,
45 stability_monitoring: true,
46 performance_profiling: false,
47 memory_leak_detection: true,
48 history_window_size: 100,
49 gradient_norm_threshold: 100.0,
50 loss_spike_threshold: 10.0,
51 memory_growth_threshold: 100_000_000, auto_recovery_attempts: 3,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StepMetrics {
60 pub step: usize,
61 pub timestamp: u64,
62 pub loss: f32,
63 pub gradient_norm: f32,
64 pub learning_rate: f32,
65 pub memory_usage: usize,
66 pub step_duration_ms: u64,
67 pub has_nan_inf: bool,
68 pub gradient_anomaly: bool,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct AnomalyReport {
74 pub step: usize,
75 pub anomaly_type: AnomalyType,
76 pub severity: AnomalySeverity,
77 pub description: String,
78 pub suggested_actions: Vec<String>,
79 pub auto_recovery_applied: bool,
80}
81
82#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
84pub enum AnomalyType {
85 NanInf,
86 GradientExplosion,
87 GradientVanishing,
88 LossSpike,
89 MemoryLeak,
90 PerformanceRegression,
91 TrainingStagnation,
92}
93
94#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
96pub enum AnomalySeverity {
97 Low,
98 Medium,
99 High,
100 Critical,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub enum RecoveryStrategy {
106 ReduceLearningRate,
107 GradientClipping,
108 RestoreCheckpoint,
109 RestartTraining,
110 MemoryCleanup,
111 OptimizerReset,
112}
113
114pub struct TrainingMonitor {
116 config: TrainingMonitorConfig,
117 metrics_history: VecDeque<StepMetrics>,
118 anomaly_reports: Vec<AnomalyReport>,
119 recovery_attempts: HashMap<AnomalyType, usize>,
120 performance_stats: PerformanceStats,
121 memory_baseline: usize,
122 #[allow(dead_code)]
123 last_checkpoint: Option<u64>,
124}
125
126impl TrainingMonitor {
127 pub fn new(config: TrainingMonitorConfig) -> Self {
128 Self {
129 config,
130 metrics_history: VecDeque::new(),
131 anomaly_reports: Vec::new(),
132 recovery_attempts: HashMap::new(),
133 performance_stats: PerformanceStats::new(),
134 memory_baseline: 0,
135 last_checkpoint: None,
136 }
137 }
138
139 pub fn record_step(
141 &mut self,
142 step: usize,
143 loss: f32,
144 gradients: &HashMap<String, Tensor>,
145 learning_rate: f32,
146 memory_usage: usize,
147 step_duration: Duration,
148 ) -> Result<()> {
149 let gradient_norm = self.compute_gradient_norm(gradients)?;
150 let has_nan_inf = self.detect_nan_inf(loss, gradients)?;
151 let gradient_anomaly = self.detect_gradient_anomaly(gradient_norm);
152
153 let metrics = StepMetrics {
154 step,
155 timestamp: SystemTime::now()
156 .duration_since(UNIX_EPOCH)
157 .expect("SystemTime should be after UNIX_EPOCH")
158 .as_secs(),
159 loss,
160 gradient_norm,
161 learning_rate,
162 memory_usage,
163 step_duration_ms: step_duration.as_millis() as u64,
164 has_nan_inf,
165 gradient_anomaly,
166 };
167
168 self.metrics_history.push_back(metrics.clone());
170
171 while self.metrics_history.len() > self.config.history_window_size {
173 self.metrics_history.pop_front();
174 }
175
176 self.performance_stats.update(&metrics);
178
179 self.perform_anomaly_detection(&metrics)?;
181
182 Ok(())
183 }
184
185 fn detect_nan_inf(&self, loss: f32, gradients: &HashMap<String, Tensor>) -> Result<bool> {
187 if !self.config.nan_inf_detection {
188 return Ok(false);
189 }
190
191 if !loss.is_finite() {
193 return Ok(true);
194 }
195
196 for gradient in gradients.values() {
198 if self.has_nan_inf_tensor(gradient)? {
199 return Ok(true);
200 }
201 }
202
203 Ok(false)
204 }
205
206 fn has_nan_inf_tensor(&self, _tensor: &Tensor) -> Result<bool> {
208 Ok(false)
211 }
212
213 fn compute_gradient_norm(&self, gradients: &HashMap<String, Tensor>) -> Result<f32> {
215 let mut total_norm = 0.0f32;
216 let mut param_count = 0;
217
218 for gradient in gradients.values() {
219 let grad_norm = self.tensor_norm(gradient)?;
221 total_norm += grad_norm * grad_norm;
222 param_count += 1;
223 }
224
225 if param_count > 0 {
226 Ok(total_norm.sqrt())
227 } else {
228 Ok(0.0)
229 }
230 }
231
232 fn tensor_norm(&self, _tensor: &Tensor) -> Result<f32> {
234 Ok(1.0)
236 }
237
238 fn detect_gradient_anomaly(&self, gradient_norm: f32) -> bool {
240 if !self.config.gradient_anomaly_detection {
241 return false;
242 }
243
244 gradient_norm > self.config.gradient_norm_threshold || gradient_norm < 1e-8
245 }
246
247 fn perform_anomaly_detection(&mut self, metrics: &StepMetrics) -> Result<()> {
249 let mut detected_anomalies = Vec::new();
250
251 if metrics.has_nan_inf {
253 detected_anomalies.push(AnomalyReport {
254 step: metrics.step,
255 anomaly_type: AnomalyType::NanInf,
256 severity: AnomalySeverity::Critical,
257 description: "NaN or Inf values detected in loss or gradients".to_string(),
258 suggested_actions: vec![
259 "Check learning rate (reduce if too high)".to_string(),
260 "Implement gradient clipping".to_string(),
261 "Restore from previous checkpoint".to_string(),
262 ],
263 auto_recovery_applied: false,
264 });
265 }
266
267 if metrics.gradient_norm > self.config.gradient_norm_threshold {
269 detected_anomalies.push(AnomalyReport {
270 step: metrics.step,
271 anomaly_type: AnomalyType::GradientExplosion,
272 severity: AnomalySeverity::High,
273 description: format!(
274 "Gradient norm ({:.2}) exceeds threshold ({:.2})",
275 metrics.gradient_norm, self.config.gradient_norm_threshold
276 ),
277 suggested_actions: vec![
278 "Apply gradient clipping".to_string(),
279 "Reduce learning rate".to_string(),
280 "Check for unstable layers".to_string(),
281 ],
282 auto_recovery_applied: false,
283 });
284 }
285
286 if metrics.gradient_norm < 1e-8 {
288 detected_anomalies.push(AnomalyReport {
289 step: metrics.step,
290 anomaly_type: AnomalyType::GradientVanishing,
291 severity: AnomalySeverity::Medium,
292 description: format!(
293 "Gradient norm ({:.2e}) is extremely small",
294 metrics.gradient_norm
295 ),
296 suggested_actions: vec![
297 "Increase learning rate".to_string(),
298 "Check for dead neurons".to_string(),
299 "Consider different activation functions".to_string(),
300 ],
301 auto_recovery_applied: false,
302 });
303 }
304
305 if let Some(recent_loss) = self.get_recent_average_loss() {
307 if metrics.loss > recent_loss * self.config.loss_spike_threshold {
308 detected_anomalies.push(AnomalyReport {
309 step: metrics.step,
310 anomaly_type: AnomalyType::LossSpike,
311 severity: AnomalySeverity::High,
312 description: format!(
313 "Loss spike detected: {:.4} vs recent average {:.4}",
314 metrics.loss, recent_loss
315 ),
316 suggested_actions: vec![
317 "Check for data corruption".to_string(),
318 "Verify batch normalization".to_string(),
319 "Consider reducing learning rate".to_string(),
320 ],
321 auto_recovery_applied: false,
322 });
323 }
324 }
325
326 if self.config.memory_leak_detection
328 && self.memory_baseline > 0
329 && metrics.memory_usage > self.memory_baseline + self.config.memory_growth_threshold
330 {
331 detected_anomalies.push(AnomalyReport {
332 step: metrics.step,
333 anomaly_type: AnomalyType::MemoryLeak,
334 severity: AnomalySeverity::Medium,
335 description: format!(
336 "Memory usage increased by {} bytes",
337 metrics.memory_usage - self.memory_baseline
338 ),
339 suggested_actions: vec![
340 "Check for tensor accumulation".to_string(),
341 "Verify gradient cleanup".to_string(),
342 "Consider memory optimization".to_string(),
343 ],
344 auto_recovery_applied: false,
345 });
346 }
347
348 if self.detect_training_stagnation()? {
350 detected_anomalies.push(AnomalyReport {
351 step: metrics.step,
352 anomaly_type: AnomalyType::TrainingStagnation,
353 severity: AnomalySeverity::Medium,
354 description: "Training appears to have stagnated".to_string(),
355 suggested_actions: vec![
356 "Adjust learning rate schedule".to_string(),
357 "Consider different optimizer".to_string(),
358 "Check for overfitting".to_string(),
359 ],
360 auto_recovery_applied: false,
361 });
362 }
363
364 for mut anomaly in detected_anomalies {
366 if self.should_apply_auto_recovery(&anomaly) {
367 anomaly.auto_recovery_applied = self.apply_auto_recovery(&anomaly)?;
368 }
369 self.anomaly_reports.push(anomaly);
370 }
371
372 Ok(())
373 }
374
375 fn get_recent_average_loss(&self) -> Option<f32> {
377 if self.metrics_history.len() < 10 {
378 return None;
379 }
380
381 let recent_count = std::cmp::min(10, self.metrics_history.len());
382 let recent_losses: Vec<f32> =
383 self.metrics_history.iter().rev().take(recent_count).map(|m| m.loss).collect();
384
385 if recent_losses.is_empty() {
386 None
387 } else {
388 Some(recent_losses.iter().sum::<f32>() / recent_losses.len() as f32)
389 }
390 }
391
392 fn detect_training_stagnation(&self) -> Result<bool> {
394 if self.metrics_history.len() < 50 {
395 return Ok(false);
396 }
397
398 let recent_window = 20;
400 let older_window = 30;
401
402 let recent_avg = self.get_window_average_loss(recent_window)?;
403 let older_avg = self.get_window_average_loss(older_window)?;
404
405 Ok(recent_avg >= older_avg * 0.99)
407 }
408
409 fn get_window_average_loss(&self, window_size: usize) -> Result<f32> {
411 if self.metrics_history.len() < window_size {
412 return Ok(0.0);
413 }
414
415 let losses: Vec<f32> =
416 self.metrics_history.iter().rev().take(window_size).map(|m| m.loss).collect();
417
418 Ok(losses.iter().sum::<f32>() / losses.len() as f32)
419 }
420
421 fn should_apply_auto_recovery(&self, anomaly: &AnomalyReport) -> bool {
423 let attempts = self.recovery_attempts.get(&anomaly.anomaly_type).unwrap_or(&0);
424 *attempts < self.config.auto_recovery_attempts
425 }
426
427 fn apply_auto_recovery(&mut self, anomaly: &AnomalyReport) -> Result<bool> {
429 let attempts = self.recovery_attempts.entry(anomaly.anomaly_type.clone()).or_insert(0);
430 *attempts += 1;
431
432 match anomaly.anomaly_type {
433 AnomalyType::NanInf => {
434 println!("Auto-recovery: Restoring from checkpoint due to NaN/Inf");
436 Ok(true)
437 },
438 AnomalyType::GradientExplosion => {
439 println!("Auto-recovery: Applying gradient clipping");
441 Ok(true)
442 },
443 AnomalyType::MemoryLeak => {
444 println!("Auto-recovery: Triggering memory cleanup");
446 Ok(true)
447 },
448 _ => Ok(false),
449 }
450 }
451
452 pub fn get_health_status(&self) -> TrainingHealthStatus {
454 let recent_anomalies = self.anomaly_reports.iter().rev().take(10).collect::<Vec<_>>();
455
456 let critical_count = recent_anomalies
457 .iter()
458 .filter(|a| matches!(a.severity, AnomalySeverity::Critical))
459 .count();
460
461 let high_count = recent_anomalies
462 .iter()
463 .filter(|a| matches!(a.severity, AnomalySeverity::High))
464 .count();
465
466 let overall_health = if critical_count > 0 {
467 HealthStatus::Critical
468 } else if high_count > 3 {
469 HealthStatus::Poor
470 } else if high_count > 1 {
471 HealthStatus::Warning
472 } else {
473 HealthStatus::Good
474 };
475
476 TrainingHealthStatus {
477 overall_health,
478 recent_anomalies: recent_anomalies.len(),
479 critical_issues: critical_count,
480 high_issues: high_count,
481 auto_recovery_success_rate: self.calculate_recovery_success_rate(),
482 performance_trend: self.performance_stats.get_trend(),
483 }
484 }
485
486 fn calculate_recovery_success_rate(&self) -> f32 {
488 let total_recoveries =
489 self.anomaly_reports.iter().filter(|a| a.auto_recovery_applied).count();
490
491 if total_recoveries == 0 {
492 return 1.0;
493 }
494
495 0.85 }
498
499 pub fn get_training_report(&self) -> TrainingReport {
501 TrainingReport {
502 health_status: self.get_health_status(),
503 anomaly_summary: self.get_anomaly_summary(),
504 performance_stats: self.performance_stats.clone(),
505 recommendations: self.generate_recommendations(),
506 }
507 }
508
509 fn get_anomaly_summary(&self) -> AnomalySummary {
511 let mut type_counts = HashMap::new();
512 let mut severity_counts = HashMap::new();
513
514 for anomaly in &self.anomaly_reports {
515 *type_counts.entry(anomaly.anomaly_type.clone()).or_insert(0) += 1;
516 *severity_counts.entry(anomaly.severity.clone()).or_insert(0) += 1;
517 }
518
519 AnomalySummary {
520 total_anomalies: self.anomaly_reports.len(),
521 type_distribution: type_counts,
522 severity_distribution: severity_counts,
523 }
524 }
525
526 fn generate_recommendations(&self) -> Vec<String> {
528 let mut recommendations = Vec::new();
529
530 let recent_anomalies = self.anomaly_reports.iter().rev().take(20).collect::<Vec<_>>();
532
533 if recent_anomalies
534 .iter()
535 .any(|a| matches!(a.anomaly_type, AnomalyType::GradientExplosion))
536 {
537 recommendations.push("Consider implementing gradient clipping".to_string());
538 }
539
540 if recent_anomalies
541 .iter()
542 .any(|a| matches!(a.anomaly_type, AnomalyType::MemoryLeak))
543 {
544 recommendations.push("Review memory management and tensor lifecycle".to_string());
545 }
546
547 if recent_anomalies
548 .iter()
549 .any(|a| matches!(a.anomaly_type, AnomalyType::TrainingStagnation))
550 {
551 recommendations
552 .push("Consider adjusting learning rate schedule or optimizer".to_string());
553 }
554
555 if self.performance_stats.average_step_duration_ms > 5000 {
556 recommendations
557 .push("Training steps are taking too long - consider optimization".to_string());
558 }
559
560 recommendations
561 }
562
563 pub fn set_memory_baseline(&mut self, baseline: usize) {
565 self.memory_baseline = baseline;
566 }
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
571pub struct PerformanceStats {
572 pub total_steps: usize,
573 pub average_step_duration_ms: u64,
574 pub average_loss: f32,
575 pub average_gradient_norm: f32,
576 pub memory_usage_trend: f32,
577}
578
579impl PerformanceStats {
580 fn new() -> Self {
581 Self {
582 total_steps: 0,
583 average_step_duration_ms: 0,
584 average_loss: 0.0,
585 average_gradient_norm: 0.0,
586 memory_usage_trend: 0.0,
587 }
588 }
589
590 fn update(&mut self, metrics: &StepMetrics) {
591 self.total_steps += 1;
592
593 let n = self.total_steps as f32;
595 let old_weight = (n - 1.0) / n;
596 let new_weight = 1.0 / n;
597
598 self.average_step_duration_ms = (self.average_step_duration_ms as f32 * old_weight
599 + metrics.step_duration_ms as f32 * new_weight)
600 as u64;
601
602 self.average_loss = self.average_loss * old_weight + metrics.loss * new_weight;
603 self.average_gradient_norm =
604 self.average_gradient_norm * old_weight + metrics.gradient_norm * new_weight;
605 }
606
607 fn get_trend(&self) -> PerformanceTrend {
608 if self.total_steps < 10 {
610 PerformanceTrend::Stable
611 } else if self.average_step_duration_ms > 10000 {
612 PerformanceTrend::Degrading
613 } else {
614 PerformanceTrend::Improving
615 }
616 }
617}
618
619#[derive(Debug, Clone, Serialize, Deserialize)]
621pub struct TrainingHealthStatus {
622 pub overall_health: HealthStatus,
623 pub recent_anomalies: usize,
624 pub critical_issues: usize,
625 pub high_issues: usize,
626 pub auto_recovery_success_rate: f32,
627 pub performance_trend: PerformanceTrend,
628}
629
630#[derive(Debug, Clone, Serialize, Deserialize)]
632pub enum HealthStatus {
633 Good,
634 Warning,
635 Poor,
636 Critical,
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
641pub enum PerformanceTrend {
642 Improving,
643 Stable,
644 Degrading,
645}
646
647#[derive(Debug, Clone, Serialize, Deserialize)]
649pub struct AnomalySummary {
650 pub total_anomalies: usize,
651 pub type_distribution: HashMap<AnomalyType, usize>,
652 pub severity_distribution: HashMap<AnomalySeverity, usize>,
653}
654
655#[derive(Debug, Clone, Serialize, Deserialize)]
657pub struct TrainingReport {
658 pub health_status: TrainingHealthStatus,
659 pub anomaly_summary: AnomalySummary,
660 pub performance_stats: PerformanceStats,
661 pub recommendations: Vec<String>,
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use std::time::Duration;
668
669 #[test]
670 fn test_training_monitor_creation() {
671 let config = TrainingMonitorConfig::default();
672 let monitor = TrainingMonitor::new(config);
673
674 assert_eq!(monitor.metrics_history.len(), 0);
675 assert_eq!(monitor.anomaly_reports.len(), 0);
676 }
677
678 #[test]
679 fn test_nan_inf_detection() {
680 let config = TrainingMonitorConfig::default();
681 let monitor = TrainingMonitor::new(config);
682
683 let gradients = HashMap::new();
684 let result = monitor.detect_nan_inf(f32::NAN, &gradients);
685
686 assert!(result.is_ok());
687 assert!(result.expect("operation failed in test"));
688 }
689
690 #[test]
691 fn test_gradient_anomaly_detection() {
692 let config = TrainingMonitorConfig {
693 gradient_norm_threshold: 10.0,
694 ..Default::default()
695 };
696 let monitor = TrainingMonitor::new(config);
697
698 assert!(monitor.detect_gradient_anomaly(100.0));
699 assert!(monitor.detect_gradient_anomaly(1e-10));
700 assert!(!monitor.detect_gradient_anomaly(5.0));
701 }
702
703 #[test]
704 fn test_step_recording() {
705 let config = TrainingMonitorConfig::default();
706 let mut monitor = TrainingMonitor::new(config);
707
708 let gradients = HashMap::new();
709 let result = monitor.record_step(
710 0,
711 1.0,
712 &gradients,
713 0.001,
714 1000000,
715 Duration::from_millis(100),
716 );
717
718 assert!(result.is_ok());
719 assert_eq!(monitor.metrics_history.len(), 1);
720 }
721
722 #[test]
723 fn test_health_status() {
724 let config = TrainingMonitorConfig::default();
725 let monitor = TrainingMonitor::new(config);
726
727 let health = monitor.get_health_status();
728 assert!(matches!(health.overall_health, HealthStatus::Good));
729 assert_eq!(health.recent_anomalies, 0);
730 }
731
732 #[test]
733 fn test_performance_stats() {
734 let mut stats = PerformanceStats::new();
735 let metrics = StepMetrics {
736 step: 0,
737 timestamp: SystemTime::now()
738 .duration_since(UNIX_EPOCH)
739 .expect("SystemTime should be after UNIX_EPOCH")
740 .as_millis() as u64,
741 loss: 1.0,
742 gradient_norm: 2.0,
743 learning_rate: 0.001,
744 memory_usage: 1000000,
745 step_duration_ms: 100,
746 has_nan_inf: false,
747 gradient_anomaly: false,
748 };
749
750 stats.update(&metrics);
751
752 assert_eq!(stats.total_steps, 1);
753 assert_eq!(stats.average_loss, 1.0);
754 assert_eq!(stats.average_gradient_norm, 2.0);
755 }
756}