scirs2_metrics/integration/optim/
adapter.rs

1//! Adapters for scirs2-optim integration
2//!
3//! This module provides adapters for using scirs2-metrics with scirs2-optim.
4
5#[allow(unused_imports)]
6use crate::error::MetricsError;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use std::collections::HashMap;
9use std::fmt;
10use std::marker::PhantomData;
11
12/// Metric optimization mode
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum OptimizationMode {
15    /// Minimize the metric (lower is better)
16    Minimize,
17    /// Maximize the metric (higher is better)
18    Maximize,
19}
20
21impl fmt::Display for OptimizationMode {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            OptimizationMode::Minimize => write!(f, "minimize"),
25            OptimizationMode::Maximize => write!(f, "maximize"),
26        }
27    }
28}
29
30/// Adapter for using scirs2-metrics with scirs2-optim
31#[derive(Debug, Clone)]
32pub struct MetricOptimizer<F: Float + fmt::Debug + fmt::Display + FromPrimitive = f64> {
33    /// Metric name
34    metric_name: String,
35    /// Optimization mode
36    mode: OptimizationMode,
37    /// History of metric values
38    history: Vec<F>,
39    /// Best metric value seen so far
40    best_value: Option<F>,
41    /// Additional metrics to track
42    additional_metrics: HashMap<String, Vec<F>>,
43    /// Phantom data for F type
44    _phantom: PhantomData<F>,
45}
46
47impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> MetricOptimizer<F> {
48    /// Create a new metric optimizer
49    ///
50    /// # Arguments
51    ///
52    /// * `metric_name` - Name of the metric to optimize
53    /// * `maximize` - Whether to maximize (true) or minimize (false) the metric
54    pub fn new<S: Into<String>>(name: S, maximize: bool) -> Self {
55        Self {
56            metric_name: name.into(),
57            mode: if maximize {
58                OptimizationMode::Maximize
59            } else {
60                OptimizationMode::Minimize
61            },
62            history: Vec::new(),
63            best_value: None,
64            additional_metrics: HashMap::new(),
65            _phantom: PhantomData,
66        }
67    }
68
69    /// Get the metric name
70    pub fn metric_name(&self) -> &str {
71        &self.metric_name
72    }
73
74    /// Get the optimization mode
75    pub fn mode(&self) -> OptimizationMode {
76        self.mode
77    }
78
79    /// Get the metric history
80    pub fn history(&self) -> &[F] {
81        &self.history
82    }
83
84    /// Get the best metric value seen so far
85    pub fn best_value(&self) -> Option<F> {
86        self.best_value
87    }
88
89    /// Add a metric value to the history
90    pub fn add_value(&mut self, value: F) {
91        self.history.push(value);
92
93        // Update best value
94        self.best_value = match (self.best_value, self.mode) {
95            (None, _) => Some(value),
96            (Some(best), OptimizationMode::Maximize) if value > best => Some(value),
97            (Some(best), OptimizationMode::Minimize) if value < best => Some(value),
98            (Some(best), _) => Some(best),
99        };
100    }
101
102    /// Add a value for an additional metric to track
103    pub fn add_additional_value(&mut self, metricname: &str, value: F) {
104        self.additional_metrics
105            .entry(metricname.to_string())
106            .or_default()
107            .push(value);
108    }
109
110    /// Get the history of an additional metric
111    pub fn additional_metric_history(&self, metricname: &str) -> Option<&[F]> {
112        self.additional_metrics
113            .get(metricname)
114            .map(|v| v.as_slice())
115    }
116
117    /// Reset the optimizer state
118    pub fn reset(&mut self) {
119        self.history.clear();
120        self.best_value = None;
121        self.additional_metrics.clear();
122    }
123
124    /// Check if the current value is better than the previous best
125    pub fn is_better(&self, current: F, previous: F) -> bool {
126        match self.mode {
127            OptimizationMode::Maximize => current > previous,
128            OptimizationMode::Minimize => current < previous,
129        }
130    }
131
132    /// Check if the current metric value is better than the best so far
133    pub fn is_improvement(&self, value: F) -> bool {
134        match self.best_value {
135            None => true,
136            Some(best) => self.is_better(value, best),
137        }
138    }
139
140    /// Create scheduler configuration for this metric
141    ///
142    /// Returns a configuration that can be used to create an external scheduler.
143    /// This provides a bridge to scirs2-optim schedulers without circular dependencies.
144    pub fn create_scheduler_config(
145        &self,
146        initial_lr: F,
147        factor: F,
148        patience: usize,
149        min_lr: F,
150    ) -> SchedulerConfig<F> {
151        SchedulerConfig {
152            initial_lr,
153            factor,
154            patience,
155            min_lr,
156            mode: self.mode,
157            metric_name: self.metric_name.clone(),
158        }
159    }
160}
161
162/// Configuration for external scheduler creation
163#[derive(Debug, Clone)]
164pub struct SchedulerConfig<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
165    /// Initial learning rate
166    pub initial_lr: F,
167    /// Factor by which to reduce learning rate
168    pub factor: F,
169    /// Number of epochs with no improvement before reduction
170    pub patience: usize,
171    /// Minimum learning rate
172    pub min_lr: F,
173    /// Optimization mode (minimize or maximize)
174    pub mode: OptimizationMode,
175    /// Metric name for tracking
176    pub metric_name: String,
177}
178
179impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> SchedulerConfig<F> {
180    /// Get the configuration as a tuple for easy destructuring
181    pub fn as_tuple(&self) -> (F, F, usize, F, OptimizationMode) {
182        (
183            self.initial_lr,
184            self.factor,
185            self.patience,
186            self.min_lr,
187            self.mode,
188        )
189    }
190
191    /// Create a new scheduler configuration
192    pub fn new(
193        initial_lr: F,
194        factor: F,
195        patience: usize,
196        min_lr: F,
197        mode: OptimizationMode,
198        metric_name: String,
199    ) -> Self {
200        Self {
201            initial_lr,
202            factor,
203            patience,
204            min_lr,
205            mode,
206            metric_name,
207        }
208    }
209}