Skip to main content

trustformers_debug/
training_dynamics.rs

1//! Training Dynamics Analysis
2//!
3//! Advanced analysis tools for understanding training dynamics including
4//! loss curve analysis, convergence detection, and learning rate impact assessment.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Configuration for training dynamics analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingDynamicsConfig {
13    /// Enable loss curve analysis
14    pub enable_loss_curve_analysis: bool,
15    /// Enable learning rate impact analysis
16    pub enable_learning_rate_analysis: bool,
17    /// Enable batch size effects analysis
18    pub enable_batch_size_analysis: bool,
19    /// Enable convergence detection
20    pub enable_convergence_detection: bool,
21    /// Enable plateau identification
22    pub enable_plateau_identification: bool,
23    /// Window size for moving averages
24    pub moving_average_window: usize,
25    /// Convergence tolerance
26    pub convergence_tolerance: f32,
27    /// Plateau detection threshold
28    pub plateau_threshold: f32,
29    /// Minimum epochs for convergence detection
30    pub min_epochs_for_convergence: usize,
31    /// Maximum history length
32    pub max_history_length: usize,
33}
34
35impl Default for TrainingDynamicsConfig {
36    fn default() -> Self {
37        Self {
38            enable_loss_curve_analysis: true,
39            enable_learning_rate_analysis: true,
40            enable_batch_size_analysis: true,
41            enable_convergence_detection: true,
42            enable_plateau_identification: true,
43            moving_average_window: 10,
44            convergence_tolerance: 1e-6,
45            plateau_threshold: 1e-4,
46            min_epochs_for_convergence: 20,
47            max_history_length: 10000,
48        }
49    }
50}
51
52/// Training metrics at a specific point in time
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingMetrics {
55    pub epoch: usize,
56    pub step: usize,
57    pub train_loss: f32,
58    pub validation_loss: Option<f32>,
59    pub learning_rate: f32,
60    pub batch_size: usize,
61    pub gradient_norm: Option<f32>,
62    pub accuracy: Option<f32>,
63    pub timestamp: f64,
64}
65
66/// Loss curve analysis results
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LossCurveAnalysis {
69    pub trend: LossTrend,
70    pub smoothness: f32,
71    pub volatility: f32,
72    pub improvement_rate: f32,
73    pub best_loss: f32,
74    pub current_loss: f32,
75    pub loss_reduction_percentage: f32,
76    pub epochs_since_improvement: usize,
77    pub moving_averages: MovingAverages,
78    pub loss_statistics: LossStatistics,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum LossTrend {
83    Decreasing,
84    Increasing,
85    Oscillating,
86    Plateaued,
87    Unknown,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MovingAverages {
92    pub short_term: f32,  // Last 5-10 epochs
93    pub medium_term: f32, // Last 20-50 epochs
94    pub long_term: f32,   // Last 100+ epochs
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct LossStatistics {
99    pub mean: f32,
100    pub std: f32,
101    pub min: f32,
102    pub max: f32,
103    pub median: f32,
104    pub percentile_25: f32,
105    pub percentile_75: f32,
106    pub autocorrelation: f32,
107}
108
109/// Learning rate impact analysis
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LearningRateAnalysis {
112    pub current_lr: f32,
113    pub lr_schedule_type: LRScheduleType,
114    pub lr_impact_score: f32,
115    pub optimal_lr_estimate: f32,
116    pub lr_sensitivity: f32,
117    pub lr_history: Vec<LearningRatePoint>,
118    pub recommendations: Vec<LRRecommendation>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum LRScheduleType {
123    Constant,
124    StepDecay,
125    ExponentialDecay,
126    CosineAnnealing,
127    ReduceOnPlateau,
128    Warmup,
129    Cyclical,
130    Unknown,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct LearningRatePoint {
135    pub epoch: usize,
136    pub learning_rate: f32,
137    pub loss_change: f32,
138    pub gradient_norm: Option<f32>,
139    pub effectiveness: f32,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct LRRecommendation {
144    pub action: LRAction,
145    pub confidence: f32,
146    pub rationale: String,
147    pub expected_improvement: f32,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum LRAction {
152    Increase,
153    Decrease,
154    KeepCurrent,
155    AddScheduler,
156    ChangeScheduler,
157    AddWarmup,
158}
159
160/// Batch size effects analysis
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchSizeAnalysis {
163    pub current_batch_size: usize,
164    pub batch_size_efficiency: f32,
165    pub gradient_noise_level: f32,
166    pub convergence_speed: f32,
167    pub memory_utilization: f32,
168    pub optimal_batch_size_estimate: usize,
169    pub batch_size_history: Vec<BatchSizePoint>,
170    pub recommendations: Vec<BatchSizeRecommendation>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct BatchSizePoint {
175    pub epoch: usize,
176    pub batch_size: usize,
177    pub loss_improvement: f32,
178    pub gradient_stability: f32,
179    pub throughput: f32,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct BatchSizeRecommendation {
184    pub suggested_batch_size: usize,
185    pub confidence: f32,
186    pub rationale: String,
187    pub expected_benefits: Vec<String>,
188}
189
190/// Convergence detection results
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ConvergenceAnalysis {
193    pub convergence_status: ConvergenceStatus,
194    pub convergence_probability: f32,
195    pub epochs_to_convergence_estimate: Option<usize>,
196    pub convergence_criteria: Vec<ConvergenceCriterion>,
197    pub early_stopping_recommendation: Option<EarlyStoppingRecommendation>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum ConvergenceStatus {
202    Converging,
203    Converged,
204    Diverging,
205    Oscillating,
206    TooEarly,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ConvergenceCriterion {
211    pub criterion_type: ConvergenceCriterionType,
212    pub current_value: f32,
213    pub threshold: f32,
214    pub satisfied: bool,
215    pub confidence: f32,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ConvergenceCriterionType {
220    LossStability,
221    GradientMagnitude,
222    LossImprovement,
223    ValidationGap,
224    LearningRateDecay,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct EarlyStoppingRecommendation {
229    pub should_stop: bool,
230    pub confidence: f32,
231    pub rationale: String,
232    pub suggested_epochs_remaining: usize,
233}
234
235/// Plateau identification results
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PlateauAnalysis {
238    pub plateau_detected: bool,
239    pub plateau_duration: usize,
240    pub plateau_level: f32,
241    pub plateau_type: PlateauType,
242    pub escape_probability: f32,
243    pub plateau_characteristics: PlateauCharacteristics,
244    pub recommendations: Vec<PlateauRecommendation>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum PlateauType {
249    LossPlayteau,
250    GradientPlateau,
251    AccuracyPlateau,
252    LearningRatePlateau,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct PlateauCharacteristics {
257    pub stability: f32,
258    pub noise_level: f32,
259    pub gradient_magnitude: f32,
260    pub overfitting_risk: f32,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct PlateauRecommendation {
265    pub action: PlateauAction,
266    pub priority: Priority,
267    pub description: String,
268    pub implementation: String,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub enum PlateauAction {
273    IncreaseLearningRate,
274    DecreaseLearningRate,
275    ChangeBatchSize,
276    AddRegularization,
277    RemoveRegularization,
278    ChangeOptimizer,
279    AddNoise,
280    EarlyStopping,
281    ContinueTraining,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum Priority {
286    Critical,
287    High,
288    Medium,
289    Low,
290}
291
292/// Comprehensive training dynamics report
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct TrainingDynamicsReport {
295    pub loss_curve_analysis: Option<LossCurveAnalysis>,
296    pub learning_rate_analysis: Option<LearningRateAnalysis>,
297    pub batch_size_analysis: Option<BatchSizeAnalysis>,
298    pub convergence_analysis: Option<ConvergenceAnalysis>,
299    pub plateau_analysis: Option<PlateauAnalysis>,
300    pub training_summary: TrainingSummary,
301    pub recommendations: Vec<TrainingRecommendation>,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct TrainingSummary {
306    pub total_epochs: usize,
307    pub total_steps: usize,
308    pub training_efficiency: f32,
309    pub convergence_health: f32,
310    pub stability_score: f32,
311    pub overall_progress: f32,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct TrainingRecommendation {
316    pub category: TrainingCategory,
317    pub priority: Priority,
318    pub description: String,
319    pub implementation: String,
320    pub expected_impact: f32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub enum TrainingCategory {
325    LearningRate,
326    BatchSize,
327    Optimization,
328    Regularization,
329    EarlyStopping,
330    Architecture,
331}
332
333/// Training dynamics analyzer
334#[derive(Debug)]
335pub struct TrainingDynamicsAnalyzer {
336    config: TrainingDynamicsConfig,
337    metrics_history: VecDeque<TrainingMetrics>,
338    analysis_cache: HashMap<String, TrainingDynamicsReport>,
339}
340
341impl TrainingDynamicsAnalyzer {
342    /// Create a new training dynamics analyzer
343    pub fn new(config: TrainingDynamicsConfig) -> Self {
344        Self {
345            config,
346            metrics_history: VecDeque::new(),
347            analysis_cache: HashMap::new(),
348        }
349    }
350
351    /// Record training metrics
352    pub fn record_metrics(&mut self, metrics: TrainingMetrics) {
353        self.metrics_history.push_back(metrics);
354
355        // Limit history size
356        while self.metrics_history.len() > self.config.max_history_length {
357            self.metrics_history.pop_front();
358        }
359    }
360
361    /// Perform comprehensive training dynamics analysis
362    pub async fn analyze(&mut self) -> Result<TrainingDynamicsReport> {
363        let mut report = TrainingDynamicsReport {
364            loss_curve_analysis: None,
365            learning_rate_analysis: None,
366            batch_size_analysis: None,
367            convergence_analysis: None,
368            plateau_analysis: None,
369            training_summary: TrainingSummary {
370                total_epochs: 0,
371                total_steps: 0,
372                training_efficiency: 0.0,
373                convergence_health: 0.0,
374                stability_score: 0.0,
375                overall_progress: 0.0,
376            },
377            recommendations: Vec::new(),
378        };
379
380        if self.config.enable_loss_curve_analysis {
381            report.loss_curve_analysis = Some(self.analyze_loss_curve().await?);
382        }
383
384        if self.config.enable_learning_rate_analysis {
385            report.learning_rate_analysis = Some(self.analyze_learning_rate().await?);
386        }
387
388        if self.config.enable_batch_size_analysis {
389            report.batch_size_analysis = Some(self.analyze_batch_size().await?);
390        }
391
392        if self.config.enable_convergence_detection {
393            report.convergence_analysis = Some(self.detect_convergence().await?);
394        }
395
396        if self.config.enable_plateau_identification {
397            report.plateau_analysis = Some(self.identify_plateau().await?);
398        }
399
400        self.generate_training_summary(&mut report);
401        self.generate_training_recommendations(&mut report);
402
403        Ok(report)
404    }
405
406    /// Analyze loss curve patterns
407    async fn analyze_loss_curve(&self) -> Result<LossCurveAnalysis> {
408        if self.metrics_history.is_empty() {
409            return Ok(LossCurveAnalysis {
410                trend: LossTrend::Unknown,
411                smoothness: 0.0,
412                volatility: 0.0,
413                improvement_rate: 0.0,
414                best_loss: 0.0,
415                current_loss: 0.0,
416                loss_reduction_percentage: 0.0,
417                epochs_since_improvement: 0,
418                moving_averages: MovingAverages {
419                    short_term: 0.0,
420                    medium_term: 0.0,
421                    long_term: 0.0,
422                },
423                loss_statistics: LossStatistics {
424                    mean: 0.0,
425                    std: 0.0,
426                    min: 0.0,
427                    max: 0.0,
428                    median: 0.0,
429                    percentile_25: 0.0,
430                    percentile_75: 0.0,
431                    autocorrelation: 0.0,
432                },
433            });
434        }
435
436        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
437
438        let trend = self.detect_loss_trend(&losses);
439        let smoothness = self.calculate_smoothness(&losses);
440        let volatility = self.calculate_volatility(&losses);
441        let improvement_rate = self.calculate_improvement_rate(&losses);
442
443        let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
444        let current_loss = *losses.last().expect("losses is non-empty from metrics_history");
445        let loss_reduction_percentage = if losses.len() > 1 {
446            ((losses[0] - current_loss) / losses[0].abs()) * 100.0
447        } else {
448            0.0
449        };
450
451        let epochs_since_improvement = self.calculate_epochs_since_improvement(&losses, best_loss);
452        let moving_averages = self.calculate_moving_averages(&losses);
453        let loss_statistics = self.calculate_loss_statistics(&losses);
454
455        Ok(LossCurveAnalysis {
456            trend,
457            smoothness,
458            volatility,
459            improvement_rate,
460            best_loss,
461            current_loss,
462            loss_reduction_percentage,
463            epochs_since_improvement,
464            moving_averages,
465            loss_statistics,
466        })
467    }
468
469    /// Detect overall trend in loss curve
470    fn detect_loss_trend(&self, losses: &[f32]) -> LossTrend {
471        if losses.len() < 3 {
472            return LossTrend::Unknown;
473        }
474
475        let window_size = (losses.len() / 4).max(5).min(20);
476        let recent_losses = &losses[losses.len().saturating_sub(window_size)..];
477        let early_losses = &losses[..window_size.min(losses.len())];
478
479        let recent_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
480        let early_mean = early_losses.iter().sum::<f32>() / early_losses.len() as f32;
481
482        let improvement = (early_mean - recent_mean) / early_mean.abs();
483
484        // Check for plateau
485        let recent_std = self.calculate_std(recent_losses);
486        let recent_mean_abs = recent_mean.abs();
487
488        if recent_std / recent_mean_abs.max(1e-8) < self.config.plateau_threshold {
489            return LossTrend::Plateaued;
490        }
491
492        // Check for oscillation
493        let oscillation_score = self.detect_oscillation(losses);
494        if oscillation_score > 0.5 {
495            return LossTrend::Oscillating;
496        }
497
498        if improvement > 0.01 {
499            LossTrend::Decreasing
500        } else if improvement < -0.01 {
501            LossTrend::Increasing
502        } else {
503            LossTrend::Plateaued
504        }
505    }
506
507    /// Calculate smoothness of loss curve
508    fn calculate_smoothness(&self, losses: &[f32]) -> f32 {
509        if losses.len() < 2 {
510            return 1.0;
511        }
512
513        let differences: Vec<f32> = losses.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
514
515        let mean_diff = differences.iter().sum::<f32>() / differences.len() as f32;
516        let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
517
518        // Smoothness is inverse of relative variation
519        1.0 / (1.0 + mean_diff / mean_loss.abs().max(1e-8))
520    }
521
522    /// Calculate volatility of loss curve
523    fn calculate_volatility(&self, losses: &[f32]) -> f32 {
524        if losses.len() < 2 {
525            return 0.0;
526        }
527
528        let returns: Vec<f32> =
529            losses.windows(2).map(|w| (w[1] - w[0]) / w[0].abs().max(1e-8)).collect();
530
531        self.calculate_std(&returns)
532    }
533
534    /// Calculate improvement rate
535    fn calculate_improvement_rate(&self, losses: &[f32]) -> f32 {
536        if losses.len() < 2 {
537            return 0.0;
538        }
539
540        let total_improvement = losses[0] - losses[losses.len() - 1];
541        let epochs = losses.len() as f32;
542
543        total_improvement / epochs
544    }
545
546    /// Calculate epochs since last improvement
547    fn calculate_epochs_since_improvement(&self, losses: &[f32], best_loss: f32) -> usize {
548        for (i, &loss) in losses.iter().rev().enumerate() {
549            if (loss - best_loss).abs() < 1e-8 {
550                return i;
551            }
552        }
553        losses.len()
554    }
555
556    /// Calculate moving averages
557    fn calculate_moving_averages(&self, losses: &[f32]) -> MovingAverages {
558        let short_window = 5.min(losses.len());
559        let medium_window = 20.min(losses.len());
560        let long_window = 100.min(losses.len());
561
562        let short_term = if short_window > 0 {
563            losses[losses.len() - short_window..].iter().sum::<f32>() / short_window as f32
564        } else {
565            0.0
566        };
567
568        let medium_term = if medium_window > 0 {
569            losses[losses.len() - medium_window..].iter().sum::<f32>() / medium_window as f32
570        } else {
571            0.0
572        };
573
574        let long_term = if long_window > 0 {
575            losses[losses.len() - long_window..].iter().sum::<f32>() / long_window as f32
576        } else {
577            0.0
578        };
579
580        MovingAverages {
581            short_term,
582            medium_term,
583            long_term,
584        }
585    }
586
587    /// Calculate comprehensive loss statistics
588    fn calculate_loss_statistics(&self, losses: &[f32]) -> LossStatistics {
589        if losses.is_empty() {
590            return LossStatistics {
591                mean: 0.0,
592                std: 0.0,
593                min: 0.0,
594                max: 0.0,
595                median: 0.0,
596                percentile_25: 0.0,
597                percentile_75: 0.0,
598                autocorrelation: 0.0,
599            };
600        }
601
602        let mean = losses.iter().sum::<f32>() / losses.len() as f32;
603        let std = self.calculate_std(losses);
604
605        let mut sorted_losses = losses.to_vec();
606        sorted_losses.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
607
608        let min = sorted_losses[0];
609        let max = sorted_losses[sorted_losses.len() - 1];
610        let median = sorted_losses[sorted_losses.len() / 2];
611        let percentile_25 = sorted_losses[sorted_losses.len() / 4];
612        let percentile_75 = sorted_losses[3 * sorted_losses.len() / 4];
613
614        let autocorrelation = self.calculate_autocorrelation(losses, 1);
615
616        LossStatistics {
617            mean,
618            std,
619            min,
620            max,
621            median,
622            percentile_25,
623            percentile_75,
624            autocorrelation,
625        }
626    }
627
628    /// Calculate standard deviation
629    fn calculate_std(&self, values: &[f32]) -> f32 {
630        if values.len() < 2 {
631            return 0.0;
632        }
633
634        let mean = values.iter().sum::<f32>() / values.len() as f32;
635        let variance =
636            values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
637
638        variance.sqrt()
639    }
640
641    /// Detect oscillation in loss curve
642    fn detect_oscillation(&self, losses: &[f32]) -> f32 {
643        if losses.len() < 4 {
644            return 0.0;
645        }
646
647        let mut direction_changes = 0;
648        let mut total_comparisons = 0;
649
650        for i in 1..losses.len() - 1 {
651            let prev_direction = losses[i] > losses[i - 1];
652            let next_direction = losses[i + 1] > losses[i];
653
654            if prev_direction != next_direction {
655                direction_changes += 1;
656            }
657            total_comparisons += 1;
658        }
659
660        direction_changes as f32 / total_comparisons as f32
661    }
662
663    /// Calculate autocorrelation
664    fn calculate_autocorrelation(&self, values: &[f32], lag: usize) -> f32 {
665        if values.len() <= lag {
666            return 0.0;
667        }
668
669        let mean = values.iter().sum::<f32>() / values.len() as f32;
670
671        let mut numerator = 0.0;
672        let mut denominator = 0.0;
673
674        for i in 0..values.len() - lag {
675            numerator += (values[i] - mean) * (values[i + lag] - mean);
676        }
677
678        for &value in values {
679            denominator += (value - mean).powi(2);
680        }
681
682        if denominator > 1e-8 {
683            numerator / denominator
684        } else {
685            0.0
686        }
687    }
688
689    /// Analyze learning rate impact
690    async fn analyze_learning_rate(&self) -> Result<LearningRateAnalysis> {
691        if self.metrics_history.is_empty() {
692            return Ok(LearningRateAnalysis {
693                current_lr: 0.0,
694                lr_schedule_type: LRScheduleType::Unknown,
695                lr_impact_score: 0.0,
696                optimal_lr_estimate: 0.0,
697                lr_sensitivity: 0.0,
698                lr_history: Vec::new(),
699                recommendations: Vec::new(),
700            });
701        }
702
703        let current_lr = self
704            .metrics_history
705            .back()
706            .expect("metrics_history should not be empty after empty check")
707            .learning_rate;
708        let lr_schedule_type = self.detect_lr_schedule_type();
709
710        let lr_history = self.build_lr_history();
711        let lr_impact_score = self.calculate_lr_impact_score(&lr_history);
712        let optimal_lr_estimate = self.estimate_optimal_lr(&lr_history);
713        let lr_sensitivity = self.calculate_lr_sensitivity(&lr_history);
714        let recommendations = self.generate_lr_recommendations(current_lr, &lr_history);
715
716        Ok(LearningRateAnalysis {
717            current_lr,
718            lr_schedule_type,
719            lr_impact_score,
720            optimal_lr_estimate,
721            lr_sensitivity,
722            lr_history,
723            recommendations,
724        })
725    }
726
727    /// Detect learning rate schedule type
728    fn detect_lr_schedule_type(&self) -> LRScheduleType {
729        let lrs: Vec<f32> = self.metrics_history.iter().map(|m| m.learning_rate).collect();
730
731        if lrs.len() < 3 {
732            return LRScheduleType::Unknown;
733        }
734
735        // Check for constant LR
736        let lr_std = self.calculate_std(&lrs);
737        if lr_std < 1e-8 {
738            return LRScheduleType::Constant;
739        }
740
741        // Check for step decay (sudden drops)
742        let mut step_drops = 0;
743        for window in lrs.windows(2) {
744            if window[1] < window[0] * 0.9 {
745                step_drops += 1;
746            }
747        }
748
749        if step_drops > lrs.len() / 20 {
750            return LRScheduleType::StepDecay;
751        }
752
753        // Check for exponential decay
754        let log_lrs: Vec<f32> = lrs.iter().map(|&lr| lr.ln()).collect();
755        let exponential_trend = self.calculate_linear_trend(&log_lrs);
756        if exponential_trend < -0.01 {
757            return LRScheduleType::ExponentialDecay;
758        }
759
760        // Check for cyclical patterns
761        let cyclical_score = self.detect_cyclical_pattern(&lrs);
762        if cyclical_score > 0.3 {
763            return LRScheduleType::Cyclical;
764        }
765
766        LRScheduleType::Unknown
767    }
768
769    /// Calculate linear trend
770    fn calculate_linear_trend(&self, values: &[f32]) -> f32 {
771        if values.len() < 2 {
772            return 0.0;
773        }
774
775        let n = values.len() as f32;
776        let x_mean = (n - 1.0) / 2.0;
777        let y_mean = values.iter().sum::<f32>() / n;
778
779        let mut numerator = 0.0;
780        let mut denominator = 0.0;
781
782        for (i, &y) in values.iter().enumerate() {
783            let x = i as f32;
784            numerator += (x - x_mean) * (y - y_mean);
785            denominator += (x - x_mean).powi(2);
786        }
787
788        if denominator > 1e-8 {
789            numerator / denominator
790        } else {
791            0.0
792        }
793    }
794
795    /// Detect cyclical patterns
796    fn detect_cyclical_pattern(&self, values: &[f32]) -> f32 {
797        // Simplified cyclical detection using autocorrelation
798        let mut max_autocorr: f32 = 0.0;
799        for lag in 2..=values.len() / 4 {
800            let autocorr = self.calculate_autocorrelation(values, lag).abs();
801            max_autocorr = max_autocorr.max(autocorr);
802        }
803        max_autocorr
804    }
805
806    /// Build learning rate history with effectiveness scores
807    fn build_lr_history(&self) -> Vec<LearningRatePoint> {
808        let mut history = Vec::new();
809
810        for (i, metrics) in self.metrics_history.iter().enumerate() {
811            let loss_change = if i > 0 {
812                self.metrics_history[i - 1].train_loss - metrics.train_loss
813            } else {
814                0.0
815            };
816
817            let effectiveness = if loss_change > 0.0 {
818                loss_change / metrics.learning_rate.max(1e-8)
819            } else {
820                0.0
821            };
822
823            history.push(LearningRatePoint {
824                epoch: metrics.epoch,
825                learning_rate: metrics.learning_rate,
826                loss_change,
827                gradient_norm: metrics.gradient_norm,
828                effectiveness,
829            });
830        }
831
832        history
833    }
834
835    /// Calculate learning rate impact score
836    fn calculate_lr_impact_score(&self, lr_history: &[LearningRatePoint]) -> f32 {
837        if lr_history.is_empty() {
838            return 0.0;
839        }
840
841        let avg_effectiveness =
842            lr_history.iter().map(|p| p.effectiveness).sum::<f32>() / lr_history.len() as f32;
843
844        avg_effectiveness.max(0.0).min(1.0)
845    }
846
847    /// Estimate optimal learning rate
848    fn estimate_optimal_lr(&self, lr_history: &[LearningRatePoint]) -> f32 {
849        if lr_history.is_empty() {
850            return 0.001; // Default
851        }
852
853        // Find LR with highest effectiveness
854        lr_history
855            .iter()
856            .max_by(|a, b| {
857                a.effectiveness
858                    .partial_cmp(&b.effectiveness)
859                    .unwrap_or(std::cmp::Ordering::Equal)
860            })
861            .map(|p| p.learning_rate)
862            .unwrap_or(0.001)
863    }
864
865    /// Calculate learning rate sensitivity
866    fn calculate_lr_sensitivity(&self, lr_history: &[LearningRatePoint]) -> f32 {
867        if lr_history.len() < 2 {
868            return 0.0;
869        }
870
871        let effectiveness_values: Vec<f32> = lr_history.iter().map(|p| p.effectiveness).collect();
872
873        self.calculate_std(&effectiveness_values)
874    }
875
876    /// Generate learning rate recommendations
877    fn generate_lr_recommendations(
878        &self,
879        current_lr: f32,
880        lr_history: &[LearningRatePoint],
881    ) -> Vec<LRRecommendation> {
882        let mut recommendations = Vec::new();
883
884        if lr_history.is_empty() {
885            return recommendations;
886        }
887
888        let recent_effectiveness =
889            lr_history.iter().rev().take(5).map(|p| p.effectiveness).sum::<f32>()
890                / 5.0f32.min(lr_history.len() as f32);
891
892        if recent_effectiveness < 0.1 {
893            recommendations.push(LRRecommendation {
894                action: LRAction::Decrease,
895                confidence: 0.7,
896                rationale: "Low learning effectiveness detected".to_string(),
897                expected_improvement: 0.3,
898            });
899        }
900
901        let optimal_lr = self.estimate_optimal_lr(lr_history);
902        if current_lr > optimal_lr * 2.0 {
903            recommendations.push(LRRecommendation {
904                action: LRAction::Decrease,
905                confidence: 0.8,
906                rationale: "Current LR significantly higher than estimated optimal".to_string(),
907                expected_improvement: 0.4,
908            });
909        } else if current_lr < optimal_lr * 0.5 {
910            recommendations.push(LRRecommendation {
911                action: LRAction::Increase,
912                confidence: 0.6,
913                rationale: "Current LR significantly lower than estimated optimal".to_string(),
914                expected_improvement: 0.3,
915            });
916        }
917
918        recommendations
919    }
920
921    /// Analyze batch size effects
922    async fn analyze_batch_size(&self) -> Result<BatchSizeAnalysis> {
923        if self.metrics_history.is_empty() {
924            return Ok(BatchSizeAnalysis {
925                current_batch_size: 0,
926                batch_size_efficiency: 0.0,
927                gradient_noise_level: 0.0,
928                convergence_speed: 0.0,
929                memory_utilization: 0.0,
930                optimal_batch_size_estimate: 32,
931                batch_size_history: Vec::new(),
932                recommendations: Vec::new(),
933            });
934        }
935
936        let current_batch_size = self
937            .metrics_history
938            .back()
939            .expect("metrics_history should not be empty after empty check")
940            .batch_size;
941        let batch_size_history = self.build_batch_size_history();
942
943        let batch_size_efficiency = self.calculate_batch_size_efficiency(&batch_size_history);
944        let gradient_noise_level = self.estimate_gradient_noise_level();
945        let convergence_speed = self.estimate_convergence_speed();
946        let memory_utilization = self.estimate_memory_utilization(current_batch_size);
947        let optimal_batch_size_estimate = self.estimate_optimal_batch_size(&batch_size_history);
948        let recommendations =
949            self.generate_batch_size_recommendations(current_batch_size, &batch_size_history);
950
951        Ok(BatchSizeAnalysis {
952            current_batch_size,
953            batch_size_efficiency,
954            gradient_noise_level,
955            convergence_speed,
956            memory_utilization,
957            optimal_batch_size_estimate,
958            batch_size_history,
959            recommendations,
960        })
961    }
962
963    /// Build batch size history
964    fn build_batch_size_history(&self) -> Vec<BatchSizePoint> {
965        let mut history = Vec::new();
966
967        for (i, metrics) in self.metrics_history.iter().enumerate() {
968            let loss_improvement = if i > 0 {
969                self.metrics_history[i - 1].train_loss - metrics.train_loss
970            } else {
971                0.0
972            };
973
974            let gradient_stability =
975                metrics.gradient_norm.map(|gn| 1.0 / (1.0 + gn)).unwrap_or(0.5);
976            let throughput = 1.0; // Simplified throughput metric
977
978            history.push(BatchSizePoint {
979                epoch: metrics.epoch,
980                batch_size: metrics.batch_size,
981                loss_improvement,
982                gradient_stability,
983                throughput,
984            });
985        }
986
987        history
988    }
989
990    /// Calculate batch size efficiency
991    fn calculate_batch_size_efficiency(&self, batch_history: &[BatchSizePoint]) -> f32 {
992        if batch_history.is_empty() {
993            return 0.0;
994        }
995
996        let avg_improvement =
997            batch_history.iter().map(|p| p.loss_improvement.max(0.0)).sum::<f32>()
998                / batch_history.len() as f32;
999
1000        let avg_stability = batch_history.iter().map(|p| p.gradient_stability).sum::<f32>()
1001            / batch_history.len() as f32;
1002
1003        (avg_improvement * 0.6 + avg_stability * 0.4).min(1.0)
1004    }
1005
1006    /// Estimate gradient noise level
1007    fn estimate_gradient_noise_level(&self) -> f32 {
1008        let gradient_norms: Vec<f32> =
1009            self.metrics_history.iter().filter_map(|m| m.gradient_norm).collect();
1010
1011        if gradient_norms.is_empty() {
1012            return 0.5;
1013        }
1014
1015        let std = self.calculate_std(&gradient_norms);
1016        let mean = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
1017
1018        if mean > 1e-8 {
1019            (std / mean).min(1.0)
1020        } else {
1021            0.5
1022        }
1023    }
1024
1025    /// Estimate convergence speed
1026    fn estimate_convergence_speed(&self) -> f32 {
1027        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1028
1029        if losses.len() < 2 {
1030            return 0.0;
1031        }
1032
1033        let improvement_per_epoch = (losses[0] - losses[losses.len() - 1]) / losses.len() as f32;
1034        improvement_per_epoch.max(0.0).min(1.0)
1035    }
1036
1037    /// Estimate memory utilization
1038    fn estimate_memory_utilization(&self, batch_size: usize) -> f32 {
1039        // Simplified memory utilization based on batch size
1040        let normalized_batch_size = batch_size as f32 / 1024.0; // Normalize by typical large batch size
1041        normalized_batch_size.min(1.0)
1042    }
1043
1044    /// Estimate optimal batch size
1045    fn estimate_optimal_batch_size(&self, batch_history: &[BatchSizePoint]) -> usize {
1046        if batch_history.is_empty() {
1047            return 32;
1048        }
1049
1050        // Find batch size with best balance of improvement and stability
1051        batch_history
1052            .iter()
1053            .max_by(|a, b| {
1054                let score_a = a.loss_improvement * 0.6 + a.gradient_stability * 0.4;
1055                let score_b = b.loss_improvement * 0.6 + b.gradient_stability * 0.4;
1056                score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
1057            })
1058            .map(|p| p.batch_size)
1059            .unwrap_or(32)
1060    }
1061
1062    /// Generate batch size recommendations
1063    fn generate_batch_size_recommendations(
1064        &self,
1065        current_batch_size: usize,
1066        _batch_history: &[BatchSizePoint],
1067    ) -> Vec<BatchSizeRecommendation> {
1068        let mut recommendations = Vec::new();
1069
1070        if current_batch_size < 16 {
1071            recommendations.push(BatchSizeRecommendation {
1072                suggested_batch_size: 32,
1073                confidence: 0.7,
1074                rationale: "Very small batch size may lead to noisy gradients".to_string(),
1075                expected_benefits: vec![
1076                    "More stable gradients".to_string(),
1077                    "Better convergence".to_string(),
1078                ],
1079            });
1080        } else if current_batch_size > 512 {
1081            recommendations.push(BatchSizeRecommendation {
1082                suggested_batch_size: 256,
1083                confidence: 0.6,
1084                rationale: "Large batch size may slow convergence".to_string(),
1085                expected_benefits: vec![
1086                    "Faster convergence".to_string(),
1087                    "Lower memory usage".to_string(),
1088                ],
1089            });
1090        }
1091
1092        recommendations
1093    }
1094
1095    /// Detect convergence
1096    async fn detect_convergence(&self) -> Result<ConvergenceAnalysis> {
1097        if self.metrics_history.len() < self.config.min_epochs_for_convergence {
1098            return Ok(ConvergenceAnalysis {
1099                convergence_status: ConvergenceStatus::TooEarly,
1100                convergence_probability: 0.0,
1101                epochs_to_convergence_estimate: None,
1102                convergence_criteria: Vec::new(),
1103                early_stopping_recommendation: None,
1104            });
1105        }
1106
1107        let convergence_criteria = self.evaluate_convergence_criteria();
1108        let convergence_status = self.determine_convergence_status(&convergence_criteria);
1109        let convergence_probability = self.calculate_convergence_probability(&convergence_criteria);
1110        let epochs_to_convergence_estimate = self.estimate_epochs_to_convergence();
1111        let early_stopping_recommendation =
1112            self.generate_early_stopping_recommendation(&convergence_criteria);
1113
1114        Ok(ConvergenceAnalysis {
1115            convergence_status,
1116            convergence_probability,
1117            epochs_to_convergence_estimate,
1118            convergence_criteria,
1119            early_stopping_recommendation,
1120        })
1121    }
1122
1123    /// Evaluate convergence criteria
1124    fn evaluate_convergence_criteria(&self) -> Vec<ConvergenceCriterion> {
1125        let mut criteria = Vec::new();
1126
1127        // Loss stability criterion
1128        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1129        let recent_window = 10.min(losses.len());
1130        let recent_losses = &losses[losses.len() - recent_window..];
1131        let loss_std = self.calculate_std(recent_losses);
1132        let loss_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1133        let loss_stability = loss_std / loss_mean.abs().max(1e-8);
1134
1135        criteria.push(ConvergenceCriterion {
1136            criterion_type: ConvergenceCriterionType::LossStability,
1137            current_value: loss_stability,
1138            threshold: self.config.convergence_tolerance,
1139            satisfied: loss_stability < self.config.convergence_tolerance,
1140            confidence: 0.8,
1141        });
1142
1143        // Gradient magnitude criterion
1144        if let Some(recent_grad_norm) = self.metrics_history.back().and_then(|m| m.gradient_norm) {
1145            criteria.push(ConvergenceCriterion {
1146                criterion_type: ConvergenceCriterionType::GradientMagnitude,
1147                current_value: recent_grad_norm,
1148                threshold: 1e-4,
1149                satisfied: recent_grad_norm < 1e-4,
1150                confidence: 0.7,
1151            });
1152        }
1153
1154        // Loss improvement criterion
1155        if losses.len() >= 10 {
1156            let old_window = &losses[losses.len() - 20..losses.len() - 10];
1157            let new_window = &losses[losses.len() - 10..];
1158            let old_mean = old_window.iter().sum::<f32>() / old_window.len() as f32;
1159            let new_mean = new_window.iter().sum::<f32>() / new_window.len() as f32;
1160            let improvement = (old_mean - new_mean) / old_mean.abs().max(1e-8);
1161
1162            criteria.push(ConvergenceCriterion {
1163                criterion_type: ConvergenceCriterionType::LossImprovement,
1164                current_value: improvement,
1165                threshold: 1e-3,
1166                satisfied: improvement < 1e-3,
1167                confidence: 0.6,
1168            });
1169        }
1170
1171        criteria
1172    }
1173
1174    /// Determine convergence status
1175    fn determine_convergence_status(&self, criteria: &[ConvergenceCriterion]) -> ConvergenceStatus {
1176        let satisfied_count = criteria.iter().filter(|c| c.satisfied).count();
1177        let total_count = criteria.len();
1178
1179        if total_count == 0 {
1180            return ConvergenceStatus::TooEarly;
1181        }
1182
1183        let satisfaction_rate = satisfied_count as f32 / total_count as f32;
1184
1185        if satisfaction_rate > 0.8 {
1186            ConvergenceStatus::Converged
1187        } else if satisfaction_rate > 0.5 {
1188            ConvergenceStatus::Converging
1189        } else {
1190            // Check for divergence
1191            let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1192            let recent_trend =
1193                self.calculate_linear_trend(&losses[losses.len().saturating_sub(20)..]);
1194
1195            if recent_trend > 0.01 {
1196                ConvergenceStatus::Diverging
1197            } else {
1198                ConvergenceStatus::Oscillating
1199            }
1200        }
1201    }
1202
1203    /// Calculate convergence probability
1204    fn calculate_convergence_probability(&self, criteria: &[ConvergenceCriterion]) -> f32 {
1205        if criteria.is_empty() {
1206            return 0.0;
1207        }
1208
1209        let weighted_satisfaction: f32 =
1210            criteria.iter().map(|c| if c.satisfied { c.confidence } else { 0.0 }).sum();
1211
1212        let total_weight: f32 = criteria.iter().map(|c| c.confidence).sum();
1213
1214        if total_weight > 0.0 {
1215            weighted_satisfaction / total_weight
1216        } else {
1217            0.0
1218        }
1219    }
1220
1221    /// Estimate epochs to convergence
1222    fn estimate_epochs_to_convergence(&self) -> Option<usize> {
1223        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1224
1225        if losses.len() < 5 {
1226            return None;
1227        }
1228
1229        let improvement_rate = self.calculate_improvement_rate(&losses);
1230
1231        if improvement_rate <= 0.0 {
1232            return None;
1233        }
1234
1235        let current_loss = *losses.last().expect("losses has at least 5 elements after len check");
1236        let target_loss = current_loss * (1.0 - self.config.convergence_tolerance);
1237        let remaining_improvement = current_loss - target_loss;
1238
1239        let epochs_needed = (remaining_improvement / improvement_rate).ceil() as usize;
1240
1241        Some(epochs_needed.min(1000)) // Cap at reasonable number
1242    }
1243
1244    /// Generate early stopping recommendation
1245    fn generate_early_stopping_recommendation(
1246        &self,
1247        criteria: &[ConvergenceCriterion],
1248    ) -> Option<EarlyStoppingRecommendation> {
1249        let convergence_probability = self.calculate_convergence_probability(criteria);
1250
1251        if convergence_probability > 0.9 {
1252            Some(EarlyStoppingRecommendation {
1253                should_stop: true,
1254                confidence: convergence_probability,
1255                rationale: "High convergence probability detected".to_string(),
1256                suggested_epochs_remaining: 0,
1257            })
1258        } else if convergence_probability > 0.7 {
1259            Some(EarlyStoppingRecommendation {
1260                should_stop: false,
1261                confidence: convergence_probability,
1262                rationale: "Approaching convergence, continue for a few more epochs".to_string(),
1263                suggested_epochs_remaining: 5,
1264            })
1265        } else {
1266            None
1267        }
1268    }
1269
1270    /// Identify plateaus
1271    async fn identify_plateau(&self) -> Result<PlateauAnalysis> {
1272        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1273
1274        if losses.len() < 10 {
1275            return Ok(PlateauAnalysis {
1276                plateau_detected: false,
1277                plateau_duration: 0,
1278                plateau_level: 0.0,
1279                plateau_type: PlateauType::LossPlayteau,
1280                escape_probability: 0.0,
1281                plateau_characteristics: PlateauCharacteristics {
1282                    stability: 0.0,
1283                    noise_level: 0.0,
1284                    gradient_magnitude: 0.0,
1285                    overfitting_risk: 0.0,
1286                },
1287                recommendations: Vec::new(),
1288            });
1289        }
1290
1291        let window_size = 10.min(losses.len());
1292        let recent_losses = &losses[losses.len() - window_size..];
1293
1294        let plateau_detected = self.detect_plateau_in_window(recent_losses);
1295        let plateau_duration =
1296            if plateau_detected { self.calculate_plateau_duration(&losses) } else { 0 };
1297
1298        let plateau_level = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1299        let plateau_type = PlateauType::LossPlayteau; // Simplified
1300        let escape_probability =
1301            self.estimate_plateau_escape_probability(&losses, plateau_duration);
1302        let plateau_characteristics = self.analyze_plateau_characteristics(recent_losses);
1303        let recommendations =
1304            self.generate_plateau_recommendations(plateau_detected, plateau_duration);
1305
1306        Ok(PlateauAnalysis {
1307            plateau_detected,
1308            plateau_duration,
1309            plateau_level,
1310            plateau_type,
1311            escape_probability,
1312            plateau_characteristics,
1313            recommendations,
1314        })
1315    }
1316
1317    /// Detect plateau in a window of values
1318    fn detect_plateau_in_window(&self, values: &[f32]) -> bool {
1319        if values.len() < 3 {
1320            return false;
1321        }
1322
1323        let std = self.calculate_std(values);
1324        let mean = values.iter().sum::<f32>() / values.len() as f32;
1325
1326        std / mean.abs().max(1e-8) < self.config.plateau_threshold
1327    }
1328
1329    /// Calculate plateau duration
1330    fn calculate_plateau_duration(&self, losses: &[f32]) -> usize {
1331        let threshold = self.config.plateau_threshold;
1332        let mut duration = 0;
1333
1334        for window in losses.windows(10).rev() {
1335            let std = self.calculate_std(window);
1336            let mean = window.iter().sum::<f32>() / window.len() as f32;
1337
1338            if std / mean.abs().max(1e-8) < threshold {
1339                duration += 1;
1340            } else {
1341                break;
1342            }
1343        }
1344
1345        duration
1346    }
1347
1348    /// Estimate plateau escape probability
1349    fn estimate_plateau_escape_probability(&self, losses: &[f32], plateau_duration: usize) -> f32 {
1350        if plateau_duration == 0 {
1351            return 1.0;
1352        }
1353
1354        // Longer plateaus are harder to escape
1355        let duration_factor = 1.0 / (1.0 + plateau_duration as f32 * 0.1);
1356
1357        // Recent trend might indicate escape potential
1358        let recent_trend = if losses.len() >= 5 {
1359            self.calculate_linear_trend(&losses[losses.len() - 5..])
1360        } else {
1361            0.0
1362        };
1363
1364        let trend_factor = if recent_trend < 0.0 { 0.8 } else { 0.3 };
1365
1366        (duration_factor * trend_factor).max(0.1).min(0.9)
1367    }
1368
1369    /// Analyze plateau characteristics
1370    fn analyze_plateau_characteristics(&self, plateau_values: &[f32]) -> PlateauCharacteristics {
1371        let stability = 1.0 - self.calculate_std(plateau_values);
1372        let noise_level = self.calculate_std(plateau_values);
1373
1374        let gradient_magnitude =
1375            self.metrics_history.back().and_then(|m| m.gradient_norm).unwrap_or(0.0);
1376
1377        // Simple overfitting risk estimation
1378        let overfitting_risk =
1379            if let Some(val_loss) = self.metrics_history.back().and_then(|m| m.validation_loss) {
1380                let train_loss = self
1381                    .metrics_history
1382                    .back()
1383                    .expect("metrics_history should not be empty in this branch")
1384                    .train_loss;
1385                ((val_loss - train_loss) / train_loss.abs().max(1e-8)).max(0.0).min(1.0)
1386            } else {
1387                0.5
1388            };
1389
1390        PlateauCharacteristics {
1391            stability: stability.max(0.0).min(1.0),
1392            noise_level: noise_level.min(1.0),
1393            gradient_magnitude,
1394            overfitting_risk,
1395        }
1396    }
1397
1398    /// Generate plateau recommendations
1399    fn generate_plateau_recommendations(
1400        &self,
1401        plateau_detected: bool,
1402        plateau_duration: usize,
1403    ) -> Vec<PlateauRecommendation> {
1404        let mut recommendations = Vec::new();
1405
1406        if !plateau_detected {
1407            return recommendations;
1408        }
1409
1410        if plateau_duration > 20 {
1411            recommendations.push(PlateauRecommendation {
1412                action: PlateauAction::IncreaseLearningRate,
1413                priority: Priority::High,
1414                description: "Long plateau detected, consider increasing learning rate".to_string(),
1415                implementation: "Multiply current learning rate by 2-5x temporarily".to_string(),
1416            });
1417        } else if plateau_duration > 10 {
1418            recommendations.push(PlateauRecommendation {
1419                action: PlateauAction::ChangeBatchSize,
1420                priority: Priority::Medium,
1421                description: "Moderate plateau detected, try changing batch size".to_string(),
1422                implementation: "Increase or decrease batch size by 50%".to_string(),
1423            });
1424        }
1425
1426        if plateau_duration > 30 {
1427            recommendations.push(PlateauRecommendation {
1428                action: PlateauAction::EarlyStopping,
1429                priority: Priority::Critical,
1430                description: "Very long plateau, consider early stopping".to_string(),
1431                implementation: "Stop training and use best checkpoint".to_string(),
1432            });
1433        }
1434
1435        recommendations
1436    }
1437
1438    /// Generate training summary
1439    fn generate_training_summary(&self, report: &mut TrainingDynamicsReport) {
1440        let total_epochs = self.metrics_history.back().map(|m| m.epoch).unwrap_or(0);
1441        let total_steps = self.metrics_history.back().map(|m| m.step).unwrap_or(0);
1442
1443        let training_efficiency = if let Some(loss_analysis) = &report.loss_curve_analysis {
1444            loss_analysis.improvement_rate.max(0.0).min(1.0)
1445        } else {
1446            0.0
1447        };
1448
1449        let convergence_health = if let Some(conv_analysis) = &report.convergence_analysis {
1450            conv_analysis.convergence_probability
1451        } else {
1452            0.0
1453        };
1454
1455        let stability_score = if let Some(loss_analysis) = &report.loss_curve_analysis {
1456            loss_analysis.smoothness
1457        } else {
1458            0.0
1459        };
1460
1461        let overall_progress =
1462            (training_efficiency * 0.4 + convergence_health * 0.3 + stability_score * 0.3)
1463                .max(0.0)
1464                .min(1.0);
1465
1466        report.training_summary = TrainingSummary {
1467            total_epochs,
1468            total_steps,
1469            training_efficiency,
1470            convergence_health,
1471            stability_score,
1472            overall_progress,
1473        };
1474    }
1475
1476    /// Generate training recommendations
1477    fn generate_training_recommendations(&self, report: &mut TrainingDynamicsReport) {
1478        let mut recommendations = Vec::new();
1479
1480        // Learning rate recommendations
1481        if let Some(lr_analysis) = &report.learning_rate_analysis {
1482            for lr_rec in &lr_analysis.recommendations {
1483                recommendations.push(TrainingRecommendation {
1484                    category: TrainingCategory::LearningRate,
1485                    priority: if lr_rec.confidence > 0.8 {
1486                        Priority::High
1487                    } else {
1488                        Priority::Medium
1489                    },
1490                    description: lr_rec.rationale.clone(),
1491                    implementation: format!("{:?} learning rate", lr_rec.action),
1492                    expected_impact: lr_rec.expected_improvement,
1493                });
1494            }
1495        }
1496
1497        // Plateau recommendations
1498        if let Some(plateau_analysis) = &report.plateau_analysis {
1499            for plateau_rec in &plateau_analysis.recommendations {
1500                recommendations.push(TrainingRecommendation {
1501                    category: TrainingCategory::Optimization,
1502                    priority: plateau_rec.priority.clone(),
1503                    description: plateau_rec.description.clone(),
1504                    implementation: plateau_rec.implementation.clone(),
1505                    expected_impact: 0.5, // Default impact
1506                });
1507            }
1508        }
1509
1510        // Convergence recommendations
1511        if let Some(conv_analysis) = &report.convergence_analysis {
1512            if let Some(early_stop) = &conv_analysis.early_stopping_recommendation {
1513                if early_stop.should_stop {
1514                    recommendations.push(TrainingRecommendation {
1515                        category: TrainingCategory::EarlyStopping,
1516                        priority: Priority::High,
1517                        description: early_stop.rationale.clone(),
1518                        implementation: "Stop training and save current model".to_string(),
1519                        expected_impact: 0.8,
1520                    });
1521                }
1522            }
1523        }
1524
1525        report.recommendations = recommendations;
1526    }
1527
1528    /// Generate a comprehensive report
1529    pub async fn generate_report(&self) -> Result<TrainingDynamicsReport> {
1530        let mut temp_analyzer = TrainingDynamicsAnalyzer {
1531            config: self.config.clone(),
1532            metrics_history: self.metrics_history.clone(),
1533            analysis_cache: HashMap::new(),
1534        };
1535
1536        temp_analyzer.analyze().await
1537    }
1538
1539    /// Clear all recorded metrics
1540    pub fn clear(&mut self) {
1541        self.metrics_history.clear();
1542        self.analysis_cache.clear();
1543    }
1544
1545    /// Get summary of current training state
1546    pub fn get_training_summary(&self) -> TrainingStateSummary {
1547        let current_metrics = self.metrics_history.back();
1548
1549        TrainingStateSummary {
1550            total_epochs: current_metrics.map(|m| m.epoch).unwrap_or(0),
1551            total_steps: current_metrics.map(|m| m.step).unwrap_or(0),
1552            current_loss: current_metrics.map(|m| m.train_loss).unwrap_or(0.0),
1553            current_lr: current_metrics.map(|m| m.learning_rate).unwrap_or(0.0),
1554            metrics_collected: self.metrics_history.len(),
1555        }
1556    }
1557}
1558
1559/// Summary of current training state
1560#[derive(Debug, Clone, Serialize, Deserialize)]
1561pub struct TrainingStateSummary {
1562    pub total_epochs: usize,
1563    pub total_steps: usize,
1564    pub current_loss: f32,
1565    pub current_lr: f32,
1566    pub metrics_collected: usize,
1567}