Skip to main content

trustformers_debug/
training_dynamics.rs

1//! Training Dynamics Analysis
2//!
3//! Advanced analysis tools for understanding training dynamics including
4//! loss curve analysis, convergence detection, and learning rate impact assessment.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9
10/// Configuration for training dynamics analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingDynamicsConfig {
13    /// Enable loss curve analysis
14    pub enable_loss_curve_analysis: bool,
15    /// Enable learning rate impact analysis
16    pub enable_learning_rate_analysis: bool,
17    /// Enable batch size effects analysis
18    pub enable_batch_size_analysis: bool,
19    /// Enable convergence detection
20    pub enable_convergence_detection: bool,
21    /// Enable plateau identification
22    pub enable_plateau_identification: bool,
23    /// Window size for moving averages
24    pub moving_average_window: usize,
25    /// Convergence tolerance
26    pub convergence_tolerance: f32,
27    /// Plateau detection threshold
28    pub plateau_threshold: f32,
29    /// Minimum epochs for convergence detection
30    pub min_epochs_for_convergence: usize,
31    /// Maximum history length
32    pub max_history_length: usize,
33}
34
35impl Default for TrainingDynamicsConfig {
36    fn default() -> Self {
37        Self {
38            enable_loss_curve_analysis: true,
39            enable_learning_rate_analysis: true,
40            enable_batch_size_analysis: true,
41            enable_convergence_detection: true,
42            enable_plateau_identification: true,
43            moving_average_window: 10,
44            convergence_tolerance: 1e-6,
45            plateau_threshold: 1e-4,
46            min_epochs_for_convergence: 20,
47            max_history_length: 10000,
48        }
49    }
50}
51
52/// Training metrics at a specific point in time
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct TrainingMetrics {
55    pub epoch: usize,
56    pub step: usize,
57    pub train_loss: f32,
58    pub validation_loss: Option<f32>,
59    pub learning_rate: f32,
60    pub batch_size: usize,
61    pub gradient_norm: Option<f32>,
62    pub accuracy: Option<f32>,
63    pub timestamp: f64,
64}
65
66/// Loss curve analysis results
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LossCurveAnalysis {
69    pub trend: LossTrend,
70    pub smoothness: f32,
71    pub volatility: f32,
72    pub improvement_rate: f32,
73    pub best_loss: f32,
74    pub current_loss: f32,
75    pub loss_reduction_percentage: f32,
76    pub epochs_since_improvement: usize,
77    pub moving_averages: MovingAverages,
78    pub loss_statistics: LossStatistics,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum LossTrend {
83    Decreasing,
84    Increasing,
85    Oscillating,
86    Plateaued,
87    Unknown,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MovingAverages {
92    pub short_term: f32,  // Last 5-10 epochs
93    pub medium_term: f32, // Last 20-50 epochs
94    pub long_term: f32,   // Last 100+ epochs
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct LossStatistics {
99    pub mean: f32,
100    pub std: f32,
101    pub min: f32,
102    pub max: f32,
103    pub median: f32,
104    pub percentile_25: f32,
105    pub percentile_75: f32,
106    pub autocorrelation: f32,
107}
108
109/// Learning rate impact analysis
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LearningRateAnalysis {
112    pub current_lr: f32,
113    pub lr_schedule_type: LRScheduleType,
114    pub lr_impact_score: f32,
115    pub optimal_lr_estimate: f32,
116    pub lr_sensitivity: f32,
117    pub lr_history: Vec<LearningRatePoint>,
118    pub recommendations: Vec<LRRecommendation>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub enum LRScheduleType {
123    Constant,
124    StepDecay,
125    ExponentialDecay,
126    CosineAnnealing,
127    ReduceOnPlateau,
128    Warmup,
129    Cyclical,
130    Unknown,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct LearningRatePoint {
135    pub epoch: usize,
136    pub learning_rate: f32,
137    pub loss_change: f32,
138    pub gradient_norm: Option<f32>,
139    pub effectiveness: f32,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct LRRecommendation {
144    pub action: LRAction,
145    pub confidence: f32,
146    pub rationale: String,
147    pub expected_improvement: f32,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum LRAction {
152    Increase,
153    Decrease,
154    KeepCurrent,
155    AddScheduler,
156    ChangeScheduler,
157    AddWarmup,
158}
159
160/// Batch size effects analysis
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchSizeAnalysis {
163    pub current_batch_size: usize,
164    pub batch_size_efficiency: f32,
165    pub gradient_noise_level: f32,
166    pub convergence_speed: f32,
167    pub memory_utilization: f32,
168    pub optimal_batch_size_estimate: usize,
169    pub batch_size_history: Vec<BatchSizePoint>,
170    pub recommendations: Vec<BatchSizeRecommendation>,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct BatchSizePoint {
175    pub epoch: usize,
176    pub batch_size: usize,
177    pub loss_improvement: f32,
178    pub gradient_stability: f32,
179    pub throughput: f32,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct BatchSizeRecommendation {
184    pub suggested_batch_size: usize,
185    pub confidence: f32,
186    pub rationale: String,
187    pub expected_benefits: Vec<String>,
188}
189
190/// Convergence detection results
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ConvergenceAnalysis {
193    pub convergence_status: ConvergenceStatus,
194    pub convergence_probability: f32,
195    pub epochs_to_convergence_estimate: Option<usize>,
196    pub convergence_criteria: Vec<ConvergenceCriterion>,
197    pub early_stopping_recommendation: Option<EarlyStoppingRecommendation>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub enum ConvergenceStatus {
202    Converging,
203    Converged,
204    Diverging,
205    Oscillating,
206    TooEarly,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ConvergenceCriterion {
211    pub criterion_type: ConvergenceCriterionType,
212    pub current_value: f32,
213    pub threshold: f32,
214    pub satisfied: bool,
215    pub confidence: f32,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub enum ConvergenceCriterionType {
220    LossStability,
221    GradientMagnitude,
222    LossImprovement,
223    ValidationGap,
224    LearningRateDecay,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct EarlyStoppingRecommendation {
229    pub should_stop: bool,
230    pub confidence: f32,
231    pub rationale: String,
232    pub suggested_epochs_remaining: usize,
233}
234
235/// Plateau identification results
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct PlateauAnalysis {
238    pub plateau_detected: bool,
239    pub plateau_duration: usize,
240    pub plateau_level: f32,
241    pub plateau_type: PlateauType,
242    pub escape_probability: f32,
243    pub plateau_characteristics: PlateauCharacteristics,
244    pub recommendations: Vec<PlateauRecommendation>,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub enum PlateauType {
249    LossPlayteau,
250    GradientPlateau,
251    AccuracyPlateau,
252    LearningRatePlateau,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct PlateauCharacteristics {
257    pub stability: f32,
258    pub noise_level: f32,
259    pub gradient_magnitude: f32,
260    pub overfitting_risk: f32,
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct PlateauRecommendation {
265    pub action: PlateauAction,
266    pub priority: Priority,
267    pub description: String,
268    pub implementation: String,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub enum PlateauAction {
273    IncreaseLearningRate,
274    DecreaseLearningRate,
275    ChangeBatchSize,
276    AddRegularization,
277    RemoveRegularization,
278    ChangeOptimizer,
279    AddNoise,
280    EarlyStopping,
281    ContinueTraining,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub enum Priority {
286    Critical,
287    High,
288    Medium,
289    Low,
290}
291
292/// Comprehensive training dynamics report
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct TrainingDynamicsReport {
295    pub loss_curve_analysis: Option<LossCurveAnalysis>,
296    pub learning_rate_analysis: Option<LearningRateAnalysis>,
297    pub batch_size_analysis: Option<BatchSizeAnalysis>,
298    pub convergence_analysis: Option<ConvergenceAnalysis>,
299    pub plateau_analysis: Option<PlateauAnalysis>,
300    pub training_summary: TrainingSummary,
301    pub recommendations: Vec<TrainingRecommendation>,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct TrainingSummary {
306    pub total_epochs: usize,
307    pub total_steps: usize,
308    pub training_efficiency: f32,
309    pub convergence_health: f32,
310    pub stability_score: f32,
311    pub overall_progress: f32,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct TrainingRecommendation {
316    pub category: TrainingCategory,
317    pub priority: Priority,
318    pub description: String,
319    pub implementation: String,
320    pub expected_impact: f32,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub enum TrainingCategory {
325    LearningRate,
326    BatchSize,
327    Optimization,
328    Regularization,
329    EarlyStopping,
330    Architecture,
331}
332
333/// Training dynamics analyzer
334#[derive(Debug)]
335pub struct TrainingDynamicsAnalyzer {
336    config: TrainingDynamicsConfig,
337    metrics_history: VecDeque<TrainingMetrics>,
338    analysis_cache: HashMap<String, TrainingDynamicsReport>,
339}
340
341impl TrainingDynamicsAnalyzer {
342    /// Create a new training dynamics analyzer
343    pub fn new(config: TrainingDynamicsConfig) -> Self {
344        Self {
345            config,
346            metrics_history: VecDeque::new(),
347            analysis_cache: HashMap::new(),
348        }
349    }
350
351    /// Record training metrics
352    pub fn record_metrics(&mut self, metrics: TrainingMetrics) {
353        self.metrics_history.push_back(metrics);
354
355        // Limit history size
356        while self.metrics_history.len() > self.config.max_history_length {
357            self.metrics_history.pop_front();
358        }
359    }
360
361    /// Perform comprehensive training dynamics analysis
362    pub async fn analyze(&mut self) -> Result<TrainingDynamicsReport> {
363        let mut report = TrainingDynamicsReport {
364            loss_curve_analysis: None,
365            learning_rate_analysis: None,
366            batch_size_analysis: None,
367            convergence_analysis: None,
368            plateau_analysis: None,
369            training_summary: TrainingSummary {
370                total_epochs: 0,
371                total_steps: 0,
372                training_efficiency: 0.0,
373                convergence_health: 0.0,
374                stability_score: 0.0,
375                overall_progress: 0.0,
376            },
377            recommendations: Vec::new(),
378        };
379
380        if self.config.enable_loss_curve_analysis {
381            report.loss_curve_analysis = Some(self.analyze_loss_curve().await?);
382        }
383
384        if self.config.enable_learning_rate_analysis {
385            report.learning_rate_analysis = Some(self.analyze_learning_rate().await?);
386        }
387
388        if self.config.enable_batch_size_analysis {
389            report.batch_size_analysis = Some(self.analyze_batch_size().await?);
390        }
391
392        if self.config.enable_convergence_detection {
393            report.convergence_analysis = Some(self.detect_convergence().await?);
394        }
395
396        if self.config.enable_plateau_identification {
397            report.plateau_analysis = Some(self.identify_plateau().await?);
398        }
399
400        self.generate_training_summary(&mut report);
401        self.generate_training_recommendations(&mut report);
402
403        Ok(report)
404    }
405
406    /// Analyze loss curve patterns
407    async fn analyze_loss_curve(&self) -> Result<LossCurveAnalysis> {
408        if self.metrics_history.is_empty() {
409            return Ok(LossCurveAnalysis {
410                trend: LossTrend::Unknown,
411                smoothness: 0.0,
412                volatility: 0.0,
413                improvement_rate: 0.0,
414                best_loss: 0.0,
415                current_loss: 0.0,
416                loss_reduction_percentage: 0.0,
417                epochs_since_improvement: 0,
418                moving_averages: MovingAverages {
419                    short_term: 0.0,
420                    medium_term: 0.0,
421                    long_term: 0.0,
422                },
423                loss_statistics: LossStatistics {
424                    mean: 0.0,
425                    std: 0.0,
426                    min: 0.0,
427                    max: 0.0,
428                    median: 0.0,
429                    percentile_25: 0.0,
430                    percentile_75: 0.0,
431                    autocorrelation: 0.0,
432                },
433            });
434        }
435
436        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
437
438        let trend = self.detect_loss_trend(&losses);
439        let smoothness = self.calculate_smoothness(&losses);
440        let volatility = self.calculate_volatility(&losses);
441        let improvement_rate = self.calculate_improvement_rate(&losses);
442
443        let best_loss = losses.iter().fold(f32::INFINITY, |a, &b| a.min(b));
444        let current_loss = *losses.last().unwrap();
445        let loss_reduction_percentage = if losses.len() > 1 {
446            ((losses[0] - current_loss) / losses[0].abs()) * 100.0
447        } else {
448            0.0
449        };
450
451        let epochs_since_improvement = self.calculate_epochs_since_improvement(&losses, best_loss);
452        let moving_averages = self.calculate_moving_averages(&losses);
453        let loss_statistics = self.calculate_loss_statistics(&losses);
454
455        Ok(LossCurveAnalysis {
456            trend,
457            smoothness,
458            volatility,
459            improvement_rate,
460            best_loss,
461            current_loss,
462            loss_reduction_percentage,
463            epochs_since_improvement,
464            moving_averages,
465            loss_statistics,
466        })
467    }
468
469    /// Detect overall trend in loss curve
470    fn detect_loss_trend(&self, losses: &[f32]) -> LossTrend {
471        if losses.len() < 3 {
472            return LossTrend::Unknown;
473        }
474
475        let window_size = (losses.len() / 4).max(5).min(20);
476        let recent_losses = &losses[losses.len().saturating_sub(window_size)..];
477        let early_losses = &losses[..window_size.min(losses.len())];
478
479        let recent_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
480        let early_mean = early_losses.iter().sum::<f32>() / early_losses.len() as f32;
481
482        let improvement = (early_mean - recent_mean) / early_mean.abs();
483
484        // Check for plateau
485        let recent_std = self.calculate_std(recent_losses);
486        let recent_mean_abs = recent_mean.abs();
487
488        if recent_std / recent_mean_abs.max(1e-8) < self.config.plateau_threshold {
489            return LossTrend::Plateaued;
490        }
491
492        // Check for oscillation
493        let oscillation_score = self.detect_oscillation(losses);
494        if oscillation_score > 0.5 {
495            return LossTrend::Oscillating;
496        }
497
498        if improvement > 0.01 {
499            LossTrend::Decreasing
500        } else if improvement < -0.01 {
501            LossTrend::Increasing
502        } else {
503            LossTrend::Plateaued
504        }
505    }
506
507    /// Calculate smoothness of loss curve
508    fn calculate_smoothness(&self, losses: &[f32]) -> f32 {
509        if losses.len() < 2 {
510            return 1.0;
511        }
512
513        let differences: Vec<f32> = losses.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
514
515        let mean_diff = differences.iter().sum::<f32>() / differences.len() as f32;
516        let mean_loss = losses.iter().sum::<f32>() / losses.len() as f32;
517
518        // Smoothness is inverse of relative variation
519        1.0 / (1.0 + mean_diff / mean_loss.abs().max(1e-8))
520    }
521
522    /// Calculate volatility of loss curve
523    fn calculate_volatility(&self, losses: &[f32]) -> f32 {
524        if losses.len() < 2 {
525            return 0.0;
526        }
527
528        let returns: Vec<f32> =
529            losses.windows(2).map(|w| (w[1] - w[0]) / w[0].abs().max(1e-8)).collect();
530
531        self.calculate_std(&returns)
532    }
533
534    /// Calculate improvement rate
535    fn calculate_improvement_rate(&self, losses: &[f32]) -> f32 {
536        if losses.len() < 2 {
537            return 0.0;
538        }
539
540        let total_improvement = losses[0] - losses[losses.len() - 1];
541        let epochs = losses.len() as f32;
542
543        total_improvement / epochs
544    }
545
546    /// Calculate epochs since last improvement
547    fn calculate_epochs_since_improvement(&self, losses: &[f32], best_loss: f32) -> usize {
548        for (i, &loss) in losses.iter().rev().enumerate() {
549            if (loss - best_loss).abs() < 1e-8 {
550                return i;
551            }
552        }
553        losses.len()
554    }
555
556    /// Calculate moving averages
557    fn calculate_moving_averages(&self, losses: &[f32]) -> MovingAverages {
558        let short_window = 5.min(losses.len());
559        let medium_window = 20.min(losses.len());
560        let long_window = 100.min(losses.len());
561
562        let short_term = if short_window > 0 {
563            losses[losses.len() - short_window..].iter().sum::<f32>() / short_window as f32
564        } else {
565            0.0
566        };
567
568        let medium_term = if medium_window > 0 {
569            losses[losses.len() - medium_window..].iter().sum::<f32>() / medium_window as f32
570        } else {
571            0.0
572        };
573
574        let long_term = if long_window > 0 {
575            losses[losses.len() - long_window..].iter().sum::<f32>() / long_window as f32
576        } else {
577            0.0
578        };
579
580        MovingAverages {
581            short_term,
582            medium_term,
583            long_term,
584        }
585    }
586
587    /// Calculate comprehensive loss statistics
588    fn calculate_loss_statistics(&self, losses: &[f32]) -> LossStatistics {
589        if losses.is_empty() {
590            return LossStatistics {
591                mean: 0.0,
592                std: 0.0,
593                min: 0.0,
594                max: 0.0,
595                median: 0.0,
596                percentile_25: 0.0,
597                percentile_75: 0.0,
598                autocorrelation: 0.0,
599            };
600        }
601
602        let mean = losses.iter().sum::<f32>() / losses.len() as f32;
603        let std = self.calculate_std(losses);
604
605        let mut sorted_losses = losses.to_vec();
606        sorted_losses.sort_by(|a, b| a.partial_cmp(b).unwrap());
607
608        let min = sorted_losses[0];
609        let max = sorted_losses[sorted_losses.len() - 1];
610        let median = sorted_losses[sorted_losses.len() / 2];
611        let percentile_25 = sorted_losses[sorted_losses.len() / 4];
612        let percentile_75 = sorted_losses[3 * sorted_losses.len() / 4];
613
614        let autocorrelation = self.calculate_autocorrelation(losses, 1);
615
616        LossStatistics {
617            mean,
618            std,
619            min,
620            max,
621            median,
622            percentile_25,
623            percentile_75,
624            autocorrelation,
625        }
626    }
627
628    /// Calculate standard deviation
629    fn calculate_std(&self, values: &[f32]) -> f32 {
630        if values.len() < 2 {
631            return 0.0;
632        }
633
634        let mean = values.iter().sum::<f32>() / values.len() as f32;
635        let variance =
636            values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
637
638        variance.sqrt()
639    }
640
641    /// Detect oscillation in loss curve
642    fn detect_oscillation(&self, losses: &[f32]) -> f32 {
643        if losses.len() < 4 {
644            return 0.0;
645        }
646
647        let mut direction_changes = 0;
648        let mut total_comparisons = 0;
649
650        for i in 1..losses.len() - 1 {
651            let prev_direction = losses[i] > losses[i - 1];
652            let next_direction = losses[i + 1] > losses[i];
653
654            if prev_direction != next_direction {
655                direction_changes += 1;
656            }
657            total_comparisons += 1;
658        }
659
660        direction_changes as f32 / total_comparisons as f32
661    }
662
663    /// Calculate autocorrelation
664    fn calculate_autocorrelation(&self, values: &[f32], lag: usize) -> f32 {
665        if values.len() <= lag {
666            return 0.0;
667        }
668
669        let mean = values.iter().sum::<f32>() / values.len() as f32;
670
671        let mut numerator = 0.0;
672        let mut denominator = 0.0;
673
674        for i in 0..values.len() - lag {
675            numerator += (values[i] - mean) * (values[i + lag] - mean);
676        }
677
678        for &value in values {
679            denominator += (value - mean).powi(2);
680        }
681
682        if denominator > 1e-8 {
683            numerator / denominator
684        } else {
685            0.0
686        }
687    }
688
689    /// Analyze learning rate impact
690    async fn analyze_learning_rate(&self) -> Result<LearningRateAnalysis> {
691        if self.metrics_history.is_empty() {
692            return Ok(LearningRateAnalysis {
693                current_lr: 0.0,
694                lr_schedule_type: LRScheduleType::Unknown,
695                lr_impact_score: 0.0,
696                optimal_lr_estimate: 0.0,
697                lr_sensitivity: 0.0,
698                lr_history: Vec::new(),
699                recommendations: Vec::new(),
700            });
701        }
702
703        let current_lr = self.metrics_history.back().unwrap().learning_rate;
704        let lr_schedule_type = self.detect_lr_schedule_type();
705
706        let lr_history = self.build_lr_history();
707        let lr_impact_score = self.calculate_lr_impact_score(&lr_history);
708        let optimal_lr_estimate = self.estimate_optimal_lr(&lr_history);
709        let lr_sensitivity = self.calculate_lr_sensitivity(&lr_history);
710        let recommendations = self.generate_lr_recommendations(current_lr, &lr_history);
711
712        Ok(LearningRateAnalysis {
713            current_lr,
714            lr_schedule_type,
715            lr_impact_score,
716            optimal_lr_estimate,
717            lr_sensitivity,
718            lr_history,
719            recommendations,
720        })
721    }
722
723    /// Detect learning rate schedule type
724    fn detect_lr_schedule_type(&self) -> LRScheduleType {
725        let lrs: Vec<f32> = self.metrics_history.iter().map(|m| m.learning_rate).collect();
726
727        if lrs.len() < 3 {
728            return LRScheduleType::Unknown;
729        }
730
731        // Check for constant LR
732        let lr_std = self.calculate_std(&lrs);
733        if lr_std < 1e-8 {
734            return LRScheduleType::Constant;
735        }
736
737        // Check for step decay (sudden drops)
738        let mut step_drops = 0;
739        for window in lrs.windows(2) {
740            if window[1] < window[0] * 0.9 {
741                step_drops += 1;
742            }
743        }
744
745        if step_drops > lrs.len() / 20 {
746            return LRScheduleType::StepDecay;
747        }
748
749        // Check for exponential decay
750        let log_lrs: Vec<f32> = lrs.iter().map(|&lr| lr.ln()).collect();
751        let exponential_trend = self.calculate_linear_trend(&log_lrs);
752        if exponential_trend < -0.01 {
753            return LRScheduleType::ExponentialDecay;
754        }
755
756        // Check for cyclical patterns
757        let cyclical_score = self.detect_cyclical_pattern(&lrs);
758        if cyclical_score > 0.3 {
759            return LRScheduleType::Cyclical;
760        }
761
762        LRScheduleType::Unknown
763    }
764
765    /// Calculate linear trend
766    fn calculate_linear_trend(&self, values: &[f32]) -> f32 {
767        if values.len() < 2 {
768            return 0.0;
769        }
770
771        let n = values.len() as f32;
772        let x_mean = (n - 1.0) / 2.0;
773        let y_mean = values.iter().sum::<f32>() / n;
774
775        let mut numerator = 0.0;
776        let mut denominator = 0.0;
777
778        for (i, &y) in values.iter().enumerate() {
779            let x = i as f32;
780            numerator += (x - x_mean) * (y - y_mean);
781            denominator += (x - x_mean).powi(2);
782        }
783
784        if denominator > 1e-8 {
785            numerator / denominator
786        } else {
787            0.0
788        }
789    }
790
791    /// Detect cyclical patterns
792    fn detect_cyclical_pattern(&self, values: &[f32]) -> f32 {
793        // Simplified cyclical detection using autocorrelation
794        let mut max_autocorr: f32 = 0.0;
795        for lag in 2..=values.len() / 4 {
796            let autocorr = self.calculate_autocorrelation(values, lag).abs();
797            max_autocorr = max_autocorr.max(autocorr);
798        }
799        max_autocorr
800    }
801
802    /// Build learning rate history with effectiveness scores
803    fn build_lr_history(&self) -> Vec<LearningRatePoint> {
804        let mut history = Vec::new();
805
806        for (i, metrics) in self.metrics_history.iter().enumerate() {
807            let loss_change = if i > 0 {
808                self.metrics_history[i - 1].train_loss - metrics.train_loss
809            } else {
810                0.0
811            };
812
813            let effectiveness = if loss_change > 0.0 {
814                loss_change / metrics.learning_rate.max(1e-8)
815            } else {
816                0.0
817            };
818
819            history.push(LearningRatePoint {
820                epoch: metrics.epoch,
821                learning_rate: metrics.learning_rate,
822                loss_change,
823                gradient_norm: metrics.gradient_norm,
824                effectiveness,
825            });
826        }
827
828        history
829    }
830
831    /// Calculate learning rate impact score
832    fn calculate_lr_impact_score(&self, lr_history: &[LearningRatePoint]) -> f32 {
833        if lr_history.is_empty() {
834            return 0.0;
835        }
836
837        let avg_effectiveness =
838            lr_history.iter().map(|p| p.effectiveness).sum::<f32>() / lr_history.len() as f32;
839
840        avg_effectiveness.max(0.0).min(1.0)
841    }
842
843    /// Estimate optimal learning rate
844    fn estimate_optimal_lr(&self, lr_history: &[LearningRatePoint]) -> f32 {
845        if lr_history.is_empty() {
846            return 0.001; // Default
847        }
848
849        // Find LR with highest effectiveness
850        lr_history
851            .iter()
852            .max_by(|a, b| a.effectiveness.partial_cmp(&b.effectiveness).unwrap())
853            .map(|p| p.learning_rate)
854            .unwrap_or(0.001)
855    }
856
857    /// Calculate learning rate sensitivity
858    fn calculate_lr_sensitivity(&self, lr_history: &[LearningRatePoint]) -> f32 {
859        if lr_history.len() < 2 {
860            return 0.0;
861        }
862
863        let effectiveness_values: Vec<f32> = lr_history.iter().map(|p| p.effectiveness).collect();
864
865        self.calculate_std(&effectiveness_values)
866    }
867
868    /// Generate learning rate recommendations
869    fn generate_lr_recommendations(
870        &self,
871        current_lr: f32,
872        lr_history: &[LearningRatePoint],
873    ) -> Vec<LRRecommendation> {
874        let mut recommendations = Vec::new();
875
876        if lr_history.is_empty() {
877            return recommendations;
878        }
879
880        let recent_effectiveness =
881            lr_history.iter().rev().take(5).map(|p| p.effectiveness).sum::<f32>()
882                / 5.0f32.min(lr_history.len() as f32);
883
884        if recent_effectiveness < 0.1 {
885            recommendations.push(LRRecommendation {
886                action: LRAction::Decrease,
887                confidence: 0.7,
888                rationale: "Low learning effectiveness detected".to_string(),
889                expected_improvement: 0.3,
890            });
891        }
892
893        let optimal_lr = self.estimate_optimal_lr(lr_history);
894        if current_lr > optimal_lr * 2.0 {
895            recommendations.push(LRRecommendation {
896                action: LRAction::Decrease,
897                confidence: 0.8,
898                rationale: "Current LR significantly higher than estimated optimal".to_string(),
899                expected_improvement: 0.4,
900            });
901        } else if current_lr < optimal_lr * 0.5 {
902            recommendations.push(LRRecommendation {
903                action: LRAction::Increase,
904                confidence: 0.6,
905                rationale: "Current LR significantly lower than estimated optimal".to_string(),
906                expected_improvement: 0.3,
907            });
908        }
909
910        recommendations
911    }
912
913    /// Analyze batch size effects
914    async fn analyze_batch_size(&self) -> Result<BatchSizeAnalysis> {
915        if self.metrics_history.is_empty() {
916            return Ok(BatchSizeAnalysis {
917                current_batch_size: 0,
918                batch_size_efficiency: 0.0,
919                gradient_noise_level: 0.0,
920                convergence_speed: 0.0,
921                memory_utilization: 0.0,
922                optimal_batch_size_estimate: 32,
923                batch_size_history: Vec::new(),
924                recommendations: Vec::new(),
925            });
926        }
927
928        let current_batch_size = self.metrics_history.back().unwrap().batch_size;
929        let batch_size_history = self.build_batch_size_history();
930
931        let batch_size_efficiency = self.calculate_batch_size_efficiency(&batch_size_history);
932        let gradient_noise_level = self.estimate_gradient_noise_level();
933        let convergence_speed = self.estimate_convergence_speed();
934        let memory_utilization = self.estimate_memory_utilization(current_batch_size);
935        let optimal_batch_size_estimate = self.estimate_optimal_batch_size(&batch_size_history);
936        let recommendations =
937            self.generate_batch_size_recommendations(current_batch_size, &batch_size_history);
938
939        Ok(BatchSizeAnalysis {
940            current_batch_size,
941            batch_size_efficiency,
942            gradient_noise_level,
943            convergence_speed,
944            memory_utilization,
945            optimal_batch_size_estimate,
946            batch_size_history,
947            recommendations,
948        })
949    }
950
951    /// Build batch size history
952    fn build_batch_size_history(&self) -> Vec<BatchSizePoint> {
953        let mut history = Vec::new();
954
955        for (i, metrics) in self.metrics_history.iter().enumerate() {
956            let loss_improvement = if i > 0 {
957                self.metrics_history[i - 1].train_loss - metrics.train_loss
958            } else {
959                0.0
960            };
961
962            let gradient_stability =
963                metrics.gradient_norm.map(|gn| 1.0 / (1.0 + gn)).unwrap_or(0.5);
964            let throughput = 1.0; // Simplified throughput metric
965
966            history.push(BatchSizePoint {
967                epoch: metrics.epoch,
968                batch_size: metrics.batch_size,
969                loss_improvement,
970                gradient_stability,
971                throughput,
972            });
973        }
974
975        history
976    }
977
978    /// Calculate batch size efficiency
979    fn calculate_batch_size_efficiency(&self, batch_history: &[BatchSizePoint]) -> f32 {
980        if batch_history.is_empty() {
981            return 0.0;
982        }
983
984        let avg_improvement =
985            batch_history.iter().map(|p| p.loss_improvement.max(0.0)).sum::<f32>()
986                / batch_history.len() as f32;
987
988        let avg_stability = batch_history.iter().map(|p| p.gradient_stability).sum::<f32>()
989            / batch_history.len() as f32;
990
991        (avg_improvement * 0.6 + avg_stability * 0.4).min(1.0)
992    }
993
994    /// Estimate gradient noise level
995    fn estimate_gradient_noise_level(&self) -> f32 {
996        let gradient_norms: Vec<f32> =
997            self.metrics_history.iter().filter_map(|m| m.gradient_norm).collect();
998
999        if gradient_norms.is_empty() {
1000            return 0.5;
1001        }
1002
1003        let std = self.calculate_std(&gradient_norms);
1004        let mean = gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32;
1005
1006        if mean > 1e-8 {
1007            (std / mean).min(1.0)
1008        } else {
1009            0.5
1010        }
1011    }
1012
1013    /// Estimate convergence speed
1014    fn estimate_convergence_speed(&self) -> f32 {
1015        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1016
1017        if losses.len() < 2 {
1018            return 0.0;
1019        }
1020
1021        let improvement_per_epoch = (losses[0] - losses[losses.len() - 1]) / losses.len() as f32;
1022        improvement_per_epoch.max(0.0).min(1.0)
1023    }
1024
1025    /// Estimate memory utilization
1026    fn estimate_memory_utilization(&self, batch_size: usize) -> f32 {
1027        // Simplified memory utilization based on batch size
1028        let normalized_batch_size = batch_size as f32 / 1024.0; // Normalize by typical large batch size
1029        normalized_batch_size.min(1.0)
1030    }
1031
1032    /// Estimate optimal batch size
1033    fn estimate_optimal_batch_size(&self, batch_history: &[BatchSizePoint]) -> usize {
1034        if batch_history.is_empty() {
1035            return 32;
1036        }
1037
1038        // Find batch size with best balance of improvement and stability
1039        batch_history
1040            .iter()
1041            .max_by(|a, b| {
1042                let score_a = a.loss_improvement * 0.6 + a.gradient_stability * 0.4;
1043                let score_b = b.loss_improvement * 0.6 + b.gradient_stability * 0.4;
1044                score_a.partial_cmp(&score_b).unwrap()
1045            })
1046            .map(|p| p.batch_size)
1047            .unwrap_or(32)
1048    }
1049
1050    /// Generate batch size recommendations
1051    fn generate_batch_size_recommendations(
1052        &self,
1053        current_batch_size: usize,
1054        _batch_history: &[BatchSizePoint],
1055    ) -> Vec<BatchSizeRecommendation> {
1056        let mut recommendations = Vec::new();
1057
1058        if current_batch_size < 16 {
1059            recommendations.push(BatchSizeRecommendation {
1060                suggested_batch_size: 32,
1061                confidence: 0.7,
1062                rationale: "Very small batch size may lead to noisy gradients".to_string(),
1063                expected_benefits: vec![
1064                    "More stable gradients".to_string(),
1065                    "Better convergence".to_string(),
1066                ],
1067            });
1068        } else if current_batch_size > 512 {
1069            recommendations.push(BatchSizeRecommendation {
1070                suggested_batch_size: 256,
1071                confidence: 0.6,
1072                rationale: "Large batch size may slow convergence".to_string(),
1073                expected_benefits: vec![
1074                    "Faster convergence".to_string(),
1075                    "Lower memory usage".to_string(),
1076                ],
1077            });
1078        }
1079
1080        recommendations
1081    }
1082
1083    /// Detect convergence
1084    async fn detect_convergence(&self) -> Result<ConvergenceAnalysis> {
1085        if self.metrics_history.len() < self.config.min_epochs_for_convergence {
1086            return Ok(ConvergenceAnalysis {
1087                convergence_status: ConvergenceStatus::TooEarly,
1088                convergence_probability: 0.0,
1089                epochs_to_convergence_estimate: None,
1090                convergence_criteria: Vec::new(),
1091                early_stopping_recommendation: None,
1092            });
1093        }
1094
1095        let convergence_criteria = self.evaluate_convergence_criteria();
1096        let convergence_status = self.determine_convergence_status(&convergence_criteria);
1097        let convergence_probability = self.calculate_convergence_probability(&convergence_criteria);
1098        let epochs_to_convergence_estimate = self.estimate_epochs_to_convergence();
1099        let early_stopping_recommendation =
1100            self.generate_early_stopping_recommendation(&convergence_criteria);
1101
1102        Ok(ConvergenceAnalysis {
1103            convergence_status,
1104            convergence_probability,
1105            epochs_to_convergence_estimate,
1106            convergence_criteria,
1107            early_stopping_recommendation,
1108        })
1109    }
1110
1111    /// Evaluate convergence criteria
1112    fn evaluate_convergence_criteria(&self) -> Vec<ConvergenceCriterion> {
1113        let mut criteria = Vec::new();
1114
1115        // Loss stability criterion
1116        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1117        let recent_window = 10.min(losses.len());
1118        let recent_losses = &losses[losses.len() - recent_window..];
1119        let loss_std = self.calculate_std(recent_losses);
1120        let loss_mean = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1121        let loss_stability = loss_std / loss_mean.abs().max(1e-8);
1122
1123        criteria.push(ConvergenceCriterion {
1124            criterion_type: ConvergenceCriterionType::LossStability,
1125            current_value: loss_stability,
1126            threshold: self.config.convergence_tolerance,
1127            satisfied: loss_stability < self.config.convergence_tolerance,
1128            confidence: 0.8,
1129        });
1130
1131        // Gradient magnitude criterion
1132        if let Some(recent_grad_norm) = self.metrics_history.back().and_then(|m| m.gradient_norm) {
1133            criteria.push(ConvergenceCriterion {
1134                criterion_type: ConvergenceCriterionType::GradientMagnitude,
1135                current_value: recent_grad_norm,
1136                threshold: 1e-4,
1137                satisfied: recent_grad_norm < 1e-4,
1138                confidence: 0.7,
1139            });
1140        }
1141
1142        // Loss improvement criterion
1143        if losses.len() >= 10 {
1144            let old_window = &losses[losses.len() - 20..losses.len() - 10];
1145            let new_window = &losses[losses.len() - 10..];
1146            let old_mean = old_window.iter().sum::<f32>() / old_window.len() as f32;
1147            let new_mean = new_window.iter().sum::<f32>() / new_window.len() as f32;
1148            let improvement = (old_mean - new_mean) / old_mean.abs().max(1e-8);
1149
1150            criteria.push(ConvergenceCriterion {
1151                criterion_type: ConvergenceCriterionType::LossImprovement,
1152                current_value: improvement,
1153                threshold: 1e-3,
1154                satisfied: improvement < 1e-3,
1155                confidence: 0.6,
1156            });
1157        }
1158
1159        criteria
1160    }
1161
1162    /// Determine convergence status
1163    fn determine_convergence_status(&self, criteria: &[ConvergenceCriterion]) -> ConvergenceStatus {
1164        let satisfied_count = criteria.iter().filter(|c| c.satisfied).count();
1165        let total_count = criteria.len();
1166
1167        if total_count == 0 {
1168            return ConvergenceStatus::TooEarly;
1169        }
1170
1171        let satisfaction_rate = satisfied_count as f32 / total_count as f32;
1172
1173        if satisfaction_rate > 0.8 {
1174            ConvergenceStatus::Converged
1175        } else if satisfaction_rate > 0.5 {
1176            ConvergenceStatus::Converging
1177        } else {
1178            // Check for divergence
1179            let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1180            let recent_trend =
1181                self.calculate_linear_trend(&losses[losses.len().saturating_sub(20)..]);
1182
1183            if recent_trend > 0.01 {
1184                ConvergenceStatus::Diverging
1185            } else {
1186                ConvergenceStatus::Oscillating
1187            }
1188        }
1189    }
1190
1191    /// Calculate convergence probability
1192    fn calculate_convergence_probability(&self, criteria: &[ConvergenceCriterion]) -> f32 {
1193        if criteria.is_empty() {
1194            return 0.0;
1195        }
1196
1197        let weighted_satisfaction: f32 =
1198            criteria.iter().map(|c| if c.satisfied { c.confidence } else { 0.0 }).sum();
1199
1200        let total_weight: f32 = criteria.iter().map(|c| c.confidence).sum();
1201
1202        if total_weight > 0.0 {
1203            weighted_satisfaction / total_weight
1204        } else {
1205            0.0
1206        }
1207    }
1208
1209    /// Estimate epochs to convergence
1210    fn estimate_epochs_to_convergence(&self) -> Option<usize> {
1211        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1212
1213        if losses.len() < 5 {
1214            return None;
1215        }
1216
1217        let improvement_rate = self.calculate_improvement_rate(&losses);
1218
1219        if improvement_rate <= 0.0 {
1220            return None;
1221        }
1222
1223        let current_loss = *losses.last().unwrap();
1224        let target_loss = current_loss * (1.0 - self.config.convergence_tolerance);
1225        let remaining_improvement = current_loss - target_loss;
1226
1227        let epochs_needed = (remaining_improvement / improvement_rate).ceil() as usize;
1228
1229        Some(epochs_needed.min(1000)) // Cap at reasonable number
1230    }
1231
1232    /// Generate early stopping recommendation
1233    fn generate_early_stopping_recommendation(
1234        &self,
1235        criteria: &[ConvergenceCriterion],
1236    ) -> Option<EarlyStoppingRecommendation> {
1237        let convergence_probability = self.calculate_convergence_probability(criteria);
1238
1239        if convergence_probability > 0.9 {
1240            Some(EarlyStoppingRecommendation {
1241                should_stop: true,
1242                confidence: convergence_probability,
1243                rationale: "High convergence probability detected".to_string(),
1244                suggested_epochs_remaining: 0,
1245            })
1246        } else if convergence_probability > 0.7 {
1247            Some(EarlyStoppingRecommendation {
1248                should_stop: false,
1249                confidence: convergence_probability,
1250                rationale: "Approaching convergence, continue for a few more epochs".to_string(),
1251                suggested_epochs_remaining: 5,
1252            })
1253        } else {
1254            None
1255        }
1256    }
1257
1258    /// Identify plateaus
1259    async fn identify_plateau(&self) -> Result<PlateauAnalysis> {
1260        let losses: Vec<f32> = self.metrics_history.iter().map(|m| m.train_loss).collect();
1261
1262        if losses.len() < 10 {
1263            return Ok(PlateauAnalysis {
1264                plateau_detected: false,
1265                plateau_duration: 0,
1266                plateau_level: 0.0,
1267                plateau_type: PlateauType::LossPlayteau,
1268                escape_probability: 0.0,
1269                plateau_characteristics: PlateauCharacteristics {
1270                    stability: 0.0,
1271                    noise_level: 0.0,
1272                    gradient_magnitude: 0.0,
1273                    overfitting_risk: 0.0,
1274                },
1275                recommendations: Vec::new(),
1276            });
1277        }
1278
1279        let window_size = 10.min(losses.len());
1280        let recent_losses = &losses[losses.len() - window_size..];
1281
1282        let plateau_detected = self.detect_plateau_in_window(recent_losses);
1283        let plateau_duration =
1284            if plateau_detected { self.calculate_plateau_duration(&losses) } else { 0 };
1285
1286        let plateau_level = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
1287        let plateau_type = PlateauType::LossPlayteau; // Simplified
1288        let escape_probability =
1289            self.estimate_plateau_escape_probability(&losses, plateau_duration);
1290        let plateau_characteristics = self.analyze_plateau_characteristics(recent_losses);
1291        let recommendations =
1292            self.generate_plateau_recommendations(plateau_detected, plateau_duration);
1293
1294        Ok(PlateauAnalysis {
1295            plateau_detected,
1296            plateau_duration,
1297            plateau_level,
1298            plateau_type,
1299            escape_probability,
1300            plateau_characteristics,
1301            recommendations,
1302        })
1303    }
1304
1305    /// Detect plateau in a window of values
1306    fn detect_plateau_in_window(&self, values: &[f32]) -> bool {
1307        if values.len() < 3 {
1308            return false;
1309        }
1310
1311        let std = self.calculate_std(values);
1312        let mean = values.iter().sum::<f32>() / values.len() as f32;
1313
1314        std / mean.abs().max(1e-8) < self.config.plateau_threshold
1315    }
1316
1317    /// Calculate plateau duration
1318    fn calculate_plateau_duration(&self, losses: &[f32]) -> usize {
1319        let threshold = self.config.plateau_threshold;
1320        let mut duration = 0;
1321
1322        for window in losses.windows(10).rev() {
1323            let std = self.calculate_std(window);
1324            let mean = window.iter().sum::<f32>() / window.len() as f32;
1325
1326            if std / mean.abs().max(1e-8) < threshold {
1327                duration += 1;
1328            } else {
1329                break;
1330            }
1331        }
1332
1333        duration
1334    }
1335
1336    /// Estimate plateau escape probability
1337    fn estimate_plateau_escape_probability(&self, losses: &[f32], plateau_duration: usize) -> f32 {
1338        if plateau_duration == 0 {
1339            return 1.0;
1340        }
1341
1342        // Longer plateaus are harder to escape
1343        let duration_factor = 1.0 / (1.0 + plateau_duration as f32 * 0.1);
1344
1345        // Recent trend might indicate escape potential
1346        let recent_trend = if losses.len() >= 5 {
1347            self.calculate_linear_trend(&losses[losses.len() - 5..])
1348        } else {
1349            0.0
1350        };
1351
1352        let trend_factor = if recent_trend < 0.0 { 0.8 } else { 0.3 };
1353
1354        (duration_factor * trend_factor).max(0.1).min(0.9)
1355    }
1356
1357    /// Analyze plateau characteristics
1358    fn analyze_plateau_characteristics(&self, plateau_values: &[f32]) -> PlateauCharacteristics {
1359        let stability = 1.0 - self.calculate_std(plateau_values);
1360        let noise_level = self.calculate_std(plateau_values);
1361
1362        let gradient_magnitude =
1363            self.metrics_history.back().and_then(|m| m.gradient_norm).unwrap_or(0.0);
1364
1365        // Simple overfitting risk estimation
1366        let overfitting_risk =
1367            if let Some(val_loss) = self.metrics_history.back().and_then(|m| m.validation_loss) {
1368                let train_loss = self.metrics_history.back().unwrap().train_loss;
1369                ((val_loss - train_loss) / train_loss.abs().max(1e-8)).max(0.0).min(1.0)
1370            } else {
1371                0.5
1372            };
1373
1374        PlateauCharacteristics {
1375            stability: stability.max(0.0).min(1.0),
1376            noise_level: noise_level.min(1.0),
1377            gradient_magnitude,
1378            overfitting_risk,
1379        }
1380    }
1381
1382    /// Generate plateau recommendations
1383    fn generate_plateau_recommendations(
1384        &self,
1385        plateau_detected: bool,
1386        plateau_duration: usize,
1387    ) -> Vec<PlateauRecommendation> {
1388        let mut recommendations = Vec::new();
1389
1390        if !plateau_detected {
1391            return recommendations;
1392        }
1393
1394        if plateau_duration > 20 {
1395            recommendations.push(PlateauRecommendation {
1396                action: PlateauAction::IncreaseLearningRate,
1397                priority: Priority::High,
1398                description: "Long plateau detected, consider increasing learning rate".to_string(),
1399                implementation: "Multiply current learning rate by 2-5x temporarily".to_string(),
1400            });
1401        } else if plateau_duration > 10 {
1402            recommendations.push(PlateauRecommendation {
1403                action: PlateauAction::ChangeBatchSize,
1404                priority: Priority::Medium,
1405                description: "Moderate plateau detected, try changing batch size".to_string(),
1406                implementation: "Increase or decrease batch size by 50%".to_string(),
1407            });
1408        }
1409
1410        if plateau_duration > 30 {
1411            recommendations.push(PlateauRecommendation {
1412                action: PlateauAction::EarlyStopping,
1413                priority: Priority::Critical,
1414                description: "Very long plateau, consider early stopping".to_string(),
1415                implementation: "Stop training and use best checkpoint".to_string(),
1416            });
1417        }
1418
1419        recommendations
1420    }
1421
1422    /// Generate training summary
1423    fn generate_training_summary(&self, report: &mut TrainingDynamicsReport) {
1424        let total_epochs = self.metrics_history.back().map(|m| m.epoch).unwrap_or(0);
1425        let total_steps = self.metrics_history.back().map(|m| m.step).unwrap_or(0);
1426
1427        let training_efficiency = if let Some(loss_analysis) = &report.loss_curve_analysis {
1428            loss_analysis.improvement_rate.max(0.0).min(1.0)
1429        } else {
1430            0.0
1431        };
1432
1433        let convergence_health = if let Some(conv_analysis) = &report.convergence_analysis {
1434            conv_analysis.convergence_probability
1435        } else {
1436            0.0
1437        };
1438
1439        let stability_score = if let Some(loss_analysis) = &report.loss_curve_analysis {
1440            loss_analysis.smoothness
1441        } else {
1442            0.0
1443        };
1444
1445        let overall_progress =
1446            (training_efficiency * 0.4 + convergence_health * 0.3 + stability_score * 0.3)
1447                .max(0.0)
1448                .min(1.0);
1449
1450        report.training_summary = TrainingSummary {
1451            total_epochs,
1452            total_steps,
1453            training_efficiency,
1454            convergence_health,
1455            stability_score,
1456            overall_progress,
1457        };
1458    }
1459
1460    /// Generate training recommendations
1461    fn generate_training_recommendations(&self, report: &mut TrainingDynamicsReport) {
1462        let mut recommendations = Vec::new();
1463
1464        // Learning rate recommendations
1465        if let Some(lr_analysis) = &report.learning_rate_analysis {
1466            for lr_rec in &lr_analysis.recommendations {
1467                recommendations.push(TrainingRecommendation {
1468                    category: TrainingCategory::LearningRate,
1469                    priority: if lr_rec.confidence > 0.8 {
1470                        Priority::High
1471                    } else {
1472                        Priority::Medium
1473                    },
1474                    description: lr_rec.rationale.clone(),
1475                    implementation: format!("{:?} learning rate", lr_rec.action),
1476                    expected_impact: lr_rec.expected_improvement,
1477                });
1478            }
1479        }
1480
1481        // Plateau recommendations
1482        if let Some(plateau_analysis) = &report.plateau_analysis {
1483            for plateau_rec in &plateau_analysis.recommendations {
1484                recommendations.push(TrainingRecommendation {
1485                    category: TrainingCategory::Optimization,
1486                    priority: plateau_rec.priority.clone(),
1487                    description: plateau_rec.description.clone(),
1488                    implementation: plateau_rec.implementation.clone(),
1489                    expected_impact: 0.5, // Default impact
1490                });
1491            }
1492        }
1493
1494        // Convergence recommendations
1495        if let Some(conv_analysis) = &report.convergence_analysis {
1496            if let Some(early_stop) = &conv_analysis.early_stopping_recommendation {
1497                if early_stop.should_stop {
1498                    recommendations.push(TrainingRecommendation {
1499                        category: TrainingCategory::EarlyStopping,
1500                        priority: Priority::High,
1501                        description: early_stop.rationale.clone(),
1502                        implementation: "Stop training and save current model".to_string(),
1503                        expected_impact: 0.8,
1504                    });
1505                }
1506            }
1507        }
1508
1509        report.recommendations = recommendations;
1510    }
1511
1512    /// Generate a comprehensive report
1513    pub async fn generate_report(&self) -> Result<TrainingDynamicsReport> {
1514        let mut temp_analyzer = TrainingDynamicsAnalyzer {
1515            config: self.config.clone(),
1516            metrics_history: self.metrics_history.clone(),
1517            analysis_cache: HashMap::new(),
1518        };
1519
1520        temp_analyzer.analyze().await
1521    }
1522
1523    /// Clear all recorded metrics
1524    pub fn clear(&mut self) {
1525        self.metrics_history.clear();
1526        self.analysis_cache.clear();
1527    }
1528
1529    /// Get summary of current training state
1530    pub fn get_training_summary(&self) -> TrainingStateSummary {
1531        let current_metrics = self.metrics_history.back();
1532
1533        TrainingStateSummary {
1534            total_epochs: current_metrics.map(|m| m.epoch).unwrap_or(0),
1535            total_steps: current_metrics.map(|m| m.step).unwrap_or(0),
1536            current_loss: current_metrics.map(|m| m.train_loss).unwrap_or(0.0),
1537            current_lr: current_metrics.map(|m| m.learning_rate).unwrap_or(0.0),
1538            metrics_collected: self.metrics_history.len(),
1539        }
1540    }
1541}
1542
1543/// Summary of current training state
1544#[derive(Debug, Clone, Serialize, Deserialize)]
1545pub struct TrainingStateSummary {
1546    pub total_epochs: usize,
1547    pub total_steps: usize,
1548    pub current_loss: f32,
1549    pub current_lr: f32,
1550    pub metrics_collected: usize,
1551}