scirs2_metrics/integration/optim/
adapter.rs1#[allow(unused_imports)]
6use crate::error::MetricsError;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt;
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum OptimizationMode {
15 Minimize,
17 Maximize,
19}
20
21impl fmt::Display for OptimizationMode {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 OptimizationMode::Minimize => write!(f, "minimize"),
25 OptimizationMode::Maximize => write!(f, "maximize"),
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct MetricOptimizer<F: Float + fmt::Debug + fmt::Display + FromPrimitive = f64> {
33 metric_name: String,
35 mode: OptimizationMode,
37 history: Vec<F>,
39 best_value: Option<F>,
41 additional_metrics: HashMap<String, Vec<F>>,
43 _phantom: PhantomData<F>,
45}
46
47impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricOptimizer<F> {
48 pub fn new<S: Into<String>>(name: S, maximize: bool) -> Self {
55 Self {
56 metric_name: name.into(),
57 mode: if maximize {
58 OptimizationMode::Maximize
59 } else {
60 OptimizationMode::Minimize
61 },
62 history: Vec::new(),
63 best_value: None,
64 additional_metrics: HashMap::new(),
65 _phantom: PhantomData,
66 }
67 }
68
69 pub fn metric_name(&self) -> &str {
71 &self.metric_name
72 }
73
74 pub fn mode(&self) -> OptimizationMode {
76 self.mode
77 }
78
79 pub fn history(&self) -> &[F] {
81 &self.history
82 }
83
84 pub fn best_value(&self) -> Option<F> {
86 self.best_value
87 }
88
89 pub fn add_value(&mut self, value: F) {
91 self.history.push(value);
92
93 self.best_value = match (self.best_value, self.mode) {
95 (None, _) => Some(value),
96 (Some(best), OptimizationMode::Maximize) if value > best => Some(value),
97 (Some(best), OptimizationMode::Minimize) if value < best => Some(value),
98 (Some(best), _) => Some(best),
99 };
100 }
101
102 pub fn add_additional_value(&mut self, metricname: &str, value: F) {
104 self.additional_metrics
105 .entry(metricname.to_string())
106 .or_default()
107 .push(value);
108 }
109
110 pub fn additional_metric_history(&self, metricname: &str) -> Option<&[F]> {
112 self.additional_metrics
113 .get(metricname)
114 .map(|v| v.as_slice())
115 }
116
117 pub fn reset(&mut self) {
119 self.history.clear();
120 self.best_value = None;
121 self.additional_metrics.clear();
122 }
123
124 pub fn is_better(&self, current: F, previous: F) -> bool {
126 match self.mode {
127 OptimizationMode::Maximize => current > previous,
128 OptimizationMode::Minimize => current < previous,
129 }
130 }
131
132 pub fn is_improvement(&self, value: F) -> bool {
134 match self.best_value {
135 None => true,
136 Some(best) => self.is_better(value, best),
137 }
138 }
139
140 pub fn create_scheduler_config(
145 &self,
146 initial_lr: F,
147 factor: F,
148 patience: usize,
149 min_lr: F,
150 ) -> SchedulerConfig<F> {
151 SchedulerConfig {
152 initial_lr,
153 factor,
154 patience,
155 min_lr,
156 mode: self.mode,
157 metric_name: self.metric_name.clone(),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct SchedulerConfig<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
165 pub initial_lr: F,
167 pub factor: F,
169 pub patience: usize,
171 pub min_lr: F,
173 pub mode: OptimizationMode,
175 pub metric_name: String,
177}
178
179impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> SchedulerConfig<F> {
180 pub fn as_tuple(&self) -> (F, F, usize, F, OptimizationMode) {
182 (
183 self.initial_lr,
184 self.factor,
185 self.patience,
186 self.min_lr,
187 self.mode,
188 )
189 }
190
191 pub fn new(
193 initial_lr: F,
194 factor: F,
195 patience: usize,
196 min_lr: F,
197 mode: OptimizationMode,
198 metric_name: String,
199 ) -> Self {
200 Self {
201 initial_lr,
202 factor,
203 patience,
204 min_lr,
205 mode,
206 metric_name,
207 }
208 }
209}