1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingDynamicsConfig {
13 pub enable_loss_curve_analysis: bool,
15 pub enable_learning_rate_analysis: bool,
17 pub enable_batch_size_analysis: bool,
19 pub enable_convergence_detection: bool,
21 pub enable_plateau_identification: bool,
23 pub moving_average_window: usize,
25 pub convergence_tolerance: f32,
27 pub plateau_threshold: f32,
29 pub min_epochs_for_convergence: usize,
31 pub max_history_length: usize,
33}
34
35impl Default for TrainingDynamicsConfig {
36 fn default() -> Self {
37 Self {
38 enable_loss_curve_analysis: true,
39 enable_learning_rate_analysis: true,
40 enable_batch_size_analysis: true,
41 enable_convergence_detection: true,
42 enable_plateau_identification: true,
43 moving_average_window: 10,
44 convergence_tolerance: 1e-6,
45 plateau_threshold: 1e-4,
46 min_epochs_for_convergence: 20,
47 max_history_length: 10000,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingMetrics {
55 pub epoch: usize,
56 pub step: usize,
57 pub train_loss: f32,
58 pub validation_loss: Option<f32>,
59 pub learning_rate: f32,
60 pub batch_size: usize,
61 pub gradient_norm: Option<f32>,
62 pub accuracy: Option<f32>,
63 pub timestamp: f64,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LossCurveAnalysis {
69 pub trend: LossTrend,
70 pub smoothness: f32,
71 pub volatility: f32,
72 pub improvement_rate: f32,
73 pub best_loss: f32,
74 pub current_loss: f32,
75 pub loss_reduction_percentage: f32,
76 pub epochs_since_improvement: usize,
77 pub moving_averages: MovingAverages,
78 pub loss_statistics: LossStatistics,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum LossTrend {
83 Decreasing,
84 Increasing,
85 Oscillating,
86 Plateaued,
87 Unknown,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MovingAverages {
92 pub short_term: f32, pub medium_term: f32, pub long_term: f32, }
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct LossStatistics {
99 pub mean: f32,
100 pub std: f32,
101 pub min: f32,
102 pub max: f32,
103 pub median: f32,
104 pub percentile_25: f32,
105 pub percentile_75: f32,
106 pub autocorrelation: f32,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LearningRateAnalysis {
112 pub current_lr: f32,
113 pub lr_schedule_type: LRScheduleType,
114 pub lr_impact_score: f32,
115 pub optimal_lr_estimate: f32,
116 pub lr_sensitivity: f32,
117 pub lr_history: Vec<LearningRatePoint>,
118 pub recommendations: Vec<LRRecommendation>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum LRScheduleType {
123 Constant,
124 StepDecay,
125 ExponentialDecay,
126 CosineAnnealing,
127 ReduceOnPlateau,
128 Warmup,
129 Cyclical,
130 Unknown,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct LearningRatePoint {
135 pub epoch: usize,
136 pub learning_rate: f32,
137 pub loss_change: f32,
138 pub gradient_norm: Option<f32>,
139 pub effectiveness: f32,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct LRRecommendation {
144 pub action: LRAction,
145 pub confidence: f32,
146 pub rationale: String,
147 pub expected_improvement: f32,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum LRAction {
152 Increase,
153 Decrease,
154 KeepCurrent,
155 AddScheduler,
156 ChangeScheduler,
157 AddWarmup,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchSizeAnalysis {
163 pub current_batch_size: usize,
164 pub batch_size_efficiency: f32,
165 pub gradient_noise_level: f32,
166 pub convergence_speed: f32,
167 pub memory_utilization: f32,
168 pub optimal_batch_size_estimate: usize,
169 pub batch_size_history: Vec<BatchSizePoint>,
170 pub recommendations: Vec<BatchSizeRecommendation>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct BatchSizePoint {
175 pub epoch: usize,
176 pub batch_size: usize,
177 pub loss_improvement: f32,
178 pub gradient_stability: f32,
179 pub throughput: f32,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct BatchSizeRecommendation {
184 pub suggested_batch_size: usize,
185 pub confidence: f32,
186 pub rationale: String,
187 pub expected_benefits: Vec<String>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ConvergenceAnalysis {
193 pub convergence_status: ConvergenceStatus,
194 pub convergence_probability: f32,
195 pub epochs_to_convergence_estimate: Option<usize>,
196 pub convergence_criteria: Vec<ConvergenceCriterion>,
197 pub early_stopping_recommendation: Option<EarlyStoppingRecommendation>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum ConvergenceStatus {
202 Converging,
203 Converged,
204 Diverging,
205 Oscillating,
206 TooEarly,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ConvergenceCriterion {
211 pub criterion_type: ConvergenceCriterionType,
212 pub current_value: f32,
213 pub threshold: f32,
214 pub satisfied: bool,
215 pub confidence: f32,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ConvergenceCriterionType {
220 LossStability,
221 GradientMagnitude,
222 LossImprovement,
223 ValidationGap,
224 LearningRateDecay,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct EarlyStoppingRecommendation {
229 pub should_stop: bool,
230 pub confidence: f32,
231 pub rationale: String,
232 pub suggested_epochs_remaining: usize,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PlateauAnalysis {
238 pub plateau_detected: bool,
239 pub plateau_duration: usize,
240 pub plateau_level: f32,
241 pub plateau_type: PlateauType,
242 pub escape_probability: f32,
243 pub plateau_characteristics: PlateauCharacteristics,
244 pub recommendations: Vec<PlateauRecommendation>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum PlateauType {
249 LossPlayteau,
250 GradientPlateau,
251 AccuracyPlateau,
252 LearningRatePlateau,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct PlateauCharacteristics {
257 pub stability: f32,
258 pub noise_level: f32,
259 pub gradient_magnitude: f32,
260 pub overfitting_risk: f32,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct PlateauRecommendation {
265 pub action: PlateauAction,
266 pub priority: Priority,
267 pub description: String,
268 pub implementation: String,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub enum PlateauAction {
273 IncreaseLearningRate,
274 DecreaseLearningRate,
275 ChangeBatchSize,
276 AddRegularization,
277 RemoveRegularization,
278 ChangeOptimizer,
279 AddNoise,
280 EarlyStopping,
281 ContinueTraining,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum Priority {
286 Critical,
287 High,
288 Medium,
289 Low,
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct TrainingDynamicsReport {
295 pub loss_curve_analysis: Option<LossCurveAnalysis>,
296 pub learning_rate_analysis: Option<LearningRateAnalysis>,
297 pub batch_size_analysis: Option<BatchSizeAnalysis>,
298 pub convergence_analysis: Option<ConvergenceAnalysis>,
299 pub plateau_analysis: Option<PlateauAnalysis>,
300 pub training_summary: TrainingSummary,
301 pub recommendations: Vec<TrainingRecommendation>,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct TrainingSummary {
306 pub total_epochs: usize,
307 pub total_steps: usize,
308 pub training_efficiency: f32,
309 pub convergence_health: f32,
310 pub stability_score: f32,
311 pub overall_progress: f32,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct TrainingRecommendation {
316 pub category: TrainingCategory,
317 pub priority: Priority,
318 pub description: String,
319 pub implementation: String,
320 pub expected_impact: f32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub enum TrainingCategory {
325 LearningRate,
326 BatchSize,
327 Optimization,
328 Regularization,
329 EarlyStopping,
330 Architecture,
331}
332
333#[derive(Debug)]
335pub struct TrainingDynamicsAnalyzer {
336 config: TrainingDynamicsConfig,
337 metrics_history: VecDeque<TrainingMetrics>,
338 analysis_cache: HashMap<String, TrainingDynamicsReport>,
339}
340
341impl TrainingDynamicsAnalyzer {
342 pub fn new(config: TrainingDynamicsConfig) -> Self {
344 Self {
345 config,
346 metrics_history: VecDeque::new(),
347 analysis_cache: HashMap::new(),
348 }
349 }
350
351 pub fn record_metrics(&mut self, metrics: TrainingMetrics) {
353 self.metrics_history.push_back(metrics);
354
355 while self.metrics_history.len() > self.config.max_history_length {
357 self.metrics_history.pop_front();
358 }
359 }
360
361 pub async fn analyze(&mut self) -> Result<TrainingDynamicsReport> {
363 let mut report = TrainingDynamicsReport {
364 loss_curve_analysis: None,
365 learning_rate_analysis: None,
366 batch_size_analysis: None,
367 convergence_analysis: None,
368 plateau_analysis: None,
369 training_summary: TrainingSummary {
370 total_epochs: 0,
371 total_steps: 0,
372 training_efficiency: 0.0,
373 convergence_health: 0.0,
374 stability_score: 0.0,
375 overall_progress: 0.0,
376 },
377 recommendations: Vec::new(),
378 };
379
380 if self.config.enable_loss_curve_analysis {
381 report.loss_curve_analysis = Some(self.analyze_loss_curve().await?);
382 }
383
384 if self.config.enable_learning_rate_analysis {
385 report.learning_rate_analysis = Some(self.analyze_learning_rate().await?);
386 }
387
388 if self.config.enable_batch_size_analysis {
389 report.batch_size_analysis = Some(self.analyze_batch_size().await?);
390 }
391
392 if self.config.enable_convergence_detection {
393 report.convergence_analysis = Some(self.detect_convergence().await?);
394 }
395
396 if self.config.enable_plateau_identification {
397 report.plateau_analysis = Some(self.identify_plateau().await?);
398 }
399
400 self.generate_training_summary(&mut report);
401 self.generate_training_recommendations(&mut report);
402
403 Ok(report)
404 }
405
406 async fn analyze_loss_curve(&self) -> Result<LossCurveAnalysis> {
408 if self.metrics_history.is_empty() {
409 return Ok(LossCurveAnalysis {
410 trend: LossTrend::Unknown,
411 smoothness: 0.0,
412 volatility: 0.0,
413 improvement_rate: 0.0,
414 best_loss: 0.0,
415 current_loss: 0.0,
416 loss_reduction_percentage: 0.0,
417 epochs_since_improvement: 0,
418 moving_averages: MovingAverages {
419 short_term: 0.0,
420 medium_term: 0.0,
421 long_term: 0.0,
422 },
423 loss_statistics: LossStatistics {
424 mean: 0.0,
425 std: 0.0,
426 min: 0.0,
427 max: 0.0,
428 median: 0.0,
429 percentile_25: 0.0,
430 percentile_75: 0.0,
431 autocorrelation: 0.0,
432 },
433 });
434 }
435
436 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
437
438 let trend = self.detect_loss_trend(&losses);
439 let smoothness = self.calculate_smoothness(&losses);
440 let volatility = self.calculate_volatility(&losses);
441 let improvement_rate = self.calculate_improvement_rate(&losses);
442
443 let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
444 let current_loss = *losses.last().expect("losses is non-empty from metrics_history");
445 let loss_reduction_percentage = if losses.len() > 1 {
446 ((losses[0] - current_loss) / losses[0].abs()) * 100.0
447 } else {
448 0.0
449 };
450
451 let epochs_since_improvement = self.calculate_epochs_since_improvement(&losses, best_loss);
452 let moving_averages = self.calculate_moving_averages(&losses);
453 let loss_statistics = self.calculate_loss_statistics(&losses);
454
455 Ok(LossCurveAnalysis {
456 trend,
457 smoothness,
458 volatility,
459 improvement_rate,
460 best_loss,
461 current_loss,
462 loss_reduction_percentage,
463 epochs_since_improvement,
464 moving_averages,
465 loss_statistics,
466 })
467 }
468
469 fn detect_loss_trend(&self, losses: &[f32]) -> LossTrend {
471 if losses.len() < 3 {
472 return LossTrend::Unknown;
473 }
474
475 let window_size = (losses.len() / 4).max(5).min(20);
476 let recent_losses = &losses[losses.len().saturating_sub(window_size)..];
477 let early_losses = &losses[..window_size.min(losses.len())];
478
479 let recent_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
480 let early_mean = early_losses.iter().sum::<f32>() / early_losses.len() as f32;
481
482 let improvement = (early_mean - recent_mean) / early_mean.abs();
483
484 let recent_std = self.calculate_std(recent_losses);
486 let recent_mean_abs = recent_mean.abs();
487
488 if recent_std / recent_mean_abs.max(1e-8) < self.config.plateau_threshold {
489 return LossTrend::Plateaued;
490 }
491
492 let oscillation_score = self.detect_oscillation(losses);
494 if oscillation_score > 0.5 {
495 return LossTrend::Oscillating;
496 }
497
498 if improvement > 0.01 {
499 LossTrend::Decreasing
500 } else if improvement < -0.01 {
501 LossTrend::Increasing
502 } else {
503 LossTrend::Plateaued
504 }
505 }
506
507 fn calculate_smoothness(&self, losses: &[f32]) -> f32 {
509 if losses.len() < 2 {
510 return 1.0;
511 }
512
513 let differences: Vec<f32> = losses.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
514
515 let mean_diff = differences.iter().sum::<f32>() / differences.len() as f32;
516 let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
517
518 1.0 / (1.0 + mean_diff / mean_loss.abs().max(1e-8))
520 }
521
522 fn calculate_volatility(&self, losses: &[f32]) -> f32 {
524 if losses.len() < 2 {
525 return 0.0;
526 }
527
528 let returns: Vec<f32> =
529 losses.windows(2).map(|w| (w[1] - w[0]) / w[0].abs().max(1e-8)).collect();
530
531 self.calculate_std(&returns)
532 }
533
534 fn calculate_improvement_rate(&self, losses: &[f32]) -> f32 {
536 if losses.len() < 2 {
537 return 0.0;
538 }
539
540 let total_improvement = losses[0] - losses[losses.len() - 1];
541 let epochs = losses.len() as f32;
542
543 total_improvement / epochs
544 }
545
546 fn calculate_epochs_since_improvement(&self, losses: &[f32], best_loss: f32) -> usize {
548 for (i, &loss) in losses.iter().rev().enumerate() {
549 if (loss - best_loss).abs() < 1e-8 {
550 return i;
551 }
552 }
553 losses.len()
554 }
555
556 fn calculate_moving_averages(&self, losses: &[f32]) -> MovingAverages {
558 let short_window = 5.min(losses.len());
559 let medium_window = 20.min(losses.len());
560 let long_window = 100.min(losses.len());
561
562 let short_term = if short_window > 0 {
563 losses[losses.len() - short_window..].iter().sum::<f32>() / short_window as f32
564 } else {
565 0.0
566 };
567
568 let medium_term = if medium_window > 0 {
569 losses[losses.len() - medium_window..].iter().sum::<f32>() / medium_window as f32
570 } else {
571 0.0
572 };
573
574 let long_term = if long_window > 0 {
575 losses[losses.len() - long_window..].iter().sum::<f32>() / long_window as f32
576 } else {
577 0.0
578 };
579
580 MovingAverages {
581 short_term,
582 medium_term,
583 long_term,
584 }
585 }
586
587 fn calculate_loss_statistics(&self, losses: &[f32]) -> LossStatistics {
589 if losses.is_empty() {
590 return LossStatistics {
591 mean: 0.0,
592 std: 0.0,
593 min: 0.0,
594 max: 0.0,
595 median: 0.0,
596 percentile_25: 0.0,
597 percentile_75: 0.0,
598 autocorrelation: 0.0,
599 };
600 }
601
602 let mean = losses.iter().sum::<f32>() / losses.len() as f32;
603 let std = self.calculate_std(losses);
604
605 let mut sorted_losses = losses.to_vec();
606 sorted_losses.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
607
608 let min = sorted_losses[0];
609 let max = sorted_losses[sorted_losses.len() - 1];
610 let median = sorted_losses[sorted_losses.len() / 2];
611 let percentile_25 = sorted_losses[sorted_losses.len() / 4];
612 let percentile_75 = sorted_losses[3 * sorted_losses.len() / 4];
613
614 let autocorrelation = self.calculate_autocorrelation(losses, 1);
615
616 LossStatistics {
617 mean,
618 std,
619 min,
620 max,
621 median,
622 percentile_25,
623 percentile_75,
624 autocorrelation,
625 }
626 }
627
628 fn calculate_std(&self, values: &[f32]) -> f32 {
630 if values.len() < 2 {
631 return 0.0;
632 }
633
634 let mean = values.iter().sum::<f32>() / values.len() as f32;
635 let variance =
636 values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
637
638 variance.sqrt()
639 }
640
641 fn detect_oscillation(&self, losses: &[f32]) -> f32 {
643 if losses.len() < 4 {
644 return 0.0;
645 }
646
647 let mut direction_changes = 0;
648 let mut total_comparisons = 0;
649
650 for i in 1..losses.len() - 1 {
651 let prev_direction = losses[i] > losses[i - 1];
652 let next_direction = losses[i + 1] > losses[i];
653
654 if prev_direction != next_direction {
655 direction_changes += 1;
656 }
657 total_comparisons += 1;
658 }
659
660 direction_changes as f32 / total_comparisons as f32
661 }
662
663 fn calculate_autocorrelation(&self, values: &[f32], lag: usize) -> f32 {
665 if values.len() <= lag {
666 return 0.0;
667 }
668
669 let mean = values.iter().sum::<f32>() / values.len() as f32;
670
671 let mut numerator = 0.0;
672 let mut denominator = 0.0;
673
674 for i in 0..values.len() - lag {
675 numerator += (values[i] - mean) * (values[i + lag] - mean);
676 }
677
678 for &value in values {
679 denominator += (value - mean).powi(2);
680 }
681
682 if denominator > 1e-8 {
683 numerator / denominator
684 } else {
685 0.0
686 }
687 }
688
689 async fn analyze_learning_rate(&self) -> Result<LearningRateAnalysis> {
691 if self.metrics_history.is_empty() {
692 return Ok(LearningRateAnalysis {
693 current_lr: 0.0,
694 lr_schedule_type: LRScheduleType::Unknown,
695 lr_impact_score: 0.0,
696 optimal_lr_estimate: 0.0,
697 lr_sensitivity: 0.0,
698 lr_history: Vec::new(),
699 recommendations: Vec::new(),
700 });
701 }
702
703 let current_lr = self
704 .metrics_history
705 .back()
706 .expect("metrics_history should not be empty after empty check")
707 .learning_rate;
708 let lr_schedule_type = self.detect_lr_schedule_type();
709
710 let lr_history = self.build_lr_history();
711 let lr_impact_score = self.calculate_lr_impact_score(&lr_history);
712 let optimal_lr_estimate = self.estimate_optimal_lr(&lr_history);
713 let lr_sensitivity = self.calculate_lr_sensitivity(&lr_history);
714 let recommendations = self.generate_lr_recommendations(current_lr, &lr_history);
715
716 Ok(LearningRateAnalysis {
717 current_lr,
718 lr_schedule_type,
719 lr_impact_score,
720 optimal_lr_estimate,
721 lr_sensitivity,
722 lr_history,
723 recommendations,
724 })
725 }
726
727 fn detect_lr_schedule_type(&self) -> LRScheduleType {
729 let lrs: Vec<f32> = self.metrics_history.iter().map(|m| m.learning_rate).collect();
730
731 if lrs.len() < 3 {
732 return LRScheduleType::Unknown;
733 }
734
735 let lr_std = self.calculate_std(&lrs);
737 if lr_std < 1e-8 {
738 return LRScheduleType::Constant;
739 }
740
741 let mut step_drops = 0;
743 for window in lrs.windows(2) {
744 if window[1] < window[0] * 0.9 {
745 step_drops += 1;
746 }
747 }
748
749 if step_drops > lrs.len() / 20 {
750 return LRScheduleType::StepDecay;
751 }
752
753 let log_lrs: Vec<f32> = lrs.iter().map(|&lr| lr.ln()).collect();
755 let exponential_trend = self.calculate_linear_trend(&log_lrs);
756 if exponential_trend < -0.01 {
757 return LRScheduleType::ExponentialDecay;
758 }
759
760 let cyclical_score = self.detect_cyclical_pattern(&lrs);
762 if cyclical_score > 0.3 {
763 return LRScheduleType::Cyclical;
764 }
765
766 LRScheduleType::Unknown
767 }
768
769 fn calculate_linear_trend(&self, values: &[f32]) -> f32 {
771 if values.len() < 2 {
772 return 0.0;
773 }
774
775 let n = values.len() as f32;
776 let x_mean = (n - 1.0) / 2.0;
777 let y_mean = values.iter().sum::<f32>() / n;
778
779 let mut numerator = 0.0;
780 let mut denominator = 0.0;
781
782 for (i, &y) in values.iter().enumerate() {
783 let x = i as f32;
784 numerator += (x - x_mean) * (y - y_mean);
785 denominator += (x - x_mean).powi(2);
786 }
787
788 if denominator > 1e-8 {
789 numerator / denominator
790 } else {
791 0.0
792 }
793 }
794
795 fn detect_cyclical_pattern(&self, values: &[f32]) -> f32 {
797 let mut max_autocorr: f32 = 0.0;
799 for lag in 2..=values.len() / 4 {
800 let autocorr = self.calculate_autocorrelation(values, lag).abs();
801 max_autocorr = max_autocorr.max(autocorr);
802 }
803 max_autocorr
804 }
805
806 fn build_lr_history(&self) -> Vec<LearningRatePoint> {
808 let mut history = Vec::new();
809
810 for (i, metrics) in self.metrics_history.iter().enumerate() {
811 let loss_change = if i > 0 {
812 self.metrics_history[i - 1].train_loss - metrics.train_loss
813 } else {
814 0.0
815 };
816
817 let effectiveness = if loss_change > 0.0 {
818 loss_change / metrics.learning_rate.max(1e-8)
819 } else {
820 0.0
821 };
822
823 history.push(LearningRatePoint {
824 epoch: metrics.epoch,
825 learning_rate: metrics.learning_rate,
826 loss_change,
827 gradient_norm: metrics.gradient_norm,
828 effectiveness,
829 });
830 }
831
832 history
833 }
834
835 fn calculate_lr_impact_score(&self, lr_history: &[LearningRatePoint]) -> f32 {
837 if lr_history.is_empty() {
838 return 0.0;
839 }
840
841 let avg_effectiveness =
842 lr_history.iter().map(|p| p.effectiveness).sum::<f32>() / lr_history.len() as f32;
843
844 avg_effectiveness.max(0.0).min(1.0)
845 }
846
847 fn estimate_optimal_lr(&self, lr_history: &[LearningRatePoint]) -> f32 {
849 if lr_history.is_empty() {
850 return 0.001; }
852
853 lr_history
855 .iter()
856 .max_by(|a, b| {
857 a.effectiveness
858 .partial_cmp(&b.effectiveness)
859 .unwrap_or(std::cmp::Ordering::Equal)
860 })
861 .map(|p| p.learning_rate)
862 .unwrap_or(0.001)
863 }
864
865 fn calculate_lr_sensitivity(&self, lr_history: &[LearningRatePoint]) -> f32 {
867 if lr_history.len() < 2 {
868 return 0.0;
869 }
870
871 let effectiveness_values: Vec<f32> = lr_history.iter().map(|p| p.effectiveness).collect();
872
873 self.calculate_std(&effectiveness_values)
874 }
875
876 fn generate_lr_recommendations(
878 &self,
879 current_lr: f32,
880 lr_history: &[LearningRatePoint],
881 ) -> Vec<LRRecommendation> {
882 let mut recommendations = Vec::new();
883
884 if lr_history.is_empty() {
885 return recommendations;
886 }
887
888 let recent_effectiveness =
889 lr_history.iter().rev().take(5).map(|p| p.effectiveness).sum::<f32>()
890 / 5.0f32.min(lr_history.len() as f32);
891
892 if recent_effectiveness < 0.1 {
893 recommendations.push(LRRecommendation {
894 action: LRAction::Decrease,
895 confidence: 0.7,
896 rationale: "Low learning effectiveness detected".to_string(),
897 expected_improvement: 0.3,
898 });
899 }
900
901 let optimal_lr = self.estimate_optimal_lr(lr_history);
902 if current_lr > optimal_lr * 2.0 {
903 recommendations.push(LRRecommendation {
904 action: LRAction::Decrease,
905 confidence: 0.8,
906 rationale: "Current LR significantly higher than estimated optimal".to_string(),
907 expected_improvement: 0.4,
908 });
909 } else if current_lr < optimal_lr * 0.5 {
910 recommendations.push(LRRecommendation {
911 action: LRAction::Increase,
912 confidence: 0.6,
913 rationale: "Current LR significantly lower than estimated optimal".to_string(),
914 expected_improvement: 0.3,
915 });
916 }
917
918 recommendations
919 }
920
921 async fn analyze_batch_size(&self) -> Result<BatchSizeAnalysis> {
923 if self.metrics_history.is_empty() {
924 return Ok(BatchSizeAnalysis {
925 current_batch_size: 0,
926 batch_size_efficiency: 0.0,
927 gradient_noise_level: 0.0,
928 convergence_speed: 0.0,
929 memory_utilization: 0.0,
930 optimal_batch_size_estimate: 32,
931 batch_size_history: Vec::new(),
932 recommendations: Vec::new(),
933 });
934 }
935
936 let current_batch_size = self
937 .metrics_history
938 .back()
939 .expect("metrics_history should not be empty after empty check")
940 .batch_size;
941 let batch_size_history = self.build_batch_size_history();
942
943 let batch_size_efficiency = self.calculate_batch_size_efficiency(&batch_size_history);
944 let gradient_noise_level = self.estimate_gradient_noise_level();
945 let convergence_speed = self.estimate_convergence_speed();
946 let memory_utilization = self.estimate_memory_utilization(current_batch_size);
947 let optimal_batch_size_estimate = self.estimate_optimal_batch_size(&batch_size_history);
948 let recommendations =
949 self.generate_batch_size_recommendations(current_batch_size, &batch_size_history);
950
951 Ok(BatchSizeAnalysis {
952 current_batch_size,
953 batch_size_efficiency,
954 gradient_noise_level,
955 convergence_speed,
956 memory_utilization,
957 optimal_batch_size_estimate,
958 batch_size_history,
959 recommendations,
960 })
961 }
962
963 fn build_batch_size_history(&self) -> Vec<BatchSizePoint> {
965 let mut history = Vec::new();
966
967 for (i, metrics) in self.metrics_history.iter().enumerate() {
968 let loss_improvement = if i > 0 {
969 self.metrics_history[i - 1].train_loss - metrics.train_loss
970 } else {
971 0.0
972 };
973
974 let gradient_stability =
975 metrics.gradient_norm.map(|gn| 1.0 / (1.0 + gn)).unwrap_or(0.5);
976 let throughput = 1.0; history.push(BatchSizePoint {
979 epoch: metrics.epoch,
980 batch_size: metrics.batch_size,
981 loss_improvement,
982 gradient_stability,
983 throughput,
984 });
985 }
986
987 history
988 }
989
990 fn calculate_batch_size_efficiency(&self, batch_history: &[BatchSizePoint]) -> f32 {
992 if batch_history.is_empty() {
993 return 0.0;
994 }
995
996 let avg_improvement =
997 batch_history.iter().map(|p| p.loss_improvement.max(0.0)).sum::<f32>()
998 / batch_history.len() as f32;
999
1000 let avg_stability = batch_history.iter().map(|p| p.gradient_stability).sum::<f32>()
1001 / batch_history.len() as f32;
1002
1003 (avg_improvement * 0.6 + avg_stability * 0.4).min(1.0)
1004 }
1005
1006 fn estimate_gradient_noise_level(&self) -> f32 {
1008 let gradient_norms: Vec<f32> =
1009 self.metrics_history.iter().filter_map(|m| m.gradient_norm).collect();
1010
1011 if gradient_norms.is_empty() {
1012 return 0.5;
1013 }
1014
1015 let std = self.calculate_std(&gradient_norms);
1016 let mean = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
1017
1018 if mean > 1e-8 {
1019 (std / mean).min(1.0)
1020 } else {
1021 0.5
1022 }
1023 }
1024
1025 fn estimate_convergence_speed(&self) -> f32 {
1027 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1028
1029 if losses.len() < 2 {
1030 return 0.0;
1031 }
1032
1033 let improvement_per_epoch = (losses[0] - losses[losses.len() - 1]) / losses.len() as f32;
1034 improvement_per_epoch.max(0.0).min(1.0)
1035 }
1036
1037 fn estimate_memory_utilization(&self, batch_size: usize) -> f32 {
1039 let normalized_batch_size = batch_size as f32 / 1024.0; normalized_batch_size.min(1.0)
1042 }
1043
1044 fn estimate_optimal_batch_size(&self, batch_history: &[BatchSizePoint]) -> usize {
1046 if batch_history.is_empty() {
1047 return 32;
1048 }
1049
1050 batch_history
1052 .iter()
1053 .max_by(|a, b| {
1054 let score_a = a.loss_improvement * 0.6 + a.gradient_stability * 0.4;
1055 let score_b = b.loss_improvement * 0.6 + b.gradient_stability * 0.4;
1056 score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
1057 })
1058 .map(|p| p.batch_size)
1059 .unwrap_or(32)
1060 }
1061
1062 fn generate_batch_size_recommendations(
1064 &self,
1065 current_batch_size: usize,
1066 _batch_history: &[BatchSizePoint],
1067 ) -> Vec<BatchSizeRecommendation> {
1068 let mut recommendations = Vec::new();
1069
1070 if current_batch_size < 16 {
1071 recommendations.push(BatchSizeRecommendation {
1072 suggested_batch_size: 32,
1073 confidence: 0.7,
1074 rationale: "Very small batch size may lead to noisy gradients".to_string(),
1075 expected_benefits: vec![
1076 "More stable gradients".to_string(),
1077 "Better convergence".to_string(),
1078 ],
1079 });
1080 } else if current_batch_size > 512 {
1081 recommendations.push(BatchSizeRecommendation {
1082 suggested_batch_size: 256,
1083 confidence: 0.6,
1084 rationale: "Large batch size may slow convergence".to_string(),
1085 expected_benefits: vec![
1086 "Faster convergence".to_string(),
1087 "Lower memory usage".to_string(),
1088 ],
1089 });
1090 }
1091
1092 recommendations
1093 }
1094
1095 async fn detect_convergence(&self) -> Result<ConvergenceAnalysis> {
1097 if self.metrics_history.len() < self.config.min_epochs_for_convergence {
1098 return Ok(ConvergenceAnalysis {
1099 convergence_status: ConvergenceStatus::TooEarly,
1100 convergence_probability: 0.0,
1101 epochs_to_convergence_estimate: None,
1102 convergence_criteria: Vec::new(),
1103 early_stopping_recommendation: None,
1104 });
1105 }
1106
1107 let convergence_criteria = self.evaluate_convergence_criteria();
1108 let convergence_status = self.determine_convergence_status(&convergence_criteria);
1109 let convergence_probability = self.calculate_convergence_probability(&convergence_criteria);
1110 let epochs_to_convergence_estimate = self.estimate_epochs_to_convergence();
1111 let early_stopping_recommendation =
1112 self.generate_early_stopping_recommendation(&convergence_criteria);
1113
1114 Ok(ConvergenceAnalysis {
1115 convergence_status,
1116 convergence_probability,
1117 epochs_to_convergence_estimate,
1118 convergence_criteria,
1119 early_stopping_recommendation,
1120 })
1121 }
1122
1123 fn evaluate_convergence_criteria(&self) -> Vec<ConvergenceCriterion> {
1125 let mut criteria = Vec::new();
1126
1127 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1129 let recent_window = 10.min(losses.len());
1130 let recent_losses = &losses[losses.len() - recent_window..];
1131 let loss_std = self.calculate_std(recent_losses);
1132 let loss_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1133 let loss_stability = loss_std / loss_mean.abs().max(1e-8);
1134
1135 criteria.push(ConvergenceCriterion {
1136 criterion_type: ConvergenceCriterionType::LossStability,
1137 current_value: loss_stability,
1138 threshold: self.config.convergence_tolerance,
1139 satisfied: loss_stability < self.config.convergence_tolerance,
1140 confidence: 0.8,
1141 });
1142
1143 if let Some(recent_grad_norm) = self.metrics_history.back().and_then(|m| m.gradient_norm) {
1145 criteria.push(ConvergenceCriterion {
1146 criterion_type: ConvergenceCriterionType::GradientMagnitude,
1147 current_value: recent_grad_norm,
1148 threshold: 1e-4,
1149 satisfied: recent_grad_norm < 1e-4,
1150 confidence: 0.7,
1151 });
1152 }
1153
1154 if losses.len() >= 10 {
1156 let old_window = &losses[losses.len() - 20..losses.len() - 10];
1157 let new_window = &losses[losses.len() - 10..];
1158 let old_mean = old_window.iter().sum::<f32>() / old_window.len() as f32;
1159 let new_mean = new_window.iter().sum::<f32>() / new_window.len() as f32;
1160 let improvement = (old_mean - new_mean) / old_mean.abs().max(1e-8);
1161
1162 criteria.push(ConvergenceCriterion {
1163 criterion_type: ConvergenceCriterionType::LossImprovement,
1164 current_value: improvement,
1165 threshold: 1e-3,
1166 satisfied: improvement < 1e-3,
1167 confidence: 0.6,
1168 });
1169 }
1170
1171 criteria
1172 }
1173
1174 fn determine_convergence_status(&self, criteria: &[ConvergenceCriterion]) -> ConvergenceStatus {
1176 let satisfied_count = criteria.iter().filter(|c| c.satisfied).count();
1177 let total_count = criteria.len();
1178
1179 if total_count == 0 {
1180 return ConvergenceStatus::TooEarly;
1181 }
1182
1183 let satisfaction_rate = satisfied_count as f32 / total_count as f32;
1184
1185 if satisfaction_rate > 0.8 {
1186 ConvergenceStatus::Converged
1187 } else if satisfaction_rate > 0.5 {
1188 ConvergenceStatus::Converging
1189 } else {
1190 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1192 let recent_trend =
1193 self.calculate_linear_trend(&losses[losses.len().saturating_sub(20)..]);
1194
1195 if recent_trend > 0.01 {
1196 ConvergenceStatus::Diverging
1197 } else {
1198 ConvergenceStatus::Oscillating
1199 }
1200 }
1201 }
1202
1203 fn calculate_convergence_probability(&self, criteria: &[ConvergenceCriterion]) -> f32 {
1205 if criteria.is_empty() {
1206 return 0.0;
1207 }
1208
1209 let weighted_satisfaction: f32 =
1210 criteria.iter().map(|c| if c.satisfied { c.confidence } else { 0.0 }).sum();
1211
1212 let total_weight: f32 = criteria.iter().map(|c| c.confidence).sum();
1213
1214 if total_weight > 0.0 {
1215 weighted_satisfaction / total_weight
1216 } else {
1217 0.0
1218 }
1219 }
1220
1221 fn estimate_epochs_to_convergence(&self) -> Option<usize> {
1223 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1224
1225 if losses.len() < 5 {
1226 return None;
1227 }
1228
1229 let improvement_rate = self.calculate_improvement_rate(&losses);
1230
1231 if improvement_rate <= 0.0 {
1232 return None;
1233 }
1234
1235 let current_loss = *losses.last().expect("losses has at least 5 elements after len check");
1236 let target_loss = current_loss * (1.0 - self.config.convergence_tolerance);
1237 let remaining_improvement = current_loss - target_loss;
1238
1239 let epochs_needed = (remaining_improvement / improvement_rate).ceil() as usize;
1240
1241 Some(epochs_needed.min(1000)) }
1243
1244 fn generate_early_stopping_recommendation(
1246 &self,
1247 criteria: &[ConvergenceCriterion],
1248 ) -> Option<EarlyStoppingRecommendation> {
1249 let convergence_probability = self.calculate_convergence_probability(criteria);
1250
1251 if convergence_probability > 0.9 {
1252 Some(EarlyStoppingRecommendation {
1253 should_stop: true,
1254 confidence: convergence_probability,
1255 rationale: "High convergence probability detected".to_string(),
1256 suggested_epochs_remaining: 0,
1257 })
1258 } else if convergence_probability > 0.7 {
1259 Some(EarlyStoppingRecommendation {
1260 should_stop: false,
1261 confidence: convergence_probability,
1262 rationale: "Approaching convergence, continue for a few more epochs".to_string(),
1263 suggested_epochs_remaining: 5,
1264 })
1265 } else {
1266 None
1267 }
1268 }
1269
1270 async fn identify_plateau(&self) -> Result<PlateauAnalysis> {
1272 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1273
1274 if losses.len() < 10 {
1275 return Ok(PlateauAnalysis {
1276 plateau_detected: false,
1277 plateau_duration: 0,
1278 plateau_level: 0.0,
1279 plateau_type: PlateauType::LossPlayteau,
1280 escape_probability: 0.0,
1281 plateau_characteristics: PlateauCharacteristics {
1282 stability: 0.0,
1283 noise_level: 0.0,
1284 gradient_magnitude: 0.0,
1285 overfitting_risk: 0.0,
1286 },
1287 recommendations: Vec::new(),
1288 });
1289 }
1290
1291 let window_size = 10.min(losses.len());
1292 let recent_losses = &losses[losses.len() - window_size..];
1293
1294 let plateau_detected = self.detect_plateau_in_window(recent_losses);
1295 let plateau_duration =
1296 if plateau_detected { self.calculate_plateau_duration(&losses) } else { 0 };
1297
1298 let plateau_level = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1299 let plateau_type = PlateauType::LossPlayteau; let escape_probability =
1301 self.estimate_plateau_escape_probability(&losses, plateau_duration);
1302 let plateau_characteristics = self.analyze_plateau_characteristics(recent_losses);
1303 let recommendations =
1304 self.generate_plateau_recommendations(plateau_detected, plateau_duration);
1305
1306 Ok(PlateauAnalysis {
1307 plateau_detected,
1308 plateau_duration,
1309 plateau_level,
1310 plateau_type,
1311 escape_probability,
1312 plateau_characteristics,
1313 recommendations,
1314 })
1315 }
1316
1317 fn detect_plateau_in_window(&self, values: &[f32]) -> bool {
1319 if values.len() < 3 {
1320 return false;
1321 }
1322
1323 let std = self.calculate_std(values);
1324 let mean = values.iter().sum::<f32>() / values.len() as f32;
1325
1326 std / mean.abs().max(1e-8) < self.config.plateau_threshold
1327 }
1328
1329 fn calculate_plateau_duration(&self, losses: &[f32]) -> usize {
1331 let threshold = self.config.plateau_threshold;
1332 let mut duration = 0;
1333
1334 for window in losses.windows(10).rev() {
1335 let std = self.calculate_std(window);
1336 let mean = window.iter().sum::<f32>() / window.len() as f32;
1337
1338 if std / mean.abs().max(1e-8) < threshold {
1339 duration += 1;
1340 } else {
1341 break;
1342 }
1343 }
1344
1345 duration
1346 }
1347
1348 fn estimate_plateau_escape_probability(&self, losses: &[f32], plateau_duration: usize) -> f32 {
1350 if plateau_duration == 0 {
1351 return 1.0;
1352 }
1353
1354 let duration_factor = 1.0 / (1.0 + plateau_duration as f32 * 0.1);
1356
1357 let recent_trend = if losses.len() >= 5 {
1359 self.calculate_linear_trend(&losses[losses.len() - 5..])
1360 } else {
1361 0.0
1362 };
1363
1364 let trend_factor = if recent_trend < 0.0 { 0.8 } else { 0.3 };
1365
1366 (duration_factor * trend_factor).max(0.1).min(0.9)
1367 }
1368
1369 fn analyze_plateau_characteristics(&self, plateau_values: &[f32]) -> PlateauCharacteristics {
1371 let stability = 1.0 - self.calculate_std(plateau_values);
1372 let noise_level = self.calculate_std(plateau_values);
1373
1374 let gradient_magnitude =
1375 self.metrics_history.back().and_then(|m| m.gradient_norm).unwrap_or(0.0);
1376
1377 let overfitting_risk =
1379 if let Some(val_loss) = self.metrics_history.back().and_then(|m| m.validation_loss) {
1380 let train_loss = self
1381 .metrics_history
1382 .back()
1383 .expect("metrics_history should not be empty in this branch")
1384 .train_loss;
1385 ((val_loss - train_loss) / train_loss.abs().max(1e-8)).max(0.0).min(1.0)
1386 } else {
1387 0.5
1388 };
1389
1390 PlateauCharacteristics {
1391 stability: stability.max(0.0).min(1.0),
1392 noise_level: noise_level.min(1.0),
1393 gradient_magnitude,
1394 overfitting_risk,
1395 }
1396 }
1397
1398 fn generate_plateau_recommendations(
1400 &self,
1401 plateau_detected: bool,
1402 plateau_duration: usize,
1403 ) -> Vec<PlateauRecommendation> {
1404 let mut recommendations = Vec::new();
1405
1406 if !plateau_detected {
1407 return recommendations;
1408 }
1409
1410 if plateau_duration > 20 {
1411 recommendations.push(PlateauRecommendation {
1412 action: PlateauAction::IncreaseLearningRate,
1413 priority: Priority::High,
1414 description: "Long plateau detected, consider increasing learning rate".to_string(),
1415 implementation: "Multiply current learning rate by 2-5x temporarily".to_string(),
1416 });
1417 } else if plateau_duration > 10 {
1418 recommendations.push(PlateauRecommendation {
1419 action: PlateauAction::ChangeBatchSize,
1420 priority: Priority::Medium,
1421 description: "Moderate plateau detected, try changing batch size".to_string(),
1422 implementation: "Increase or decrease batch size by 50%".to_string(),
1423 });
1424 }
1425
1426 if plateau_duration > 30 {
1427 recommendations.push(PlateauRecommendation {
1428 action: PlateauAction::EarlyStopping,
1429 priority: Priority::Critical,
1430 description: "Very long plateau, consider early stopping".to_string(),
1431 implementation: "Stop training and use best checkpoint".to_string(),
1432 });
1433 }
1434
1435 recommendations
1436 }
1437
1438 fn generate_training_summary(&self, report: &mut TrainingDynamicsReport) {
1440 let total_epochs = self.metrics_history.back().map(|m| m.epoch).unwrap_or(0);
1441 let total_steps = self.metrics_history.back().map(|m| m.step).unwrap_or(0);
1442
1443 let training_efficiency = if let Some(loss_analysis) = &report.loss_curve_analysis {
1444 loss_analysis.improvement_rate.max(0.0).min(1.0)
1445 } else {
1446 0.0
1447 };
1448
1449 let convergence_health = if let Some(conv_analysis) = &report.convergence_analysis {
1450 conv_analysis.convergence_probability
1451 } else {
1452 0.0
1453 };
1454
1455 let stability_score = if let Some(loss_analysis) = &report.loss_curve_analysis {
1456 loss_analysis.smoothness
1457 } else {
1458 0.0
1459 };
1460
1461 let overall_progress =
1462 (training_efficiency * 0.4 + convergence_health * 0.3 + stability_score * 0.3)
1463 .max(0.0)
1464 .min(1.0);
1465
1466 report.training_summary = TrainingSummary {
1467 total_epochs,
1468 total_steps,
1469 training_efficiency,
1470 convergence_health,
1471 stability_score,
1472 overall_progress,
1473 };
1474 }
1475
1476 fn generate_training_recommendations(&self, report: &mut TrainingDynamicsReport) {
1478 let mut recommendations = Vec::new();
1479
1480 if let Some(lr_analysis) = &report.learning_rate_analysis {
1482 for lr_rec in &lr_analysis.recommendations {
1483 recommendations.push(TrainingRecommendation {
1484 category: TrainingCategory::LearningRate,
1485 priority: if lr_rec.confidence > 0.8 {
1486 Priority::High
1487 } else {
1488 Priority::Medium
1489 },
1490 description: lr_rec.rationale.clone(),
1491 implementation: format!("{:?} learning rate", lr_rec.action),
1492 expected_impact: lr_rec.expected_improvement,
1493 });
1494 }
1495 }
1496
1497 if let Some(plateau_analysis) = &report.plateau_analysis {
1499 for plateau_rec in &plateau_analysis.recommendations {
1500 recommendations.push(TrainingRecommendation {
1501 category: TrainingCategory::Optimization,
1502 priority: plateau_rec.priority.clone(),
1503 description: plateau_rec.description.clone(),
1504 implementation: plateau_rec.implementation.clone(),
1505 expected_impact: 0.5, });
1507 }
1508 }
1509
1510 if let Some(conv_analysis) = &report.convergence_analysis {
1512 if let Some(early_stop) = &conv_analysis.early_stopping_recommendation {
1513 if early_stop.should_stop {
1514 recommendations.push(TrainingRecommendation {
1515 category: TrainingCategory::EarlyStopping,
1516 priority: Priority::High,
1517 description: early_stop.rationale.clone(),
1518 implementation: "Stop training and save current model".to_string(),
1519 expected_impact: 0.8,
1520 });
1521 }
1522 }
1523 }
1524
1525 report.recommendations = recommendations;
1526 }
1527
1528 pub async fn generate_report(&self) -> Result<TrainingDynamicsReport> {
1530 let mut temp_analyzer = TrainingDynamicsAnalyzer {
1531 config: self.config.clone(),
1532 metrics_history: self.metrics_history.clone(),
1533 analysis_cache: HashMap::new(),
1534 };
1535
1536 temp_analyzer.analyze().await
1537 }
1538
1539 pub fn clear(&mut self) {
1541 self.metrics_history.clear();
1542 self.analysis_cache.clear();
1543 }
1544
1545 pub fn get_training_summary(&self) -> TrainingStateSummary {
1547 let current_metrics = self.metrics_history.back();
1548
1549 TrainingStateSummary {
1550 total_epochs: current_metrics.map(|m| m.epoch).unwrap_or(0),
1551 total_steps: current_metrics.map(|m| m.step).unwrap_or(0),
1552 current_loss: current_metrics.map(|m| m.train_loss).unwrap_or(0.0),
1553 current_lr: current_metrics.map(|m| m.learning_rate).unwrap_or(0.0),
1554 metrics_collected: self.metrics_history.len(),
1555 }
1556 }
1557}
1558
1559#[derive(Debug, Clone, Serialize, Deserialize)]
1561pub struct TrainingStateSummary {
1562 pub total_epochs: usize,
1563 pub total_steps: usize,
1564 pub current_loss: f32,
1565 pub current_lr: f32,
1566 pub metrics_collected: usize,
1567}