scirs2_metrics/integration/optim/
scheduler.rs

1//! Learning rate scheduler integration
2//!
3//! This module provides utilities for integrating learning rate schedulers
4//! with metrics.
5
6#[allow(unused_imports)]
7use crate::error::Result;
8use crate::integration::optim::OptimizationMode;
9use scirs2_core::numeric::{Float, FromPrimitive};
10use std::fmt;
11#[allow(unused_imports)]
12use std::marker::PhantomData;
13
14/// A metric-based learning rate scheduler
15#[derive(Debug, Clone)]
16pub struct MetricLRScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
17    /// Current learning rate
18    current_lr: F,
19    /// Initial learning rate
20    initial_lr: F,
21    /// Factor by which the learning rate will be reduced
22    factor: F,
23    /// Number of epochs with no improvement after which learning rate will be reduced
24    patience: usize,
25    /// Minimum learning rate
26    min_lr: F,
27    /// Counter for steps with no improvement
28    stagnation_count: usize,
29    /// Best metric value seen so far
30    best_metric: Option<F>,
31    /// Threshold for measuring improvement
32    threshold: F,
33    /// Optimization mode
34    mode: OptimizationMode,
35    /// Metric name
36    metric_name: String,
37    /// History of learning rates
38    history: Vec<F>,
39    /// History of metric values
40    metric_history: Vec<F>,
41}
42
43impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricLRScheduler<F> {
44    /// Create a new metric-based learning rate scheduler
45    pub fn new<S: Into<String>>(
46        initial_lr: F,
47        factor: F,
48        patience: usize,
49        min_lr: F,
50        metric_name: S,
51        maximize: bool,
52    ) -> Self {
53        Self {
54            current_lr: initial_lr,
55            initial_lr,
56            factor,
57            patience,
58            min_lr,
59            stagnation_count: 0,
60            best_metric: None,
61            threshold: F::from(1e-4).unwrap(),
62            mode: if maximize {
63                OptimizationMode::Maximize
64            } else {
65                OptimizationMode::Minimize
66            },
67            metric_name: metric_name.into(),
68            history: vec![initial_lr],
69            metric_history: Vec::new(),
70        }
71    }
72
73    /// Set the threshold for measuring improvement
74    pub fn set_threshold(&mut self, threshold: F) -> &mut Self {
75        self.threshold = threshold;
76        self
77    }
78
79    /// Update the scheduler with a new metric value
80    pub fn step_with_metric(&mut self, metric: F) -> F {
81        // Record metric
82        self.metric_history.push(metric);
83
84        let is_improvement = match self.best_metric {
85            None => true, // First metric value is always an improvement
86            Some(best) => {
87                match self.mode {
88                    OptimizationMode::Minimize => {
89                        // Mode is 'min', improvement means metric < best * (1 - threshold)
90                        metric < best * (F::one() - self.threshold)
91                    }
92                    OptimizationMode::Maximize => {
93                        // Mode is 'max', improvement means metric > best * (1 + threshold)
94                        metric > best * (F::one() + self.threshold)
95                    }
96                }
97            }
98        };
99
100        if is_improvement {
101            self.best_metric = Some(metric);
102            self.stagnation_count = 0;
103        } else {
104            self.stagnation_count += 1;
105
106            if self.stagnation_count >= self.patience {
107                // Reduce learning rate
108                self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
109                // Add to history
110                self.history.push(self.current_lr);
111                // Reset stagnation count
112                self.stagnation_count = 0;
113            }
114        }
115
116        self.current_lr
117    }
118
119    /// Get the current learning rate
120    pub fn get_learning_rate(&self) -> F {
121        self.current_lr
122    }
123
124    /// Reset the scheduler
125    pub fn reset(&mut self) {
126        self.current_lr = self.initial_lr;
127        self.stagnation_count = 0;
128        self.best_metric = None;
129        self.history = vec![self.initial_lr];
130        self.metric_history.clear();
131    }
132
133    /// Get the history of learning rates
134    pub fn history(&self) -> &[F] {
135        &self.history
136    }
137
138    /// Get the history of metric values
139    pub fn metric_history(&self) -> &[F] {
140        &self.metric_history
141    }
142
143    /// Get the best metric value
144    pub fn best_metric(&self) -> Option<F> {
145        self.best_metric
146    }
147
148    /// Create a scheduler configuration for use with external optimizers
149    ///
150    /// This returns the current state as a configuration that can be used
151    /// to create or update external schedulers from scirs2-optim.
152    pub fn to_scheduler_config(&self) -> crate::integration::optim::SchedulerConfig<F> {
153        use crate::integration::optim::SchedulerConfig;
154
155        SchedulerConfig {
156            initial_lr: self.initial_lr,
157            factor: self.factor,
158            patience: self.patience,
159            min_lr: self.min_lr,
160            mode: self.mode,
161            metric_name: self.metric_name.clone(),
162        }
163    }
164
165    /// Get the current scheduler state for external integration
166    pub fn get_state(&self) -> SchedulerState<F> {
167        SchedulerState {
168            current_lr: self.current_lr,
169            best_metric: self.best_metric,
170            stagnation_count: self.stagnation_count,
171            threshold: self.threshold,
172            mode: self.mode,
173        }
174    }
175}
176
177/// Current state of a metric-based scheduler
178#[derive(Debug, Clone)]
179pub struct SchedulerState<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
180    /// Current learning rate
181    pub current_lr: F,
182    /// Best metric value seen so far
183    pub best_metric: Option<F>,
184    /// Counter for steps with no improvement
185    pub stagnation_count: usize,
186    /// Threshold for measuring improvement
187    pub threshold: F,
188    /// Optimization mode
189    pub mode: OptimizationMode,
190}
191
192/// Trait for external scheduler integration
193///
194/// This trait can be implemented by external schedulers (like those in scirs2-optim)
195/// to provide seamless integration with metrics.
196pub trait MetricScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
197    /// Update the scheduler with a new metric value
198    fn step_with_metric(&mut self, metric: F) -> F;
199
200    /// Get the current learning rate
201    fn get_learning_rate(&self) -> F;
202
203    /// Reset the scheduler to initial state
204    fn reset(&mut self);
205
206    /// Set the mode (minimize or maximize)
207    fn set_mode(&mut self, mode: OptimizationMode);
208}
209
210/// Bridge adapter for external scheduler integration
211///
212/// This provides a standardized interface for metric-based scheduling
213/// without depending on specific external implementations.
214pub struct SchedulerBridge<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
215    /// Metric-based scheduler state
216    inner: Box<dyn MetricScheduler<F>>,
217    /// Metric name
218    metric_name: String,
219    /// Metric history
220    metric_history: Vec<F>,
221    /// Learning rate history
222    lr_history: Vec<F>,
223}