1use anyhow::Result;
7use log;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AdaptiveLearningRateConfig {
14 pub initial_lr: f32,
16 pub min_lr: f32,
18 pub max_lr: f32,
20 pub loss_based_adaptation: bool,
22 pub gradient_based_adaptation: bool,
24 pub plateau_detection: bool,
26 pub plateau_patience: usize,
28 pub plateau_threshold: f32,
30 pub reduction_factor: f32,
32 pub increase_factor: f32,
34 pub trend_window: usize,
36 pub momentum: f32,
38 pub cyclical_lr: bool,
40 pub cycle_length: usize,
42 pub lr_range_test: bool,
44}
45
46impl Default for AdaptiveLearningRateConfig {
47 fn default() -> Self {
48 Self {
49 initial_lr: 1e-3,
50 min_lr: 1e-7,
51 max_lr: 1e-1,
52 loss_based_adaptation: true,
53 gradient_based_adaptation: true,
54 plateau_detection: true,
55 plateau_patience: 50,
56 plateau_threshold: 1e-4,
57 reduction_factor: 0.5,
58 increase_factor: 1.1,
59 trend_window: 20,
60 momentum: 0.9,
61 cyclical_lr: false,
62 cycle_length: 1000,
63 lr_range_test: false,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TrainingDynamics {
71 pub step: usize,
72 pub loss: f32,
73 pub gradient_norm: f32,
74 pub learning_rate: f32,
75 pub accuracy: Option<f32>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum AdaptationStrategy {
81 ReduceOnPlateau,
82 CosineAnnealing,
83 ExponentialDecay,
84 PolynomialDecay,
85 CyclicalLR,
86 OneCycleLR,
87 GradientNormAdaptive,
88 LossVarianceAdaptive,
89 PerformanceBasedAdaptive,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SchedulerState {
95 pub current_lr: f32,
96 pub best_loss: f32,
97 pub plateau_counter: usize,
98 pub step_count: usize,
99 pub cycle_position: usize,
100 pub adaptation_history: VecDeque<f32>,
101 pub performance_trend: PerformanceTrend,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum PerformanceTrend {
107 Improving,
108 Stable,
109 Deteriorating,
110 Oscillating,
111 Unknown,
112}
113
114pub struct AdaptiveLearningRateScheduler {
116 config: AdaptiveLearningRateConfig,
117 state: SchedulerState,
118 dynamics_history: VecDeque<TrainingDynamics>,
119 loss_ema: f32,
120 gradient_norm_ema: f32,
121 strategies: Vec<AdaptationStrategy>,
122 strategy_weights: HashMap<AdaptationStrategy, f32>,
123 strategy_effectiveness: HashMap<AdaptationStrategy, f32>,
125 emergency_mode: bool,
127}
128
129impl AdaptiveLearningRateScheduler {
130 pub fn new(config: AdaptiveLearningRateConfig) -> Self {
131 let state = SchedulerState {
132 current_lr: config.initial_lr,
133 best_loss: f32::INFINITY,
134 plateau_counter: 0,
135 step_count: 0,
136 cycle_position: 0,
137 adaptation_history: VecDeque::with_capacity(config.trend_window),
138 performance_trend: PerformanceTrend::Unknown,
139 };
140
141 let strategies = vec![
142 AdaptationStrategy::ReduceOnPlateau,
143 AdaptationStrategy::GradientNormAdaptive,
144 AdaptationStrategy::LossVarianceAdaptive,
145 ];
146
147 let strategy_weights =
148 strategies.iter().map(|s| (s.clone(), 1.0 / strategies.len() as f32)).collect();
149
150 let strategy_effectiveness = strategies.iter()
151 .map(|s| (s.clone(), 0.5)) .collect();
153
154 Self {
155 config,
156 state,
157 dynamics_history: VecDeque::with_capacity(1000),
158 loss_ema: 0.0,
159 gradient_norm_ema: 0.0,
160 strategies,
161 strategy_weights,
162 strategy_effectiveness,
163 emergency_mode: false,
164 }
165 }
166
167 pub fn step(&mut self, dynamics: TrainingDynamics) -> Result<LearningRateUpdate> {
169 if !dynamics.loss.is_finite() || !dynamics.gradient_norm.is_finite() {
171 log::warn!(
172 "Invalid training dynamics: loss={}, grad_norm={}",
173 dynamics.loss,
174 dynamics.gradient_norm
175 );
176 return Err(anyhow::anyhow!("Invalid training dynamics detected"));
177 }
178
179 self.state.step_count += 1;
180
181 if self.state.step_count == 1 {
183 self.loss_ema = dynamics.loss;
184 self.gradient_norm_ema = dynamics.gradient_norm;
185 } else {
186 let loss_update = (1.0 - self.config.momentum) * dynamics.loss;
188 let grad_update = (1.0 - self.config.momentum) * dynamics.gradient_norm;
189
190 if loss_update.is_finite() {
191 self.loss_ema = self.config.momentum * self.loss_ema + loss_update;
192 }
193
194 if grad_update.is_finite() {
195 self.gradient_norm_ema =
196 self.config.momentum * self.gradient_norm_ema + grad_update;
197 }
198 }
199
200 self.dynamics_history.push_back(dynamics.clone());
202 if self.dynamics_history.len() > self.dynamics_history.capacity() {
203 self.dynamics_history.pop_front();
204 }
205
206 self.state.performance_trend = self.analyze_performance_trend();
208
209 let new_lr = match self.compute_adaptive_learning_rate(&dynamics) {
211 Ok(lr) if lr.is_finite() && lr > 0.0 => lr,
212 Ok(lr) => {
213 log::warn!("Invalid learning rate computed: {}. Using current LR.", lr);
214 self.state.current_lr
215 },
216 Err(e) => {
217 log::error!(
218 "Failed to compute adaptive learning rate: {}. Using current LR.",
219 e
220 );
221 self.state.current_lr
222 },
223 };
224
225 let old_lr = self.state.current_lr;
226 self.state.current_lr = new_lr.clamp(self.config.min_lr, self.config.max_lr);
227
228 if !self.state.current_lr.is_finite() {
230 log::error!("Learning rate became non-finite. Resetting to initial LR.");
231 self.state.current_lr = self.config.initial_lr;
232 }
233
234 let adaptation_ratio = self.state.current_lr / old_lr;
236 self.state.adaptation_history.push_back(adaptation_ratio);
237 if self.state.adaptation_history.len() > self.config.trend_window {
238 self.state.adaptation_history.pop_front();
239 }
240
241 if dynamics.loss < self.state.best_loss - self.config.plateau_threshold {
243 self.state.best_loss = dynamics.loss;
244 self.state.plateau_counter = 0;
245 } else {
246 self.state.plateau_counter += 1;
247 }
248
249 if self.config.cyclical_lr {
251 self.state.cycle_position = (self.state.cycle_position + 1) % self.config.cycle_length;
252 }
253
254 self.update_strategy_effectiveness(&dynamics, old_lr, self.state.current_lr);
256
257 if self.should_enter_emergency_mode() {
259 self.emergency_mode = true;
260 self.state.current_lr = self.config.initial_lr * 0.1; log::warn!("Entering emergency mode - using conservative learning rate");
262 } else if self.emergency_mode && self.can_exit_emergency_mode() {
263 self.emergency_mode = false;
264 log::info!("Exiting emergency mode - performance stabilized");
265 }
266
267 Ok(LearningRateUpdate {
268 old_lr,
269 new_lr: self.state.current_lr,
270 adaptation_reason: self.get_adaptation_reason(),
271 strategy_contributions: self.compute_strategy_contributions(&dynamics)?,
272 confidence: self.compute_adaptation_confidence(),
273 dynamics: dynamics.clone(),
274 })
275 }
276
277 pub fn get_lr(&self) -> f32 {
279 self.state.current_lr
280 }
281
282 pub fn get_state(&self) -> &SchedulerState {
284 &self.state
285 }
286
287 pub fn get_statistics(&self) -> AdaptiveLRStatistics {
289 AdaptiveLRStatistics {
290 current_lr: self.state.current_lr,
291 steps_taken: self.state.step_count,
292 adaptations_made: self.count_adaptations(),
293 performance_trend: self.state.performance_trend.clone(),
294 plateau_detected: self.state.plateau_counter >= self.config.plateau_patience,
295 loss_ema: self.loss_ema,
296 gradient_norm_ema: self.gradient_norm_ema,
297 adaptation_frequency: self.compute_adaptation_frequency(),
298 stability_score: self.compute_stability_score(),
299 }
300 }
301
302 fn compute_adaptive_learning_rate(&mut self, dynamics: &TrainingDynamics) -> Result<f32> {
304 let mut contributions = Vec::new();
305
306 for strategy in &self.strategies {
307 if let Some(weight) = self.strategy_weights.get(strategy) {
308 let contribution = self.compute_strategy_contribution(strategy, dynamics)? * weight;
309 contributions.push(contribution);
310 }
311 }
312
313 let adaptive_factor: f32 = contributions.iter().sum::<f32>() / contributions.len() as f32;
315
316 let base_lr = if self.config.cyclical_lr {
318 self.compute_cyclical_lr()
319 } else {
320 self.state.current_lr
321 };
322
323 Ok(base_lr * adaptive_factor)
324 }
325
326 fn compute_strategy_contribution(
327 &self,
328 strategy: &AdaptationStrategy,
329 dynamics: &TrainingDynamics,
330 ) -> Result<f32> {
331 let contribution = match strategy {
332 AdaptationStrategy::ReduceOnPlateau => {
333 let plateau_severity = (self.state.plateau_counter as f32
334 / self.config.plateau_patience as f32)
335 .min(2.0);
336 if self.state.plateau_counter >= self.config.plateau_patience {
337 self.config.reduction_factor.powf(plateau_severity * 0.5)
339 } else {
340 1.0
341 }
342 },
343 AdaptationStrategy::GradientNormAdaptive => {
344 let grad_ratio = dynamics.gradient_norm / self.gradient_norm_ema.max(1e-8);
345
346 if grad_ratio > 2.0 {
348 let severity = (grad_ratio / 2.0 - 1.0).min(2.0);
349 self.config.reduction_factor.powf(severity * 0.3)
350 } else if grad_ratio < 0.5 {
351 let boost = (1.0 - grad_ratio * 2.0).min(1.0);
352 self.config.increase_factor.powf(boost * 0.2)
353 } else {
354 1.0 + (grad_ratio - 1.0) * 0.1
356 }
357 },
358 AdaptationStrategy::LossVarianceAdaptive => {
359 if self.dynamics_history.len() < 10 {
360 return Ok(1.0);
361 }
362
363 let recent_losses: Vec<f32> =
364 self.dynamics_history.iter().rev().take(10).map(|d| d.loss).collect();
365
366 let variance = self.compute_variance(&recent_losses);
367 let cv = variance.sqrt() / self.loss_ema.max(1e-8);
368
369 if cv > 0.1 {
371 let instability = (cv - 0.1) / 0.1;
372 self.config.reduction_factor.powf(instability.min(1.0) * 0.5)
373 } else if cv < 0.05 {
374 let stability = (0.05 - cv) / 0.05;
376 self.config.increase_factor.powf(stability * 0.1)
377 } else {
378 1.0
379 }
380 },
381 AdaptationStrategy::PerformanceBasedAdaptive => {
382 match self.state.performance_trend {
383 PerformanceTrend::Improving => {
384 self.config.increase_factor.powf(0.3)
386 },
387 PerformanceTrend::Deteriorating => {
388 self.config.reduction_factor.powf(0.7)
390 },
391 PerformanceTrend::Oscillating => {
392 self.config.reduction_factor.powf(0.2)
394 },
395 _ => 1.0,
396 }
397 },
398 _ => 1.0, };
400
401 if contribution.is_finite() && contribution > 0.0 {
403 Ok(contribution.clamp(0.1, 10.0))
404 } else {
405 log::warn!(
406 "Invalid strategy contribution computed for {:?}: {}",
407 strategy,
408 contribution
409 );
410 Ok(1.0)
411 }
412 }
413
414 fn compute_cyclical_lr(&self) -> f32 {
415 let cycle_progress = self.state.cycle_position as f32 / self.config.cycle_length as f32;
416 let lr_range = self.config.max_lr - self.config.min_lr;
417
418 if cycle_progress < 0.5 {
420 self.config.min_lr + lr_range * (2.0 * cycle_progress)
421 } else {
422 self.config.max_lr - lr_range * (2.0 * (cycle_progress - 0.5))
423 }
424 }
425
426 fn analyze_performance_trend(&self) -> PerformanceTrend {
427 if self.dynamics_history.len() < self.config.trend_window {
428 return PerformanceTrend::Unknown;
429 }
430
431 let mut recent_losses: Vec<f32> = self
434 .dynamics_history
435 .iter()
436 .rev()
437 .take(self.config.trend_window)
438 .map(|d| d.loss)
439 .collect();
440 recent_losses.reverse(); let slope = self.compute_slope(&recent_losses);
443 let variance = self.compute_variance(&recent_losses);
444
445 if variance > 0.1 {
446 PerformanceTrend::Oscillating
447 } else if slope < -0.01 {
448 PerformanceTrend::Improving
449 } else if slope > 0.01 {
450 PerformanceTrend::Deteriorating
451 } else {
452 PerformanceTrend::Stable
453 }
454 }
455
456 fn compute_slope(&self, values: &[f32]) -> f32 {
457 if values.len() < 2 {
458 return 0.0;
459 }
460
461 let valid_pairs: Vec<(f32, f32)> = values
463 .iter()
464 .enumerate()
465 .filter_map(
466 |(i, &y)| {
467 if y.is_finite() {
468 Some((i as f32, y))
469 } else {
470 None
471 }
472 },
473 )
474 .collect();
475
476 if valid_pairs.len() < 2 {
477 return 0.0;
478 }
479
480 let n = valid_pairs.len() as f32;
481 let sum_x: f32 = valid_pairs.iter().map(|(x, _)| x).sum();
482 let sum_y: f32 = valid_pairs.iter().map(|(_, y)| y).sum();
483 let sum_xy: f32 = valid_pairs.iter().map(|(x, y)| x * y).sum();
484 let sum_x2: f32 = valid_pairs.iter().map(|(x, _)| x * x).sum();
485
486 let denominator = n * sum_x2 - sum_x * sum_x;
487
488 if denominator.abs() < 1e-10 {
489 return 0.0; }
491
492 (n * sum_xy - sum_x * sum_y) / denominator
493 }
494
495 fn compute_variance(&self, values: &[f32]) -> f32 {
496 if values.len() <= 1 {
497 return 0.0;
498 }
499
500 let mut mean = 0.0;
502 let mut m2 = 0.0;
503
504 for (i, &value) in values.iter().enumerate() {
505 if !value.is_finite() {
506 continue; }
508
509 let delta = value - mean;
510 mean += delta / (i + 1) as f32;
511 let delta2 = value - mean;
512 m2 += delta * delta2;
513 }
514
515 if values.len() > 1 {
516 m2 / (values.len() - 1) as f32
517 } else {
518 0.0
519 }
520 }
521
522 fn get_adaptation_reason(&self) -> String {
523 if self.state.plateau_counter >= self.config.plateau_patience {
524 "Plateau detected".to_string()
525 } else if matches!(
526 self.state.performance_trend,
527 PerformanceTrend::Deteriorating
528 ) {
529 "Performance deteriorating".to_string()
530 } else if matches!(self.state.performance_trend, PerformanceTrend::Improving) {
531 "Performance improving".to_string()
532 } else {
533 "Routine adaptation".to_string()
534 }
535 }
536
537 fn compute_strategy_contributions(
538 &self,
539 dynamics: &TrainingDynamics,
540 ) -> Result<HashMap<AdaptationStrategy, f32>> {
541 let mut contributions = HashMap::new();
542
543 for strategy in &self.strategies {
544 let contribution = self.compute_strategy_contribution(strategy, dynamics)?;
545 contributions.insert(strategy.clone(), contribution);
546 }
547
548 Ok(contributions)
549 }
550
551 fn compute_adaptation_confidence(&self) -> f32 {
552 let trend_consistency = if self.dynamics_history.len() >= self.config.trend_window {
554 0.8
555 } else {
556 self.dynamics_history.len() as f32 / self.config.trend_window as f32
557 };
558
559 let data_quality =
560 if self.loss_ema > 0.0 && !self.loss_ema.is_infinite() { 0.9 } else { 0.5 };
561
562 (trend_consistency * data_quality).min(1.0)
563 }
564
565 fn count_adaptations(&self) -> usize {
566 self.state
567 .adaptation_history
568 .iter()
569 .filter(|&&ratio| (ratio - 1.0).abs() > 0.01)
570 .count()
571 }
572
573 fn compute_adaptation_frequency(&self) -> f32 {
574 if self.state.step_count == 0 {
575 return 0.0;
576 }
577
578 self.count_adaptations() as f32 / self.state.step_count as f32
579 }
580
581 fn compute_stability_score(&self) -> f32 {
582 if self.state.adaptation_history.is_empty() {
583 return 1.0;
584 }
585
586 let variance = self
587 .compute_variance(&self.state.adaptation_history.iter().cloned().collect::<Vec<_>>());
588 (1.0 / (1.0 + variance)).clamp(0.0, 1.0)
589 }
590
591 fn update_strategy_effectiveness(
593 &mut self,
594 dynamics: &TrainingDynamics,
595 old_lr: f32,
596 new_lr: f32,
597 ) {
598 if self.dynamics_history.len() >= 2 {
600 let prev_loss = self.dynamics_history.back().map(|d| d.loss).unwrap_or(dynamics.loss);
601 let loss_improvement =
602 if prev_loss > 0.0 { (prev_loss - dynamics.loss) / prev_loss } else { 0.0 };
603
604 let lr_change_magnitude = (new_lr / old_lr - 1.0).abs();
605
606 let base_effectiveness = if loss_improvement > 0.0 {
608 (loss_improvement * 10.0).min(1.0)
609 } else {
610 0.2 };
612
613 let stability_bonus =
615 if lr_change_magnitude < 0.1 { 0.1 } else { -lr_change_magnitude * 0.5 };
616
617 let overall_effectiveness = (base_effectiveness + stability_bonus).clamp(0.0, 1.0);
618
619 for (_strategy, effectiveness) in self.strategy_effectiveness.iter_mut() {
621 let learning_rate = 0.1;
622 *effectiveness =
623 learning_rate * overall_effectiveness + (1.0 - learning_rate) * *effectiveness;
624 }
625 }
626 }
627
628 fn should_enter_emergency_mode(&self) -> bool {
630 if self.emergency_mode {
631 return false; }
633
634 let recent_loss_explosion = self.dynamics_history.len() >= 2 && {
640 let recent_losses: Vec<f32> =
641 self.dynamics_history.iter().rev().take(3).map(|d| d.loss).collect();
642 recent_losses.windows(2).any(|w| w[0] > w[1] * 5.0)
643 };
644
645 let poor_strategy_performance = self.strategy_effectiveness.values().all(|&eff| eff < 0.3);
646
647 let high_variance = if self.dynamics_history.len() >= 10 {
648 let recent_losses: Vec<f32> =
649 self.dynamics_history.iter().rev().take(10).map(|d| d.loss).collect();
650 let variance = self.compute_variance(&recent_losses);
651 let cv = variance.sqrt() / self.loss_ema.max(1e-8);
652 cv > 0.5
653 } else {
654 false
655 };
656
657 recent_loss_explosion || (poor_strategy_performance && high_variance)
658 }
659
660 fn can_exit_emergency_mode(&self) -> bool {
662 if !self.emergency_mode {
663 return false;
664 }
665
666 let stable_loss = if self.dynamics_history.len() >= 5 {
668 let recent_losses: Vec<f32> =
669 self.dynamics_history.iter().rev().take(5).map(|d| d.loss).collect();
670 let variance = self.compute_variance(&recent_losses);
671 let cv = variance.sqrt() / self.loss_ema.max(1e-8);
672 cv < 0.1
673 } else {
674 false
675 };
676
677 let improving_trend = matches!(
678 self.state.performance_trend,
679 PerformanceTrend::Improving | PerformanceTrend::Stable
680 );
681
682 stable_loss && improving_trend
683 }
684}
685
686#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct LearningRateUpdate {
689 pub old_lr: f32,
690 pub new_lr: f32,
691 pub adaptation_reason: String,
692 pub strategy_contributions: HashMap<AdaptationStrategy, f32>,
693 pub confidence: f32,
694 pub dynamics: TrainingDynamics,
695}
696
697#[derive(Debug, Clone, Serialize, Deserialize)]
699pub struct AdaptiveLRStatistics {
700 pub current_lr: f32,
701 pub steps_taken: usize,
702 pub adaptations_made: usize,
703 pub performance_trend: PerformanceTrend,
704 pub plateau_detected: bool,
705 pub loss_ema: f32,
706 pub gradient_norm_ema: f32,
707 pub adaptation_frequency: f32,
708 pub stability_score: f32,
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714
715 #[test]
716 fn test_adaptive_lr_scheduler_creation() {
717 let config = AdaptiveLearningRateConfig::default();
718 let scheduler = AdaptiveLearningRateScheduler::new(config.clone());
719 assert_eq!(scheduler.get_lr(), config.initial_lr);
720 }
721
722 #[test]
723 fn test_learning_rate_adaptation() {
724 let config = AdaptiveLearningRateConfig::default();
725 let mut scheduler = AdaptiveLearningRateScheduler::new(config);
726
727 let dynamics = TrainingDynamics {
728 step: 1,
729 loss: 1.0,
730 gradient_norm: 0.5,
731 learning_rate: 1e-3,
732 accuracy: Some(0.8),
733 };
734
735 let update = scheduler.step(dynamics).expect("operation failed in test");
736 assert!(update.new_lr > 0.0);
737 assert!(!update.adaptation_reason.is_empty());
738 }
739
740 #[test]
741 fn test_plateau_detection() {
742 let config = AdaptiveLearningRateConfig {
743 plateau_patience: 3,
744 ..AdaptiveLearningRateConfig::default()
745 };
746 let mut scheduler = AdaptiveLearningRateScheduler::new(config);
747
748 for i in 1..=5 {
750 let dynamics = TrainingDynamics {
751 step: i,
752 loss: 1.0, gradient_norm: 0.5,
754 learning_rate: scheduler.get_lr(),
755 accuracy: None,
756 };
757 scheduler.step(dynamics).expect("operation failed in test");
758 }
759
760 let stats = scheduler.get_statistics();
761 assert!(stats.plateau_detected);
762 }
763
764 #[test]
765 fn test_performance_trend_analysis() {
766 let config = AdaptiveLearningRateConfig::default();
767 let mut scheduler = AdaptiveLearningRateScheduler::new(config);
768
769 for i in 1..=25 {
771 let dynamics = TrainingDynamics {
772 step: i,
773 loss: 2.0 - (i as f32) * 0.05, gradient_norm: 0.5,
775 learning_rate: scheduler.get_lr(),
776 accuracy: None,
777 };
778 scheduler.step(dynamics).expect("operation failed in test");
779 }
780
781 let stats = scheduler.get_statistics();
782 assert!(matches!(
783 stats.performance_trend,
784 PerformanceTrend::Improving
785 ));
786 }
787
788 #[test]
789 fn test_cyclical_learning_rate() {
790 let config = AdaptiveLearningRateConfig {
791 cyclical_lr: true,
792 cycle_length: 10,
793 ..AdaptiveLearningRateConfig::default()
794 };
795 let scheduler = AdaptiveLearningRateScheduler::new(config);
796
797 let cyclical_lr = scheduler.compute_cyclical_lr();
798 assert!(cyclical_lr >= scheduler.config.min_lr);
799 assert!(cyclical_lr <= scheduler.config.max_lr);
800 }
801}