1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingDynamicsConfig {
13 pub enable_loss_curve_analysis: bool,
15 pub enable_learning_rate_analysis: bool,
17 pub enable_batch_size_analysis: bool,
19 pub enable_convergence_detection: bool,
21 pub enable_plateau_identification: bool,
23 pub moving_average_window: usize,
25 pub convergence_tolerance: f32,
27 pub plateau_threshold: f32,
29 pub min_epochs_for_convergence: usize,
31 pub max_history_length: usize,
33}
34
35impl Default for TrainingDynamicsConfig {
36 fn default() -> Self {
37 Self {
38 enable_loss_curve_analysis: true,
39 enable_learning_rate_analysis: true,
40 enable_batch_size_analysis: true,
41 enable_convergence_detection: true,
42 enable_plateau_identification: true,
43 moving_average_window: 10,
44 convergence_tolerance: 1e-6,
45 plateau_threshold: 1e-4,
46 min_epochs_for_convergence: 20,
47 max_history_length: 10000,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingMetrics {
55 pub epoch: usize,
56 pub step: usize,
57 pub train_loss: f32,
58 pub validation_loss: Option<f32>,
59 pub learning_rate: f32,
60 pub batch_size: usize,
61 pub gradient_norm: Option<f32>,
62 pub accuracy: Option<f32>,
63 pub timestamp: f64,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LossCurveAnalysis {
69 pub trend: LossTrend,
70 pub smoothness: f32,
71 pub volatility: f32,
72 pub improvement_rate: f32,
73 pub best_loss: f32,
74 pub current_loss: f32,
75 pub loss_reduction_percentage: f32,
76 pub epochs_since_improvement: usize,
77 pub moving_averages: MovingAverages,
78 pub loss_statistics: LossStatistics,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum LossTrend {
83 Decreasing,
84 Increasing,
85 Oscillating,
86 Plateaued,
87 Unknown,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MovingAverages {
92 pub short_term: f32, pub medium_term: f32, pub long_term: f32, }
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct LossStatistics {
99 pub mean: f32,
100 pub std: f32,
101 pub min: f32,
102 pub max: f32,
103 pub median: f32,
104 pub percentile_25: f32,
105 pub percentile_75: f32,
106 pub autocorrelation: f32,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LearningRateAnalysis {
112 pub current_lr: f32,
113 pub lr_schedule_type: LRScheduleType,
114 pub lr_impact_score: f32,
115 pub optimal_lr_estimate: f32,
116 pub lr_sensitivity: f32,
117 pub lr_history: Vec<LearningRatePoint>,
118 pub recommendations: Vec<LRRecommendation>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum LRScheduleType {
123 Constant,
124 StepDecay,
125 ExponentialDecay,
126 CosineAnnealing,
127 ReduceOnPlateau,
128 Warmup,
129 Cyclical,
130 Unknown,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct LearningRatePoint {
135 pub epoch: usize,
136 pub learning_rate: f32,
137 pub loss_change: f32,
138 pub gradient_norm: Option<f32>,
139 pub effectiveness: f32,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct LRRecommendation {
144 pub action: LRAction,
145 pub confidence: f32,
146 pub rationale: String,
147 pub expected_improvement: f32,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum LRAction {
152 Increase,
153 Decrease,
154 KeepCurrent,
155 AddScheduler,
156 ChangeScheduler,
157 AddWarmup,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchSizeAnalysis {
163 pub current_batch_size: usize,
164 pub batch_size_efficiency: f32,
165 pub gradient_noise_level: f32,
166 pub convergence_speed: f32,
167 pub memory_utilization: f32,
168 pub optimal_batch_size_estimate: usize,
169 pub batch_size_history: Vec<BatchSizePoint>,
170 pub recommendations: Vec<BatchSizeRecommendation>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct BatchSizePoint {
175 pub epoch: usize,
176 pub batch_size: usize,
177 pub loss_improvement: f32,
178 pub gradient_stability: f32,
179 pub throughput: f32,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct BatchSizeRecommendation {
184 pub suggested_batch_size: usize,
185 pub confidence: f32,
186 pub rationale: String,
187 pub expected_benefits: Vec<String>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ConvergenceAnalysis {
193 pub convergence_status: ConvergenceStatus,
194 pub convergence_probability: f32,
195 pub epochs_to_convergence_estimate: Option<usize>,
196 pub convergence_criteria: Vec<ConvergenceCriterion>,
197 pub early_stopping_recommendation: Option<EarlyStoppingRecommendation>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum ConvergenceStatus {
202 Converging,
203 Converged,
204 Diverging,
205 Oscillating,
206 TooEarly,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ConvergenceCriterion {
211 pub criterion_type: ConvergenceCriterionType,
212 pub current_value: f32,
213 pub threshold: f32,
214 pub satisfied: bool,
215 pub confidence: f32,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ConvergenceCriterionType {
220 LossStability,
221 GradientMagnitude,
222 LossImprovement,
223 ValidationGap,
224 LearningRateDecay,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct EarlyStoppingRecommendation {
229 pub should_stop: bool,
230 pub confidence: f32,
231 pub rationale: String,
232 pub suggested_epochs_remaining: usize,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PlateauAnalysis {
238 pub plateau_detected: bool,
239 pub plateau_duration: usize,
240 pub plateau_level: f32,
241 pub plateau_type: PlateauType,
242 pub escape_probability: f32,
243 pub plateau_characteristics: PlateauCharacteristics,
244 pub recommendations: Vec<PlateauRecommendation>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum PlateauType {
249 LossPlayteau,
250 GradientPlateau,
251 AccuracyPlateau,
252 LearningRatePlateau,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct PlateauCharacteristics {
257 pub stability: f32,
258 pub noise_level: f32,
259 pub gradient_magnitude: f32,
260 pub overfitting_risk: f32,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct PlateauRecommendation {
265 pub action: PlateauAction,
266 pub priority: Priority,
267 pub description: String,
268 pub implementation: String,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub enum PlateauAction {
273 IncreaseLearningRate,
274 DecreaseLearningRate,
275 ChangeBatchSize,
276 AddRegularization,
277 RemoveRegularization,
278 ChangeOptimizer,
279 AddNoise,
280 EarlyStopping,
281 ContinueTraining,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum Priority {
286 Critical,
287 High,
288 Medium,
289 Low,
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct TrainingDynamicsReport {
295 pub loss_curve_analysis: Option<LossCurveAnalysis>,
296 pub learning_rate_analysis: Option<LearningRateAnalysis>,
297 pub batch_size_analysis: Option<BatchSizeAnalysis>,
298 pub convergence_analysis: Option<ConvergenceAnalysis>,
299 pub plateau_analysis: Option<PlateauAnalysis>,
300 pub training_summary: TrainingSummary,
301 pub recommendations: Vec<TrainingRecommendation>,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct TrainingSummary {
306 pub total_epochs: usize,
307 pub total_steps: usize,
308 pub training_efficiency: f32,
309 pub convergence_health: f32,
310 pub stability_score: f32,
311 pub overall_progress: f32,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct TrainingRecommendation {
316 pub category: TrainingCategory,
317 pub priority: Priority,
318 pub description: String,
319 pub implementation: String,
320 pub expected_impact: f32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub enum TrainingCategory {
325 LearningRate,
326 BatchSize,
327 Optimization,
328 Regularization,
329 EarlyStopping,
330 Architecture,
331}
332
333#[derive(Debug)]
335pub struct TrainingDynamicsAnalyzer {
336 config: TrainingDynamicsConfig,
337 metrics_history: VecDeque<TrainingMetrics>,
338 analysis_cache: HashMap<String, TrainingDynamicsReport>,
339}
340
341impl TrainingDynamicsAnalyzer {
342 pub fn new(config: TrainingDynamicsConfig) -> Self {
344 Self {
345 config,
346 metrics_history: VecDeque::new(),
347 analysis_cache: HashMap::new(),
348 }
349 }
350
351 pub fn record_metrics(&mut self, metrics: TrainingMetrics) {
353 self.metrics_history.push_back(metrics);
354
355 while self.metrics_history.len() > self.config.max_history_length {
357 self.metrics_history.pop_front();
358 }
359 }
360
361 pub async fn analyze(&mut self) -> Result<TrainingDynamicsReport> {
363 let mut report = TrainingDynamicsReport {
364 loss_curve_analysis: None,
365 learning_rate_analysis: None,
366 batch_size_analysis: None,
367 convergence_analysis: None,
368 plateau_analysis: None,
369 training_summary: TrainingSummary {
370 total_epochs: 0,
371 total_steps: 0,
372 training_efficiency: 0.0,
373 convergence_health: 0.0,
374 stability_score: 0.0,
375 overall_progress: 0.0,
376 },
377 recommendations: Vec::new(),
378 };
379
380 if self.config.enable_loss_curve_analysis {
381 report.loss_curve_analysis = Some(self.analyze_loss_curve().await?);
382 }
383
384 if self.config.enable_learning_rate_analysis {
385 report.learning_rate_analysis = Some(self.analyze_learning_rate().await?);
386 }
387
388 if self.config.enable_batch_size_analysis {
389 report.batch_size_analysis = Some(self.analyze_batch_size().await?);
390 }
391
392 if self.config.enable_convergence_detection {
393 report.convergence_analysis = Some(self.detect_convergence().await?);
394 }
395
396 if self.config.enable_plateau_identification {
397 report.plateau_analysis = Some(self.identify_plateau().await?);
398 }
399
400 self.generate_training_summary(&mut report);
401 self.generate_training_recommendations(&mut report);
402
403 Ok(report)
404 }
405
406 async fn analyze_loss_curve(&self) -> Result<LossCurveAnalysis> {
408 if self.metrics_history.is_empty() {
409 return Ok(LossCurveAnalysis {
410 trend: LossTrend::Unknown,
411 smoothness: 0.0,
412 volatility: 0.0,
413 improvement_rate: 0.0,
414 best_loss: 0.0,
415 current_loss: 0.0,
416 loss_reduction_percentage: 0.0,
417 epochs_since_improvement: 0,
418 moving_averages: MovingAverages {
419 short_term: 0.0,
420 medium_term: 0.0,
421 long_term: 0.0,
422 },
423 loss_statistics: LossStatistics {
424 mean: 0.0,
425 std: 0.0,
426 min: 0.0,
427 max: 0.0,
428 median: 0.0,
429 percentile_25: 0.0,
430 percentile_75: 0.0,
431 autocorrelation: 0.0,
432 },
433 });
434 }
435
436 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
437
438 let trend = self.detect_loss_trend(&losses);
439 let smoothness = self.calculate_smoothness(&losses);
440 let volatility = self.calculate_volatility(&losses);
441 let improvement_rate = self.calculate_improvement_rate(&losses);
442
443 let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
444 let current_loss = *losses.last().expect("losses is non-empty from metrics_history");
445 let loss_reduction_percentage = if losses.len() > 1 {
446 ((losses[0] - current_loss) / losses[0].abs()) * 100.0
447 } else {
448 0.0
449 };
450
451 let epochs_since_improvement = self.calculate_epochs_since_improvement(&losses, best_loss);
452 let moving_averages = self.calculate_moving_averages(&losses);
453 let loss_statistics = self.calculate_loss_statistics(&losses);
454
455 Ok(LossCurveAnalysis {
456 trend,
457 smoothness,
458 volatility,
459 improvement_rate,
460 best_loss,
461 current_loss,
462 loss_reduction_percentage,
463 epochs_since_improvement,
464 moving_averages,
465 loss_statistics,
466 })
467 }
468
469 fn detect_loss_trend(&self, losses: &[f32]) -> LossTrend {
471 if losses.len() < 3 {
472 return LossTrend::Unknown;
473 }
474
475 let window_size = (losses.len() / 4).max(5).min(20);
476 let recent_losses = &losses[losses.len().saturating_sub(window_size)..];
477 let early_losses = &losses[..window_size.min(losses.len())];
478
479 let recent_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
480 let early_mean = early_losses.iter().sum::<f32>() / early_losses.len() as f32;
481
482 let improvement = (early_mean - recent_mean) / early_mean.abs();
483
484 let recent_std = self.calculate_std(recent_losses);
486 let recent_mean_abs = recent_mean.abs();
487
488 if recent_std / recent_mean_abs.max(1e-8) < self.config.plateau_threshold {
489 return LossTrend::Plateaued;
490 }
491
492 let oscillation_score = self.detect_oscillation(losses);
494 if oscillation_score > 0.5 {
495 return LossTrend::Oscillating;
496 }
497
498 if improvement > 0.01 {
499 LossTrend::Decreasing
500 } else if improvement < -0.01 {
501 LossTrend::Increasing
502 } else {
503 LossTrend::Plateaued
504 }
505 }
506
507 fn calculate_smoothness(&self, losses: &[f32]) -> f32 {
509 if losses.len() < 2 {
510 return 1.0;
511 }
512
513 let differences: Vec<f32> = losses.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
514
515 let mean_diff = differences.iter().sum::<f32>() / differences.len() as f32;
516 let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
517
518 1.0 / (1.0 + mean_diff / mean_loss.abs().max(1e-8))
520 }
521
522 fn calculate_volatility(&self, losses: &[f32]) -> f32 {
524 if losses.len() < 2 {
525 return 0.0;
526 }
527
528 let returns: Vec<f32> =
529 losses.windows(2).map(|w| (w[1] - w[0]) / w[0].abs().max(1e-8)).collect();
530
531 self.calculate_std(&returns)
532 }
533
534 fn calculate_improvement_rate(&self, losses: &[f32]) -> f32 {
536 if losses.len() < 2 {
537 return 0.0;
538 }
539
540 let total_improvement = losses[0] - losses[losses.len() - 1];
541 let epochs = losses.len() as f32;
542
543 total_improvement / epochs
544 }
545
546 fn calculate_epochs_since_improvement(&self, losses: &[f32], best_loss: f32) -> usize {
548 for (i, &loss) in losses.iter().rev().enumerate() {
549 if (loss - best_loss).abs() < 1e-8 {
550 return i;
551 }
552 }
553 losses.len()
554 }
555
556 fn calculate_moving_averages(&self, losses: &[f32]) -> MovingAverages {
558 let short_window = 5.min(losses.len());
559 let medium_window = 20.min(losses.len());
560 let long_window = 100.min(losses.len());
561
562 let short_term = if short_window > 0 {
563 losses[losses.len() - short_window..].iter().sum::<f32>() / short_window as f32
564 } else {
565 0.0
566 };
567
568 let medium_term = if medium_window > 0 {
569 losses[losses.len() - medium_window..].iter().sum::<f32>() / medium_window as f32
570 } else {
571 0.0
572 };
573
574 let long_term = if long_window > 0 {
575 losses[losses.len() - long_window..].iter().sum::<f32>() / long_window as f32
576 } else {
577 0.0
578 };
579
580 MovingAverages {
581 short_term,
582 medium_term,
583 long_term,
584 }
585 }
586
587 fn calculate_loss_statistics(&self, losses: &[f32]) -> LossStatistics {
589 if losses.is_empty() {
590 return LossStatistics {
591 mean: 0.0,
592 std: 0.0,
593 min: 0.0,
594 max: 0.0,
595 median: 0.0,
596 percentile_25: 0.0,
597 percentile_75: 0.0,
598 autocorrelation: 0.0,
599 };
600 }
601
602 let mean = losses.iter().sum::<f32>() / losses.len() as f32;
603 let std = self.calculate_std(losses);
604
605 let mut sorted_losses = losses.to_vec();
606 sorted_losses.sort_by(|a, b| a.partial_cmp(b).unwrap());
607
608 let min = sorted_losses[0];
609 let max = sorted_losses[sorted_losses.len() - 1];
610 let median = sorted_losses[sorted_losses.len() / 2];
611 let percentile_25 = sorted_losses[sorted_losses.len() / 4];
612 let percentile_75 = sorted_losses[3 * sorted_losses.len() / 4];
613
614 let autocorrelation = self.calculate_autocorrelation(losses, 1);
615
616 LossStatistics {
617 mean,
618 std,
619 min,
620 max,
621 median,
622 percentile_25,
623 percentile_75,
624 autocorrelation,
625 }
626 }
627
628 fn calculate_std(&self, values: &[f32]) -> f32 {
630 if values.len() < 2 {
631 return 0.0;
632 }
633
634 let mean = values.iter().sum::<f32>() / values.len() as f32;
635 let variance =
636 values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
637
638 variance.sqrt()
639 }
640
641 fn detect_oscillation(&self, losses: &[f32]) -> f32 {
643 if losses.len() < 4 {
644 return 0.0;
645 }
646
647 let mut direction_changes = 0;
648 let mut total_comparisons = 0;
649
650 for i in 1..losses.len() - 1 {
651 let prev_direction = losses[i] > losses[i - 1];
652 let next_direction = losses[i + 1] > losses[i];
653
654 if prev_direction != next_direction {
655 direction_changes += 1;
656 }
657 total_comparisons += 1;
658 }
659
660 direction_changes as f32 / total_comparisons as f32
661 }
662
663 fn calculate_autocorrelation(&self, values: &[f32], lag: usize) -> f32 {
665 if values.len() <= lag {
666 return 0.0;
667 }
668
669 let mean = values.iter().sum::<f32>() / values.len() as f32;
670
671 let mut numerator = 0.0;
672 let mut denominator = 0.0;
673
674 for i in 0..values.len() - lag {
675 numerator += (values[i] - mean) * (values[i + lag] - mean);
676 }
677
678 for &value in values {
679 denominator += (value - mean).powi(2);
680 }
681
682 if denominator > 1e-8 {
683 numerator / denominator
684 } else {
685 0.0
686 }
687 }
688
689 async fn analyze_learning_rate(&self) -> Result<LearningRateAnalysis> {
691 if self.metrics_history.is_empty() {
692 return Ok(LearningRateAnalysis {
693 current_lr: 0.0,
694 lr_schedule_type: LRScheduleType::Unknown,
695 lr_impact_score: 0.0,
696 optimal_lr_estimate: 0.0,
697 lr_sensitivity: 0.0,
698 lr_history: Vec::new(),
699 recommendations: Vec::new(),
700 });
701 }
702
703 let current_lr = self.metrics_history.back().unwrap().learning_rate;
704 let lr_schedule_type = self.detect_lr_schedule_type();
705
706 let lr_history = self.build_lr_history();
707 let lr_impact_score = self.calculate_lr_impact_score(&lr_history);
708 let optimal_lr_estimate = self.estimate_optimal_lr(&lr_history);
709 let lr_sensitivity = self.calculate_lr_sensitivity(&lr_history);
710 let recommendations = self.generate_lr_recommendations(current_lr, &lr_history);
711
712 Ok(LearningRateAnalysis {
713 current_lr,
714 lr_schedule_type,
715 lr_impact_score,
716 optimal_lr_estimate,
717 lr_sensitivity,
718 lr_history,
719 recommendations,
720 })
721 }
722
723 fn detect_lr_schedule_type(&self) -> LRScheduleType {
725 let lrs: Vec<f32> = self.metrics_history.iter().map(|m| m.learning_rate).collect();
726
727 if lrs.len() < 3 {
728 return LRScheduleType::Unknown;
729 }
730
731 let lr_std = self.calculate_std(&lrs);
733 if lr_std < 1e-8 {
734 return LRScheduleType::Constant;
735 }
736
737 let mut step_drops = 0;
739 for window in lrs.windows(2) {
740 if window[1] < window[0] * 0.9 {
741 step_drops += 1;
742 }
743 }
744
745 if step_drops > lrs.len() / 20 {
746 return LRScheduleType::StepDecay;
747 }
748
749 let log_lrs: Vec<f32> = lrs.iter().map(|&lr| lr.ln()).collect();
751 let exponential_trend = self.calculate_linear_trend(&log_lrs);
752 if exponential_trend < -0.01 {
753 return LRScheduleType::ExponentialDecay;
754 }
755
756 let cyclical_score = self.detect_cyclical_pattern(&lrs);
758 if cyclical_score > 0.3 {
759 return LRScheduleType::Cyclical;
760 }
761
762 LRScheduleType::Unknown
763 }
764
765 fn calculate_linear_trend(&self, values: &[f32]) -> f32 {
767 if values.len() < 2 {
768 return 0.0;
769 }
770
771 let n = values.len() as f32;
772 let x_mean = (n - 1.0) / 2.0;
773 let y_mean = values.iter().sum::<f32>() / n;
774
775 let mut numerator = 0.0;
776 let mut denominator = 0.0;
777
778 for (i, &y) in values.iter().enumerate() {
779 let x = i as f32;
780 numerator += (x - x_mean) * (y - y_mean);
781 denominator += (x - x_mean).powi(2);
782 }
783
784 if denominator > 1e-8 {
785 numerator / denominator
786 } else {
787 0.0
788 }
789 }
790
791 fn detect_cyclical_pattern(&self, values: &[f32]) -> f32 {
793 let mut max_autocorr: f32 = 0.0;
795 for lag in 2..=values.len() / 4 {
796 let autocorr = self.calculate_autocorrelation(values, lag).abs();
797 max_autocorr = max_autocorr.max(autocorr);
798 }
799 max_autocorr
800 }
801
802 fn build_lr_history(&self) -> Vec<LearningRatePoint> {
804 let mut history = Vec::new();
805
806 for (i, metrics) in self.metrics_history.iter().enumerate() {
807 let loss_change = if i > 0 {
808 self.metrics_history[i - 1].train_loss - metrics.train_loss
809 } else {
810 0.0
811 };
812
813 let effectiveness = if loss_change > 0.0 {
814 loss_change / metrics.learning_rate.max(1e-8)
815 } else {
816 0.0
817 };
818
819 history.push(LearningRatePoint {
820 epoch: metrics.epoch,
821 learning_rate: metrics.learning_rate,
822 loss_change,
823 gradient_norm: metrics.gradient_norm,
824 effectiveness,
825 });
826 }
827
828 history
829 }
830
831 fn calculate_lr_impact_score(&self, lr_history: &[LearningRatePoint]) -> f32 {
833 if lr_history.is_empty() {
834 return 0.0;
835 }
836
837 let avg_effectiveness =
838 lr_history.iter().map(|p| p.effectiveness).sum::<f32>() / lr_history.len() as f32;
839
840 avg_effectiveness.max(0.0).min(1.0)
841 }
842
843 fn estimate_optimal_lr(&self, lr_history: &[LearningRatePoint]) -> f32 {
845 if lr_history.is_empty() {
846 return 0.001; }
848
849 lr_history
851 .iter()
852 .max_by(|a, b| a.effectiveness.partial_cmp(&b.effectiveness).unwrap())
853 .map(|p| p.learning_rate)
854 .unwrap_or(0.001)
855 }
856
857 fn calculate_lr_sensitivity(&self, lr_history: &[LearningRatePoint]) -> f32 {
859 if lr_history.len() < 2 {
860 return 0.0;
861 }
862
863 let effectiveness_values: Vec<f32> = lr_history.iter().map(|p| p.effectiveness).collect();
864
865 self.calculate_std(&effectiveness_values)
866 }
867
868 fn generate_lr_recommendations(
870 &self,
871 current_lr: f32,
872 lr_history: &[LearningRatePoint],
873 ) -> Vec<LRRecommendation> {
874 let mut recommendations = Vec::new();
875
876 if lr_history.is_empty() {
877 return recommendations;
878 }
879
880 let recent_effectiveness =
881 lr_history.iter().rev().take(5).map(|p| p.effectiveness).sum::<f32>()
882 / 5.0f32.min(lr_history.len() as f32);
883
884 if recent_effectiveness < 0.1 {
885 recommendations.push(LRRecommendation {
886 action: LRAction::Decrease,
887 confidence: 0.7,
888 rationale: "Low learning effectiveness detected".to_string(),
889 expected_improvement: 0.3,
890 });
891 }
892
893 let optimal_lr = self.estimate_optimal_lr(lr_history);
894 if current_lr > optimal_lr * 2.0 {
895 recommendations.push(LRRecommendation {
896 action: LRAction::Decrease,
897 confidence: 0.8,
898 rationale: "Current LR significantly higher than estimated optimal".to_string(),
899 expected_improvement: 0.4,
900 });
901 } else if current_lr < optimal_lr * 0.5 {
902 recommendations.push(LRRecommendation {
903 action: LRAction::Increase,
904 confidence: 0.6,
905 rationale: "Current LR significantly lower than estimated optimal".to_string(),
906 expected_improvement: 0.3,
907 });
908 }
909
910 recommendations
911 }
912
913 async fn analyze_batch_size(&self) -> Result<BatchSizeAnalysis> {
915 if self.metrics_history.is_empty() {
916 return Ok(BatchSizeAnalysis {
917 current_batch_size: 0,
918 batch_size_efficiency: 0.0,
919 gradient_noise_level: 0.0,
920 convergence_speed: 0.0,
921 memory_utilization: 0.0,
922 optimal_batch_size_estimate: 32,
923 batch_size_history: Vec::new(),
924 recommendations: Vec::new(),
925 });
926 }
927
928 let current_batch_size = self.metrics_history.back().unwrap().batch_size;
929 let batch_size_history = self.build_batch_size_history();
930
931 let batch_size_efficiency = self.calculate_batch_size_efficiency(&batch_size_history);
932 let gradient_noise_level = self.estimate_gradient_noise_level();
933 let convergence_speed = self.estimate_convergence_speed();
934 let memory_utilization = self.estimate_memory_utilization(current_batch_size);
935 let optimal_batch_size_estimate = self.estimate_optimal_batch_size(&batch_size_history);
936 let recommendations =
937 self.generate_batch_size_recommendations(current_batch_size, &batch_size_history);
938
939 Ok(BatchSizeAnalysis {
940 current_batch_size,
941 batch_size_efficiency,
942 gradient_noise_level,
943 convergence_speed,
944 memory_utilization,
945 optimal_batch_size_estimate,
946 batch_size_history,
947 recommendations,
948 })
949 }
950
951 fn build_batch_size_history(&self) -> Vec<BatchSizePoint> {
953 let mut history = Vec::new();
954
955 for (i, metrics) in self.metrics_history.iter().enumerate() {
956 let loss_improvement = if i > 0 {
957 self.metrics_history[i - 1].train_loss - metrics.train_loss
958 } else {
959 0.0
960 };
961
962 let gradient_stability =
963 metrics.gradient_norm.map(|gn| 1.0 / (1.0 + gn)).unwrap_or(0.5);
964 let throughput = 1.0; history.push(BatchSizePoint {
967 epoch: metrics.epoch,
968 batch_size: metrics.batch_size,
969 loss_improvement,
970 gradient_stability,
971 throughput,
972 });
973 }
974
975 history
976 }
977
978 fn calculate_batch_size_efficiency(&self, batch_history: &[BatchSizePoint]) -> f32 {
980 if batch_history.is_empty() {
981 return 0.0;
982 }
983
984 let avg_improvement =
985 batch_history.iter().map(|p| p.loss_improvement.max(0.0)).sum::<f32>()
986 / batch_history.len() as f32;
987
988 let avg_stability = batch_history.iter().map(|p| p.gradient_stability).sum::<f32>()
989 / batch_history.len() as f32;
990
991 (avg_improvement * 0.6 + avg_stability * 0.4).min(1.0)
992 }
993
994 fn estimate_gradient_noise_level(&self) -> f32 {
996 let gradient_norms: Vec<f32> =
997 self.metrics_history.iter().filter_map(|m| m.gradient_norm).collect();
998
999 if gradient_norms.is_empty() {
1000 return 0.5;
1001 }
1002
1003 let std = self.calculate_std(&gradient_norms);
1004 let mean = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
1005
1006 if mean > 1e-8 {
1007 (std / mean).min(1.0)
1008 } else {
1009 0.5
1010 }
1011 }
1012
1013 fn estimate_convergence_speed(&self) -> f32 {
1015 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1016
1017 if losses.len() < 2 {
1018 return 0.0;
1019 }
1020
1021 let improvement_per_epoch = (losses[0] - losses[losses.len() - 1]) / losses.len() as f32;
1022 improvement_per_epoch.max(0.0).min(1.0)
1023 }
1024
1025 fn estimate_memory_utilization(&self, batch_size: usize) -> f32 {
1027 let normalized_batch_size = batch_size as f32 / 1024.0; normalized_batch_size.min(1.0)
1030 }
1031
1032 fn estimate_optimal_batch_size(&self, batch_history: &[BatchSizePoint]) -> usize {
1034 if batch_history.is_empty() {
1035 return 32;
1036 }
1037
1038 batch_history
1040 .iter()
1041 .max_by(|a, b| {
1042 let score_a = a.loss_improvement * 0.6 + a.gradient_stability * 0.4;
1043 let score_b = b.loss_improvement * 0.6 + b.gradient_stability * 0.4;
1044 score_a.partial_cmp(&score_b).unwrap()
1045 })
1046 .map(|p| p.batch_size)
1047 .unwrap_or(32)
1048 }
1049
1050 fn generate_batch_size_recommendations(
1052 &self,
1053 current_batch_size: usize,
1054 _batch_history: &[BatchSizePoint],
1055 ) -> Vec<BatchSizeRecommendation> {
1056 let mut recommendations = Vec::new();
1057
1058 if current_batch_size < 16 {
1059 recommendations.push(BatchSizeRecommendation {
1060 suggested_batch_size: 32,
1061 confidence: 0.7,
1062 rationale: "Very small batch size may lead to noisy gradients".to_string(),
1063 expected_benefits: vec![
1064 "More stable gradients".to_string(),
1065 "Better convergence".to_string(),
1066 ],
1067 });
1068 } else if current_batch_size > 512 {
1069 recommendations.push(BatchSizeRecommendation {
1070 suggested_batch_size: 256,
1071 confidence: 0.6,
1072 rationale: "Large batch size may slow convergence".to_string(),
1073 expected_benefits: vec![
1074 "Faster convergence".to_string(),
1075 "Lower memory usage".to_string(),
1076 ],
1077 });
1078 }
1079
1080 recommendations
1081 }
1082
1083 async fn detect_convergence(&self) -> Result<ConvergenceAnalysis> {
1085 if self.metrics_history.len() < self.config.min_epochs_for_convergence {
1086 return Ok(ConvergenceAnalysis {
1087 convergence_status: ConvergenceStatus::TooEarly,
1088 convergence_probability: 0.0,
1089 epochs_to_convergence_estimate: None,
1090 convergence_criteria: Vec::new(),
1091 early_stopping_recommendation: None,
1092 });
1093 }
1094
1095 let convergence_criteria = self.evaluate_convergence_criteria();
1096 let convergence_status = self.determine_convergence_status(&convergence_criteria);
1097 let convergence_probability = self.calculate_convergence_probability(&convergence_criteria);
1098 let epochs_to_convergence_estimate = self.estimate_epochs_to_convergence();
1099 let early_stopping_recommendation =
1100 self.generate_early_stopping_recommendation(&convergence_criteria);
1101
1102 Ok(ConvergenceAnalysis {
1103 convergence_status,
1104 convergence_probability,
1105 epochs_to_convergence_estimate,
1106 convergence_criteria,
1107 early_stopping_recommendation,
1108 })
1109 }
1110
1111 fn evaluate_convergence_criteria(&self) -> Vec<ConvergenceCriterion> {
1113 let mut criteria = Vec::new();
1114
1115 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1117 let recent_window = 10.min(losses.len());
1118 let recent_losses = &losses[losses.len() - recent_window..];
1119 let loss_std = self.calculate_std(recent_losses);
1120 let loss_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1121 let loss_stability = loss_std / loss_mean.abs().max(1e-8);
1122
1123 criteria.push(ConvergenceCriterion {
1124 criterion_type: ConvergenceCriterionType::LossStability,
1125 current_value: loss_stability,
1126 threshold: self.config.convergence_tolerance,
1127 satisfied: loss_stability < self.config.convergence_tolerance,
1128 confidence: 0.8,
1129 });
1130
1131 if let Some(recent_grad_norm) = self.metrics_history.back().and_then(|m| m.gradient_norm) {
1133 criteria.push(ConvergenceCriterion {
1134 criterion_type: ConvergenceCriterionType::GradientMagnitude,
1135 current_value: recent_grad_norm,
1136 threshold: 1e-4,
1137 satisfied: recent_grad_norm < 1e-4,
1138 confidence: 0.7,
1139 });
1140 }
1141
1142 if losses.len() >= 10 {
1144 let old_window = &losses[losses.len() - 20..losses.len() - 10];
1145 let new_window = &losses[losses.len() - 10..];
1146 let old_mean = old_window.iter().sum::<f32>() / old_window.len() as f32;
1147 let new_mean = new_window.iter().sum::<f32>() / new_window.len() as f32;
1148 let improvement = (old_mean - new_mean) / old_mean.abs().max(1e-8);
1149
1150 criteria.push(ConvergenceCriterion {
1151 criterion_type: ConvergenceCriterionType::LossImprovement,
1152 current_value: improvement,
1153 threshold: 1e-3,
1154 satisfied: improvement < 1e-3,
1155 confidence: 0.6,
1156 });
1157 }
1158
1159 criteria
1160 }
1161
1162 fn determine_convergence_status(&self, criteria: &[ConvergenceCriterion]) -> ConvergenceStatus {
1164 let satisfied_count = criteria.iter().filter(|c| c.satisfied).count();
1165 let total_count = criteria.len();
1166
1167 if total_count == 0 {
1168 return ConvergenceStatus::TooEarly;
1169 }
1170
1171 let satisfaction_rate = satisfied_count as f32 / total_count as f32;
1172
1173 if satisfaction_rate > 0.8 {
1174 ConvergenceStatus::Converged
1175 } else if satisfaction_rate > 0.5 {
1176 ConvergenceStatus::Converging
1177 } else {
1178 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1180 let recent_trend =
1181 self.calculate_linear_trend(&losses[losses.len().saturating_sub(20)..]);
1182
1183 if recent_trend > 0.01 {
1184 ConvergenceStatus::Diverging
1185 } else {
1186 ConvergenceStatus::Oscillating
1187 }
1188 }
1189 }
1190
1191 fn calculate_convergence_probability(&self, criteria: &[ConvergenceCriterion]) -> f32 {
1193 if criteria.is_empty() {
1194 return 0.0;
1195 }
1196
1197 let weighted_satisfaction: f32 =
1198 criteria.iter().map(|c| if c.satisfied { c.confidence } else { 0.0 }).sum();
1199
1200 let total_weight: f32 = criteria.iter().map(|c| c.confidence).sum();
1201
1202 if total_weight > 0.0 {
1203 weighted_satisfaction / total_weight
1204 } else {
1205 0.0
1206 }
1207 }
1208
1209 fn estimate_epochs_to_convergence(&self) -> Option<usize> {
1211 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1212
1213 if losses.len() < 5 {
1214 return None;
1215 }
1216
1217 let improvement_rate = self.calculate_improvement_rate(&losses);
1218
1219 if improvement_rate <= 0.0 {
1220 return None;
1221 }
1222
1223 let current_loss = *losses.last().expect("losses has at least 5 elements after len check");
1224 let target_loss = current_loss * (1.0 - self.config.convergence_tolerance);
1225 let remaining_improvement = current_loss - target_loss;
1226
1227 let epochs_needed = (remaining_improvement / improvement_rate).ceil() as usize;
1228
1229 Some(epochs_needed.min(1000)) }
1231
1232 fn generate_early_stopping_recommendation(
1234 &self,
1235 criteria: &[ConvergenceCriterion],
1236 ) -> Option<EarlyStoppingRecommendation> {
1237 let convergence_probability = self.calculate_convergence_probability(criteria);
1238
1239 if convergence_probability > 0.9 {
1240 Some(EarlyStoppingRecommendation {
1241 should_stop: true,
1242 confidence: convergence_probability,
1243 rationale: "High convergence probability detected".to_string(),
1244 suggested_epochs_remaining: 0,
1245 })
1246 } else if convergence_probability > 0.7 {
1247 Some(EarlyStoppingRecommendation {
1248 should_stop: false,
1249 confidence: convergence_probability,
1250 rationale: "Approaching convergence, continue for a few more epochs".to_string(),
1251 suggested_epochs_remaining: 5,
1252 })
1253 } else {
1254 None
1255 }
1256 }
1257
1258 async fn identify_plateau(&self) -> Result<PlateauAnalysis> {
1260 let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1261
1262 if losses.len() < 10 {
1263 return Ok(PlateauAnalysis {
1264 plateau_detected: false,
1265 plateau_duration: 0,
1266 plateau_level: 0.0,
1267 plateau_type: PlateauType::LossPlayteau,
1268 escape_probability: 0.0,
1269 plateau_characteristics: PlateauCharacteristics {
1270 stability: 0.0,
1271 noise_level: 0.0,
1272 gradient_magnitude: 0.0,
1273 overfitting_risk: 0.0,
1274 },
1275 recommendations: Vec::new(),
1276 });
1277 }
1278
1279 let window_size = 10.min(losses.len());
1280 let recent_losses = &losses[losses.len() - window_size..];
1281
1282 let plateau_detected = self.detect_plateau_in_window(recent_losses);
1283 let plateau_duration =
1284 if plateau_detected { self.calculate_plateau_duration(&losses) } else { 0 };
1285
1286 let plateau_level = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1287 let plateau_type = PlateauType::LossPlayteau; let escape_probability =
1289 self.estimate_plateau_escape_probability(&losses, plateau_duration);
1290 let plateau_characteristics = self.analyze_plateau_characteristics(recent_losses);
1291 let recommendations =
1292 self.generate_plateau_recommendations(plateau_detected, plateau_duration);
1293
1294 Ok(PlateauAnalysis {
1295 plateau_detected,
1296 plateau_duration,
1297 plateau_level,
1298 plateau_type,
1299 escape_probability,
1300 plateau_characteristics,
1301 recommendations,
1302 })
1303 }
1304
1305 fn detect_plateau_in_window(&self, values: &[f32]) -> bool {
1307 if values.len() < 3 {
1308 return false;
1309 }
1310
1311 let std = self.calculate_std(values);
1312 let mean = values.iter().sum::<f32>() / values.len() as f32;
1313
1314 std / mean.abs().max(1e-8) < self.config.plateau_threshold
1315 }
1316
1317 fn calculate_plateau_duration(&self, losses: &[f32]) -> usize {
1319 let threshold = self.config.plateau_threshold;
1320 let mut duration = 0;
1321
1322 for window in losses.windows(10).rev() {
1323 let std = self.calculate_std(window);
1324 let mean = window.iter().sum::<f32>() / window.len() as f32;
1325
1326 if std / mean.abs().max(1e-8) < threshold {
1327 duration += 1;
1328 } else {
1329 break;
1330 }
1331 }
1332
1333 duration
1334 }
1335
1336 fn estimate_plateau_escape_probability(&self, losses: &[f32], plateau_duration: usize) -> f32 {
1338 if plateau_duration == 0 {
1339 return 1.0;
1340 }
1341
1342 let duration_factor = 1.0 / (1.0 + plateau_duration as f32 * 0.1);
1344
1345 let recent_trend = if losses.len() >= 5 {
1347 self.calculate_linear_trend(&losses[losses.len() - 5..])
1348 } else {
1349 0.0
1350 };
1351
1352 let trend_factor = if recent_trend < 0.0 { 0.8 } else { 0.3 };
1353
1354 (duration_factor * trend_factor).max(0.1).min(0.9)
1355 }
1356
1357 fn analyze_plateau_characteristics(&self, plateau_values: &[f32]) -> PlateauCharacteristics {
1359 let stability = 1.0 - self.calculate_std(plateau_values);
1360 let noise_level = self.calculate_std(plateau_values);
1361
1362 let gradient_magnitude =
1363 self.metrics_history.back().and_then(|m| m.gradient_norm).unwrap_or(0.0);
1364
1365 let overfitting_risk =
1367 if let Some(val_loss) = self.metrics_history.back().and_then(|m| m.validation_loss) {
1368 let train_loss = self.metrics_history.back().unwrap().train_loss;
1369 ((val_loss - train_loss) / train_loss.abs().max(1e-8)).max(0.0).min(1.0)
1370 } else {
1371 0.5
1372 };
1373
1374 PlateauCharacteristics {
1375 stability: stability.max(0.0).min(1.0),
1376 noise_level: noise_level.min(1.0),
1377 gradient_magnitude,
1378 overfitting_risk,
1379 }
1380 }
1381
1382 fn generate_plateau_recommendations(
1384 &self,
1385 plateau_detected: bool,
1386 plateau_duration: usize,
1387 ) -> Vec<PlateauRecommendation> {
1388 let mut recommendations = Vec::new();
1389
1390 if !plateau_detected {
1391 return recommendations;
1392 }
1393
1394 if plateau_duration > 20 {
1395 recommendations.push(PlateauRecommendation {
1396 action: PlateauAction::IncreaseLearningRate,
1397 priority: Priority::High,
1398 description: "Long plateau detected, consider increasing learning rate".to_string(),
1399 implementation: "Multiply current learning rate by 2-5x temporarily".to_string(),
1400 });
1401 } else if plateau_duration > 10 {
1402 recommendations.push(PlateauRecommendation {
1403 action: PlateauAction::ChangeBatchSize,
1404 priority: Priority::Medium,
1405 description: "Moderate plateau detected, try changing batch size".to_string(),
1406 implementation: "Increase or decrease batch size by 50%".to_string(),
1407 });
1408 }
1409
1410 if plateau_duration > 30 {
1411 recommendations.push(PlateauRecommendation {
1412 action: PlateauAction::EarlyStopping,
1413 priority: Priority::Critical,
1414 description: "Very long plateau, consider early stopping".to_string(),
1415 implementation: "Stop training and use best checkpoint".to_string(),
1416 });
1417 }
1418
1419 recommendations
1420 }
1421
1422 fn generate_training_summary(&self, report: &mut TrainingDynamicsReport) {
1424 let total_epochs = self.metrics_history.back().map(|m| m.epoch).unwrap_or(0);
1425 let total_steps = self.metrics_history.back().map(|m| m.step).unwrap_or(0);
1426
1427 let training_efficiency = if let Some(loss_analysis) = &report.loss_curve_analysis {
1428 loss_analysis.improvement_rate.max(0.0).min(1.0)
1429 } else {
1430 0.0
1431 };
1432
1433 let convergence_health = if let Some(conv_analysis) = &report.convergence_analysis {
1434 conv_analysis.convergence_probability
1435 } else {
1436 0.0
1437 };
1438
1439 let stability_score = if let Some(loss_analysis) = &report.loss_curve_analysis {
1440 loss_analysis.smoothness
1441 } else {
1442 0.0
1443 };
1444
1445 let overall_progress =
1446 (training_efficiency * 0.4 + convergence_health * 0.3 + stability_score * 0.3)
1447 .max(0.0)
1448 .min(1.0);
1449
1450 report.training_summary = TrainingSummary {
1451 total_epochs,
1452 total_steps,
1453 training_efficiency,
1454 convergence_health,
1455 stability_score,
1456 overall_progress,
1457 };
1458 }
1459
1460 fn generate_training_recommendations(&self, report: &mut TrainingDynamicsReport) {
1462 let mut recommendations = Vec::new();
1463
1464 if let Some(lr_analysis) = &report.learning_rate_analysis {
1466 for lr_rec in &lr_analysis.recommendations {
1467 recommendations.push(TrainingRecommendation {
1468 category: TrainingCategory::LearningRate,
1469 priority: if lr_rec.confidence > 0.8 {
1470 Priority::High
1471 } else {
1472 Priority::Medium
1473 },
1474 description: lr_rec.rationale.clone(),
1475 implementation: format!("{:?} learning rate", lr_rec.action),
1476 expected_impact: lr_rec.expected_improvement,
1477 });
1478 }
1479 }
1480
1481 if let Some(plateau_analysis) = &report.plateau_analysis {
1483 for plateau_rec in &plateau_analysis.recommendations {
1484 recommendations.push(TrainingRecommendation {
1485 category: TrainingCategory::Optimization,
1486 priority: plateau_rec.priority.clone(),
1487 description: plateau_rec.description.clone(),
1488 implementation: plateau_rec.implementation.clone(),
1489 expected_impact: 0.5, });
1491 }
1492 }
1493
1494 if let Some(conv_analysis) = &report.convergence_analysis {
1496 if let Some(early_stop) = &conv_analysis.early_stopping_recommendation {
1497 if early_stop.should_stop {
1498 recommendations.push(TrainingRecommendation {
1499 category: TrainingCategory::EarlyStopping,
1500 priority: Priority::High,
1501 description: early_stop.rationale.clone(),
1502 implementation: "Stop training and save current model".to_string(),
1503 expected_impact: 0.8,
1504 });
1505 }
1506 }
1507 }
1508
1509 report.recommendations = recommendations;
1510 }
1511
1512 pub async fn generate_report(&self) -> Result<TrainingDynamicsReport> {
1514 let mut temp_analyzer = TrainingDynamicsAnalyzer {
1515 config: self.config.clone(),
1516 metrics_history: self.metrics_history.clone(),
1517 analysis_cache: HashMap::new(),
1518 };
1519
1520 temp_analyzer.analyze().await
1521 }
1522
1523 pub fn clear(&mut self) {
1525 self.metrics_history.clear();
1526 self.analysis_cache.clear();
1527 }
1528
1529 pub fn get_training_summary(&self) -> TrainingStateSummary {
1531 let current_metrics = self.metrics_history.back();
1532
1533 TrainingStateSummary {
1534 total_epochs: current_metrics.map(|m| m.epoch).unwrap_or(0),
1535 total_steps: current_metrics.map(|m| m.step).unwrap_or(0),
1536 current_loss: current_metrics.map(|m| m.train_loss).unwrap_or(0.0),
1537 current_lr: current_metrics.map(|m| m.learning_rate).unwrap_or(0.0),
1538 metrics_collected: self.metrics_history.len(),
1539 }
1540 }
1541}
1542
1543#[derive(Debug, Clone, Serialize, Deserialize)]
1545pub struct TrainingStateSummary {
1546 pub total_epochs: usize,
1547 pub total_steps: usize,
1548 pub current_loss: f32,
1549 pub current_lr: f32,
1550 pub metrics_collected: usize,
1551}