Skip to main content

trustformers_training/
adaptive_learning_rate.rs

1//! Adaptive Learning Rate Schedulers for Dynamic Training Optimization
2//!
3//! This module implements advanced learning rate scheduling techniques that automatically
4//! adjust learning rates based on real-time training dynamics and performance metrics.
5
6use anyhow::Result;
7use log;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11/// Configuration for adaptive learning rate scheduling
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AdaptiveLearningRateConfig {
14    /// Initial learning rate
15    pub initial_lr: f32,
16    /// Minimum learning rate
17    pub min_lr: f32,
18    /// Maximum learning rate
19    pub max_lr: f32,
20    /// Enable loss-based adaptation
21    pub loss_based_adaptation: bool,
22    /// Enable gradient-based adaptation
23    pub gradient_based_adaptation: bool,
24    /// Enable plateau detection
25    pub plateau_detection: bool,
26    /// Patience for plateau detection (steps)
27    pub plateau_patience: usize,
28    /// Plateau threshold (relative improvement)
29    pub plateau_threshold: f32,
30    /// Adaptation factor for reductions
31    pub reduction_factor: f32,
32    /// Adaptation factor for increases
33    pub increase_factor: f32,
34    /// Window size for trend analysis
35    pub trend_window: usize,
36    /// Momentum for exponential moving averages
37    pub momentum: f32,
38    /// Enable cyclical learning rates
39    pub cyclical_lr: bool,
40    /// Cycle length for cyclical LR
41    pub cycle_length: usize,
42    /// Enable learning rate range test mode
43    pub lr_range_test: bool,
44}
45
46impl Default for AdaptiveLearningRateConfig {
47    fn default() -> Self {
48        Self {
49            initial_lr: 1e-3,
50            min_lr: 1e-7,
51            max_lr: 1e-1,
52            loss_based_adaptation: true,
53            gradient_based_adaptation: true,
54            plateau_detection: true,
55            plateau_patience: 50,
56            plateau_threshold: 1e-4,
57            reduction_factor: 0.5,
58            increase_factor: 1.1,
59            trend_window: 20,
60            momentum: 0.9,
61            cyclical_lr: false,
62            cycle_length: 1000,
63            lr_range_test: false,
64        }
65    }
66}
67
68/// Training dynamics for learning rate adaptation
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TrainingDynamics {
71    pub step: usize,
72    pub loss: f32,
73    pub gradient_norm: f32,
74    pub learning_rate: f32,
75    pub accuracy: Option<f32>,
76}
77
78/// Learning rate adaptation strategy
79#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum AdaptationStrategy {
81    ReduceOnPlateau,
82    CosineAnnealing,
83    ExponentialDecay,
84    PolynomialDecay,
85    CyclicalLR,
86    OneCycleLR,
87    GradientNormAdaptive,
88    LossVarianceAdaptive,
89    PerformanceBasedAdaptive,
90}
91
92/// Learning rate scheduler state
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SchedulerState {
95    pub current_lr: f32,
96    pub best_loss: f32,
97    pub plateau_counter: usize,
98    pub step_count: usize,
99    pub cycle_position: usize,
100    pub adaptation_history: VecDeque<f32>,
101    pub performance_trend: PerformanceTrend,
102}
103
104/// Performance trend analysis
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum PerformanceTrend {
107    Improving,
108    Stable,
109    Deteriorating,
110    Oscillating,
111    Unknown,
112}
113
114/// Adaptive learning rate scheduler
115pub struct AdaptiveLearningRateScheduler {
116    config: AdaptiveLearningRateConfig,
117    state: SchedulerState,
118    dynamics_history: VecDeque<TrainingDynamics>,
119    loss_ema: f32,
120    gradient_norm_ema: f32,
121    strategies: Vec<AdaptationStrategy>,
122    strategy_weights: HashMap<AdaptationStrategy, f32>,
123    /// Track strategy effectiveness over time
124    strategy_effectiveness: HashMap<AdaptationStrategy, f32>,
125    /// Emergency fallback when all strategies fail
126    emergency_mode: bool,
127}
128
129impl AdaptiveLearningRateScheduler {
130    pub fn new(config: AdaptiveLearningRateConfig) -> Self {
131        let state = SchedulerState {
132            current_lr: config.initial_lr,
133            best_loss: f32::INFINITY,
134            plateau_counter: 0,
135            step_count: 0,
136            cycle_position: 0,
137            adaptation_history: VecDeque::with_capacity(config.trend_window),
138            performance_trend: PerformanceTrend::Unknown,
139        };
140
141        let strategies = vec![
142            AdaptationStrategy::ReduceOnPlateau,
143            AdaptationStrategy::GradientNormAdaptive,
144            AdaptationStrategy::LossVarianceAdaptive,
145        ];
146
147        let strategy_weights =
148            strategies.iter().map(|s| (s.clone(), 1.0 / strategies.len() as f32)).collect();
149
150        let strategy_effectiveness = strategies.iter()
151            .map(|s| (s.clone(), 0.5)) // Initialize with neutral effectiveness
152            .collect();
153
154        Self {
155            config,
156            state,
157            dynamics_history: VecDeque::with_capacity(1000),
158            loss_ema: 0.0,
159            gradient_norm_ema: 0.0,
160            strategies,
161            strategy_weights,
162            strategy_effectiveness,
163            emergency_mode: false,
164        }
165    }
166
167    /// Update learning rate based on current training dynamics
168    pub fn step(&mut self, dynamics: TrainingDynamics) -> Result<LearningRateUpdate> {
169        // Validate input dynamics
170        if !dynamics.loss.is_finite() || !dynamics.gradient_norm.is_finite() {
171            log::warn!(
172                "Invalid training dynamics: loss={}, grad_norm={}",
173                dynamics.loss,
174                dynamics.gradient_norm
175            );
176            return Err(anyhow::anyhow!("Invalid training dynamics detected"));
177        }
178
179        self.state.step_count += 1;
180
181        // Update exponential moving averages with validation
182        if self.state.step_count == 1 {
183            self.loss_ema = dynamics.loss;
184            self.gradient_norm_ema = dynamics.gradient_norm;
185        } else {
186            // Robust EMA update with bounds checking
187            let loss_update = (1.0 - self.config.momentum) * dynamics.loss;
188            let grad_update = (1.0 - self.config.momentum) * dynamics.gradient_norm;
189
190            if loss_update.is_finite() {
191                self.loss_ema = self.config.momentum * self.loss_ema + loss_update;
192            }
193
194            if grad_update.is_finite() {
195                self.gradient_norm_ema =
196                    self.config.momentum * self.gradient_norm_ema + grad_update;
197            }
198        }
199
200        // Store dynamics
201        self.dynamics_history.push_back(dynamics.clone());
202        if self.dynamics_history.len() > self.dynamics_history.capacity() {
203            self.dynamics_history.pop_front();
204        }
205
206        // Analyze performance trend
207        self.state.performance_trend = self.analyze_performance_trend();
208
209        // Compute adaptive learning rate with safety checks
210        let new_lr = match self.compute_adaptive_learning_rate(&dynamics) {
211            Ok(lr) if lr.is_finite() && lr > 0.0 => lr,
212            Ok(lr) => {
213                log::warn!("Invalid learning rate computed: {}. Using current LR.", lr);
214                self.state.current_lr
215            },
216            Err(e) => {
217                log::error!(
218                    "Failed to compute adaptive learning rate: {}. Using current LR.",
219                    e
220                );
221                self.state.current_lr
222            },
223        };
224
225        let old_lr = self.state.current_lr;
226        self.state.current_lr = new_lr.clamp(self.config.min_lr, self.config.max_lr);
227
228        // Additional safety check
229        if !self.state.current_lr.is_finite() {
230            log::error!("Learning rate became non-finite. Resetting to initial LR.");
231            self.state.current_lr = self.config.initial_lr;
232        }
233
234        // Update adaptation history
235        let adaptation_ratio = self.state.current_lr / old_lr;
236        self.state.adaptation_history.push_back(adaptation_ratio);
237        if self.state.adaptation_history.len() > self.config.trend_window {
238            self.state.adaptation_history.pop_front();
239        }
240
241        // Update plateau detection
242        if dynamics.loss < self.state.best_loss - self.config.plateau_threshold {
243            self.state.best_loss = dynamics.loss;
244            self.state.plateau_counter = 0;
245        } else {
246            self.state.plateau_counter += 1;
247        }
248
249        // Update cycle position for cyclical learning rates
250        if self.config.cyclical_lr {
251            self.state.cycle_position = (self.state.cycle_position + 1) % self.config.cycle_length;
252        }
253
254        // Update strategy effectiveness based on performance improvement
255        self.update_strategy_effectiveness(&dynamics, old_lr, self.state.current_lr);
256
257        // Check if we need emergency mode
258        if self.should_enter_emergency_mode() {
259            self.emergency_mode = true;
260            self.state.current_lr = self.config.initial_lr * 0.1; // Conservative emergency LR
261            log::warn!("Entering emergency mode - using conservative learning rate");
262        } else if self.emergency_mode && self.can_exit_emergency_mode() {
263            self.emergency_mode = false;
264            log::info!("Exiting emergency mode - performance stabilized");
265        }
266
267        Ok(LearningRateUpdate {
268            old_lr,
269            new_lr: self.state.current_lr,
270            adaptation_reason: self.get_adaptation_reason(),
271            strategy_contributions: self.compute_strategy_contributions(&dynamics)?,
272            confidence: self.compute_adaptation_confidence(),
273            dynamics: dynamics.clone(),
274        })
275    }
276
277    /// Get current learning rate
278    pub fn get_lr(&self) -> f32 {
279        self.state.current_lr
280    }
281
282    /// Get scheduler state
283    pub fn get_state(&self) -> &SchedulerState {
284        &self.state
285    }
286
287    /// Get comprehensive statistics
288    pub fn get_statistics(&self) -> AdaptiveLRStatistics {
289        AdaptiveLRStatistics {
290            current_lr: self.state.current_lr,
291            steps_taken: self.state.step_count,
292            adaptations_made: self.count_adaptations(),
293            performance_trend: self.state.performance_trend.clone(),
294            plateau_detected: self.state.plateau_counter >= self.config.plateau_patience,
295            loss_ema: self.loss_ema,
296            gradient_norm_ema: self.gradient_norm_ema,
297            adaptation_frequency: self.compute_adaptation_frequency(),
298            stability_score: self.compute_stability_score(),
299        }
300    }
301
302    // Private helper methods
303    fn compute_adaptive_learning_rate(&mut self, dynamics: &TrainingDynamics) -> Result<f32> {
304        let mut contributions = Vec::new();
305
306        for strategy in &self.strategies {
307            if let Some(weight) = self.strategy_weights.get(strategy) {
308                let contribution = self.compute_strategy_contribution(strategy, dynamics)? * weight;
309                contributions.push(contribution);
310            }
311        }
312
313        // Weighted average of strategy contributions
314        let adaptive_factor: f32 = contributions.iter().sum::<f32>() / contributions.len() as f32;
315
316        // Apply cyclical learning rate if enabled
317        let base_lr = if self.config.cyclical_lr {
318            self.compute_cyclical_lr()
319        } else {
320            self.state.current_lr
321        };
322
323        Ok(base_lr * adaptive_factor)
324    }
325
326    fn compute_strategy_contribution(
327        &self,
328        strategy: &AdaptationStrategy,
329        dynamics: &TrainingDynamics,
330    ) -> Result<f32> {
331        let contribution = match strategy {
332            AdaptationStrategy::ReduceOnPlateau => {
333                let plateau_severity = (self.state.plateau_counter as f32
334                    / self.config.plateau_patience as f32)
335                    .min(2.0);
336                if self.state.plateau_counter >= self.config.plateau_patience {
337                    // Gradual reduction based on plateau severity
338                    self.config.reduction_factor.powf(plateau_severity * 0.5)
339                } else {
340                    1.0
341                }
342            },
343            AdaptationStrategy::GradientNormAdaptive => {
344                let grad_ratio = dynamics.gradient_norm / self.gradient_norm_ema.max(1e-8);
345
346                // Smooth adaptation based on gradient ratio
347                if grad_ratio > 2.0 {
348                    let severity = (grad_ratio / 2.0 - 1.0).min(2.0);
349                    self.config.reduction_factor.powf(severity * 0.3)
350                } else if grad_ratio < 0.5 {
351                    let boost = (1.0 - grad_ratio * 2.0).min(1.0);
352                    self.config.increase_factor.powf(boost * 0.2)
353                } else {
354                    // Smooth transition in the middle range
355                    1.0 + (grad_ratio - 1.0) * 0.1
356                }
357            },
358            AdaptationStrategy::LossVarianceAdaptive => {
359                if self.dynamics_history.len() < 10 {
360                    return Ok(1.0);
361                }
362
363                let recent_losses: Vec<f32> =
364                    self.dynamics_history.iter().rev().take(10).map(|d| d.loss).collect();
365
366                let variance = self.compute_variance(&recent_losses);
367                let cv = variance.sqrt() / self.loss_ema.max(1e-8);
368
369                // Smoother variance-based adaptation
370                if cv > 0.1 {
371                    let instability = (cv - 0.1) / 0.1;
372                    self.config.reduction_factor.powf(instability.min(1.0) * 0.5)
373                } else if cv < 0.05 {
374                    // Very stable -> can increase slightly
375                    let stability = (0.05 - cv) / 0.05;
376                    self.config.increase_factor.powf(stability * 0.1)
377                } else {
378                    1.0
379                }
380            },
381            AdaptationStrategy::PerformanceBasedAdaptive => {
382                match self.state.performance_trend {
383                    PerformanceTrend::Improving => {
384                        // Conservative increase for improving performance
385                        self.config.increase_factor.powf(0.3)
386                    },
387                    PerformanceTrend::Deteriorating => {
388                        // More aggressive reduction for deteriorating performance
389                        self.config.reduction_factor.powf(0.7)
390                    },
391                    PerformanceTrend::Oscillating => {
392                        // Stabilize oscillations with slight reduction
393                        self.config.reduction_factor.powf(0.2)
394                    },
395                    _ => 1.0,
396                }
397            },
398            _ => 1.0, // Default: no change
399        };
400
401        // Ensure contribution is valid and within reasonable bounds
402        if contribution.is_finite() && contribution > 0.0 {
403            Ok(contribution.clamp(0.1, 10.0))
404        } else {
405            log::warn!(
406                "Invalid strategy contribution computed for {:?}: {}",
407                strategy,
408                contribution
409            );
410            Ok(1.0)
411        }
412    }
413
414    fn compute_cyclical_lr(&self) -> f32 {
415        let cycle_progress = self.state.cycle_position as f32 / self.config.cycle_length as f32;
416        let lr_range = self.config.max_lr - self.config.min_lr;
417
418        // Triangular cyclical learning rate
419        if cycle_progress < 0.5 {
420            self.config.min_lr + lr_range * (2.0 * cycle_progress)
421        } else {
422            self.config.max_lr - lr_range * (2.0 * (cycle_progress - 0.5))
423        }
424    }
425
426    fn analyze_performance_trend(&self) -> PerformanceTrend {
427        if self.dynamics_history.len() < self.config.trend_window {
428            return PerformanceTrend::Unknown;
429        }
430
431        // Take the most recent entries and restore chronological order (oldest to newest)
432        // This ensures slope is negative when loss is decreasing over time
433        let mut recent_losses: Vec<f32> = self
434            .dynamics_history
435            .iter()
436            .rev()
437            .take(self.config.trend_window)
438            .map(|d| d.loss)
439            .collect();
440        recent_losses.reverse(); // Restore chronological order
441
442        let slope = self.compute_slope(&recent_losses);
443        let variance = self.compute_variance(&recent_losses);
444
445        if variance > 0.1 {
446            PerformanceTrend::Oscillating
447        } else if slope < -0.01 {
448            PerformanceTrend::Improving
449        } else if slope > 0.01 {
450            PerformanceTrend::Deteriorating
451        } else {
452            PerformanceTrend::Stable
453        }
454    }
455
456    fn compute_slope(&self, values: &[f32]) -> f32 {
457        if values.len() < 2 {
458            return 0.0;
459        }
460
461        // Filter out invalid values first
462        let valid_pairs: Vec<(f32, f32)> = values
463            .iter()
464            .enumerate()
465            .filter_map(
466                |(i, &y)| {
467                    if y.is_finite() {
468                        Some((i as f32, y))
469                    } else {
470                        None
471                    }
472                },
473            )
474            .collect();
475
476        if valid_pairs.len() < 2 {
477            return 0.0;
478        }
479
480        let n = valid_pairs.len() as f32;
481        let sum_x: f32 = valid_pairs.iter().map(|(x, _)| x).sum();
482        let sum_y: f32 = valid_pairs.iter().map(|(_, y)| y).sum();
483        let sum_xy: f32 = valid_pairs.iter().map(|(x, y)| x * y).sum();
484        let sum_x2: f32 = valid_pairs.iter().map(|(x, _)| x * x).sum();
485
486        let denominator = n * sum_x2 - sum_x * sum_x;
487
488        if denominator.abs() < 1e-10 {
489            return 0.0; // Avoid division by zero
490        }
491
492        (n * sum_xy - sum_x * sum_y) / denominator
493    }
494
495    fn compute_variance(&self, values: &[f32]) -> f32 {
496        if values.len() <= 1 {
497            return 0.0;
498        }
499
500        // Use Welford's online algorithm for numerical stability
501        let mut mean = 0.0;
502        let mut m2 = 0.0;
503
504        for (i, &value) in values.iter().enumerate() {
505            if !value.is_finite() {
506                continue; // Skip invalid values
507            }
508
509            let delta = value - mean;
510            mean += delta / (i + 1) as f32;
511            let delta2 = value - mean;
512            m2 += delta * delta2;
513        }
514
515        if values.len() > 1 {
516            m2 / (values.len() - 1) as f32
517        } else {
518            0.0
519        }
520    }
521
522    fn get_adaptation_reason(&self) -> String {
523        if self.state.plateau_counter >= self.config.plateau_patience {
524            "Plateau detected".to_string()
525        } else if matches!(
526            self.state.performance_trend,
527            PerformanceTrend::Deteriorating
528        ) {
529            "Performance deteriorating".to_string()
530        } else if matches!(self.state.performance_trend, PerformanceTrend::Improving) {
531            "Performance improving".to_string()
532        } else {
533            "Routine adaptation".to_string()
534        }
535    }
536
537    fn compute_strategy_contributions(
538        &self,
539        dynamics: &TrainingDynamics,
540    ) -> Result<HashMap<AdaptationStrategy, f32>> {
541        let mut contributions = HashMap::new();
542
543        for strategy in &self.strategies {
544            let contribution = self.compute_strategy_contribution(strategy, dynamics)?;
545            contributions.insert(strategy.clone(), contribution);
546        }
547
548        Ok(contributions)
549    }
550
551    fn compute_adaptation_confidence(&self) -> f32 {
552        // Confidence based on trend consistency and data quality
553        let trend_consistency = if self.dynamics_history.len() >= self.config.trend_window {
554            0.8
555        } else {
556            self.dynamics_history.len() as f32 / self.config.trend_window as f32
557        };
558
559        let data_quality =
560            if self.loss_ema > 0.0 && !self.loss_ema.is_infinite() { 0.9 } else { 0.5 };
561
562        (trend_consistency * data_quality).min(1.0)
563    }
564
565    fn count_adaptations(&self) -> usize {
566        self.state
567            .adaptation_history
568            .iter()
569            .filter(|&&ratio| (ratio - 1.0).abs() > 0.01)
570            .count()
571    }
572
573    fn compute_adaptation_frequency(&self) -> f32 {
574        if self.state.step_count == 0 {
575            return 0.0;
576        }
577
578        self.count_adaptations() as f32 / self.state.step_count as f32
579    }
580
581    fn compute_stability_score(&self) -> f32 {
582        if self.state.adaptation_history.is_empty() {
583            return 1.0;
584        }
585
586        let variance = self
587            .compute_variance(&self.state.adaptation_history.iter().cloned().collect::<Vec<_>>());
588        (1.0 / (1.0 + variance)).clamp(0.0, 1.0)
589    }
590
591    /// Update effectiveness scores for strategies based on performance
592    fn update_strategy_effectiveness(
593        &mut self,
594        dynamics: &TrainingDynamics,
595        old_lr: f32,
596        new_lr: f32,
597    ) {
598        // Simple effectiveness metric based on loss improvement and LR change correlation
599        if self.dynamics_history.len() >= 2 {
600            let prev_loss = self.dynamics_history.back().map(|d| d.loss).unwrap_or(dynamics.loss);
601            let loss_improvement =
602                if prev_loss > 0.0 { (prev_loss - dynamics.loss) / prev_loss } else { 0.0 };
603
604            let lr_change_magnitude = (new_lr / old_lr - 1.0).abs();
605
606            // Reward strategies that contribute to improvements without excessive LR changes
607            let base_effectiveness = if loss_improvement > 0.0 {
608                (loss_improvement * 10.0).min(1.0)
609            } else {
610                0.2 // Small penalty for no improvement
611            };
612
613            // Penalize excessive LR changes
614            let stability_bonus =
615                if lr_change_magnitude < 0.1 { 0.1 } else { -lr_change_magnitude * 0.5 };
616
617            let overall_effectiveness = (base_effectiveness + stability_bonus).clamp(0.0, 1.0);
618
619            // Update strategy effectiveness with exponential moving average
620            for (_strategy, effectiveness) in self.strategy_effectiveness.iter_mut() {
621                let learning_rate = 0.1;
622                *effectiveness =
623                    learning_rate * overall_effectiveness + (1.0 - learning_rate) * *effectiveness;
624            }
625        }
626    }
627
628    /// Check if we should enter emergency mode
629    fn should_enter_emergency_mode(&self) -> bool {
630        if self.emergency_mode {
631            return false; // Already in emergency mode
632        }
633
634        // Enter emergency mode if:
635        // 1. Recent loss has exploded
636        // 2. Multiple consecutive bad adaptations
637        // 3. All strategies are performing poorly
638
639        let recent_loss_explosion = self.dynamics_history.len() >= 2 && {
640            let recent_losses: Vec<f32> =
641                self.dynamics_history.iter().rev().take(3).map(|d| d.loss).collect();
642            recent_losses.windows(2).any(|w| w[0] > w[1] * 5.0)
643        };
644
645        let poor_strategy_performance = self.strategy_effectiveness.values().all(|&eff| eff < 0.3);
646
647        let high_variance = if self.dynamics_history.len() >= 10 {
648            let recent_losses: Vec<f32> =
649                self.dynamics_history.iter().rev().take(10).map(|d| d.loss).collect();
650            let variance = self.compute_variance(&recent_losses);
651            let cv = variance.sqrt() / self.loss_ema.max(1e-8);
652            cv > 0.5
653        } else {
654            false
655        };
656
657        recent_loss_explosion || (poor_strategy_performance && high_variance)
658    }
659
660    /// Check if we can exit emergency mode
661    fn can_exit_emergency_mode(&self) -> bool {
662        if !self.emergency_mode {
663            return false;
664        }
665
666        // Exit emergency mode if performance has stabilized
667        let stable_loss = if self.dynamics_history.len() >= 5 {
668            let recent_losses: Vec<f32> =
669                self.dynamics_history.iter().rev().take(5).map(|d| d.loss).collect();
670            let variance = self.compute_variance(&recent_losses);
671            let cv = variance.sqrt() / self.loss_ema.max(1e-8);
672            cv < 0.1
673        } else {
674            false
675        };
676
677        let improving_trend = matches!(
678            self.state.performance_trend,
679            PerformanceTrend::Improving | PerformanceTrend::Stable
680        );
681
682        stable_loss && improving_trend
683    }
684}
685
686/// Learning rate update result
687#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct LearningRateUpdate {
689    pub old_lr: f32,
690    pub new_lr: f32,
691    pub adaptation_reason: String,
692    pub strategy_contributions: HashMap<AdaptationStrategy, f32>,
693    pub confidence: f32,
694    pub dynamics: TrainingDynamics,
695}
696
697/// Comprehensive learning rate statistics
698#[derive(Debug, Clone, Serialize, Deserialize)]
699pub struct AdaptiveLRStatistics {
700    pub current_lr: f32,
701    pub steps_taken: usize,
702    pub adaptations_made: usize,
703    pub performance_trend: PerformanceTrend,
704    pub plateau_detected: bool,
705    pub loss_ema: f32,
706    pub gradient_norm_ema: f32,
707    pub adaptation_frequency: f32,
708    pub stability_score: f32,
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714
715    #[test]
716    fn test_adaptive_lr_scheduler_creation() {
717        let config = AdaptiveLearningRateConfig::default();
718        let scheduler = AdaptiveLearningRateScheduler::new(config.clone());
719        assert_eq!(scheduler.get_lr(), config.initial_lr);
720    }
721
722    #[test]
723    fn test_learning_rate_adaptation() {
724        let config = AdaptiveLearningRateConfig::default();
725        let mut scheduler = AdaptiveLearningRateScheduler::new(config);
726
727        let dynamics = TrainingDynamics {
728            step: 1,
729            loss: 1.0,
730            gradient_norm: 0.5,
731            learning_rate: 1e-3,
732            accuracy: Some(0.8),
733        };
734
735        let update = scheduler.step(dynamics).expect("operation failed in test");
736        assert!(update.new_lr > 0.0);
737        assert!(!update.adaptation_reason.is_empty());
738    }
739
740    #[test]
741    fn test_plateau_detection() {
742        let config = AdaptiveLearningRateConfig {
743            plateau_patience: 3,
744            ..AdaptiveLearningRateConfig::default()
745        };
746        let mut scheduler = AdaptiveLearningRateScheduler::new(config);
747
748        // Simulate plateau by using same loss repeatedly
749        for i in 1..=5 {
750            let dynamics = TrainingDynamics {
751                step: i,
752                loss: 1.0, // Same loss
753                gradient_norm: 0.5,
754                learning_rate: scheduler.get_lr(),
755                accuracy: None,
756            };
757            scheduler.step(dynamics).expect("operation failed in test");
758        }
759
760        let stats = scheduler.get_statistics();
761        assert!(stats.plateau_detected);
762    }
763
764    #[test]
765    fn test_performance_trend_analysis() {
766        let config = AdaptiveLearningRateConfig::default();
767        let mut scheduler = AdaptiveLearningRateScheduler::new(config);
768
769        // Simulate improving performance
770        for i in 1..=25 {
771            let dynamics = TrainingDynamics {
772                step: i,
773                loss: 2.0 - (i as f32) * 0.05, // Decreasing loss
774                gradient_norm: 0.5,
775                learning_rate: scheduler.get_lr(),
776                accuracy: None,
777            };
778            scheduler.step(dynamics).expect("operation failed in test");
779        }
780
781        let stats = scheduler.get_statistics();
782        assert!(matches!(
783            stats.performance_trend,
784            PerformanceTrend::Improving
785        ));
786    }
787
788    #[test]
789    fn test_cyclical_learning_rate() {
790        let config = AdaptiveLearningRateConfig {
791            cyclical_lr: true,
792            cycle_length: 10,
793            ..AdaptiveLearningRateConfig::default()
794        };
795        let scheduler = AdaptiveLearningRateScheduler::new(config);
796
797        let cyclical_lr = scheduler.compute_cyclical_lr();
798        assert!(cyclical_lr >= scheduler.config.min_lr);
799        assert!(cyclical_lr <= scheduler.config.max_lr);
800    }
801}