scirs2_metrics/integration/optim/
scheduler.rs1#[allow(unused_imports)]
7use crate::error::Result;
8use crate::integration::optim::OptimizationMode;
9use scirs2_core::numeric::{Float, FromPrimitive};
10use std::fmt;
11#[allow(unused_imports)]
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub struct MetricLRScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
17 current_lr: F,
19 initial_lr: F,
21 factor: F,
23 patience: usize,
25 min_lr: F,
27 stagnation_count: usize,
29 best_metric: Option<F>,
31 threshold: F,
33 mode: OptimizationMode,
35 metric_name: String,
37 history: Vec<F>,
39 metric_history: Vec<F>,
41}
42
43impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricLRScheduler<F> {
44 pub fn new<S: Into<String>>(
46 initial_lr: F,
47 factor: F,
48 patience: usize,
49 min_lr: F,
50 metric_name: S,
51 maximize: bool,
52 ) -> Self {
53 Self {
54 current_lr: initial_lr,
55 initial_lr,
56 factor,
57 patience,
58 min_lr,
59 stagnation_count: 0,
60 best_metric: None,
61 threshold: F::from(1e-4).unwrap(),
62 mode: if maximize {
63 OptimizationMode::Maximize
64 } else {
65 OptimizationMode::Minimize
66 },
67 metric_name: metric_name.into(),
68 history: vec![initial_lr],
69 metric_history: Vec::new(),
70 }
71 }
72
73 pub fn set_threshold(&mut self, threshold: F) -> &mut Self {
75 self.threshold = threshold;
76 self
77 }
78
79 pub fn step_with_metric(&mut self, metric: F) -> F {
81 self.metric_history.push(metric);
83
84 let is_improvement = match self.best_metric {
85 None => true, Some(best) => {
87 match self.mode {
88 OptimizationMode::Minimize => {
89 metric < best * (F::one() - self.threshold)
91 }
92 OptimizationMode::Maximize => {
93 metric > best * (F::one() + self.threshold)
95 }
96 }
97 }
98 };
99
100 if is_improvement {
101 self.best_metric = Some(metric);
102 self.stagnation_count = 0;
103 } else {
104 self.stagnation_count += 1;
105
106 if self.stagnation_count >= self.patience {
107 self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
109 self.history.push(self.current_lr);
111 self.stagnation_count = 0;
113 }
114 }
115
116 self.current_lr
117 }
118
119 pub fn get_learning_rate(&self) -> F {
121 self.current_lr
122 }
123
124 pub fn reset(&mut self) {
126 self.current_lr = self.initial_lr;
127 self.stagnation_count = 0;
128 self.best_metric = None;
129 self.history = vec![self.initial_lr];
130 self.metric_history.clear();
131 }
132
133 pub fn history(&self) -> &[F] {
135 &self.history
136 }
137
138 pub fn metric_history(&self) -> &[F] {
140 &self.metric_history
141 }
142
143 pub fn best_metric(&self) -> Option<F> {
145 self.best_metric
146 }
147
148 pub fn to_scheduler_config(&self) -> crate::integration::optim::SchedulerConfig<F> {
153 use crate::integration::optim::SchedulerConfig;
154
155 SchedulerConfig {
156 initial_lr: self.initial_lr,
157 factor: self.factor,
158 patience: self.patience,
159 min_lr: self.min_lr,
160 mode: self.mode,
161 metric_name: self.metric_name.clone(),
162 }
163 }
164
165 pub fn get_state(&self) -> SchedulerState<F> {
167 SchedulerState {
168 current_lr: self.current_lr,
169 best_metric: self.best_metric,
170 stagnation_count: self.stagnation_count,
171 threshold: self.threshold,
172 mode: self.mode,
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct SchedulerState<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
180 pub current_lr: F,
182 pub best_metric: Option<F>,
184 pub stagnation_count: usize,
186 pub threshold: F,
188 pub mode: OptimizationMode,
190}
191
192pub trait MetricScheduler<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
197 fn step_with_metric(&mut self, metric: F) -> F;
199
200 fn get_learning_rate(&self) -> F;
202
203 fn reset(&mut self);
205
206 fn set_mode(&mut self, mode: OptimizationMode);
208}
209
210pub struct SchedulerBridge<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
215 inner: Box<dyn MetricScheduler<F>>,
217 metric_name: String,
219 metric_history: Vec<F>,
221 lr_history: Vec<F>,
223}