scirs2_metrics/integration/optim/
hyperparameter.rs

1//! Hyperparameter tuning utilities
2//!
3//! This module provides utilities for hyperparameter tuning using metrics.
4
5use crate::error::{MetricsError, Result};
6use crate::integration::optim::OptimizationMode;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::Rng;
9use std::collections::HashMap;
10use std::fmt;
11use std::marker::PhantomData;
12
13/// A hyperparameter with its range
14#[derive(Debug, Clone)]
15pub struct HyperParameter<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
16    /// Name of the hyperparameter
17    name: String,
18    /// Current value
19    value: F,
20    /// Minimum value (inclusive)
21    min_value: F,
22    /// Maximum value (inclusive)
23    maxvalue: F,
24    /// Step size for discrete search
25    step: Option<F>,
26    /// Is the parameter categorical
27    is_categorical: bool,
28    /// Possible categorical values
29    categorical_values: Option<Vec<F>>,
30}
31
32impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameter<F> {
33    /// Create a new continuous hyperparameter
34    pub fn new<S: Into<String>>(name: S, value: F, min_value: F, maxvalue: F) -> Self {
35        Self {
36            name: name.into(),
37            value,
38            min_value,
39            maxvalue,
40            step: None,
41            is_categorical: false,
42            categorical_values: None,
43        }
44    }
45
46    /// Create a new discrete hyperparameter
47    pub fn discrete<S: Into<String>>(
48        name: S,
49        value: F,
50        min_value: F,
51        maxvalue: F,
52        step: F,
53    ) -> Self {
54        Self {
55            name: name.into(),
56            value,
57            min_value,
58            maxvalue,
59            step: Some(step),
60            is_categorical: false,
61            categorical_values: None,
62        }
63    }
64
65    /// Create a new categorical hyperparameter
66    pub fn categorical<S: Into<String>>(name: S, value: F, values: Vec<F>) -> Result<Self> {
67        if values.is_empty() {
68            return Err(MetricsError::InvalidArgument(
69                "Categorical values cannot be empty".to_string(),
70            ));
71        }
72        if !values.contains(&value) {
73            return Err(MetricsError::InvalidArgument(format!(
74                "Current value {} must be one of the categorical values",
75                value
76            )));
77        }
78
79        Ok(Self {
80            name: name.into(),
81            value,
82            min_value: F::zero(),
83            maxvalue: F::from(values.len() - 1).unwrap(),
84            step: Some(F::one()),
85            is_categorical: true,
86            categorical_values: Some(values),
87        })
88    }
89
90    /// Get the name
91    pub fn name(&self) -> &str {
92        &self.name
93    }
94
95    /// Get the current value
96    pub fn value(&self) -> F {
97        self.value
98    }
99
100    /// Set the value
101    pub fn set_value(&mut self, value: F) -> Result<()> {
102        if self.is_categorical {
103            if let Some(values) = &self.categorical_values {
104                if !values.contains(&value) {
105                    return Err(MetricsError::InvalidArgument(format!(
106                        "Value {} is not a valid categorical value for parameter {}",
107                        value, self.name
108                    )));
109                }
110            }
111        } else if value < self.min_value || value > self.maxvalue {
112            return Err(MetricsError::InvalidArgument(format!(
113                "Value {} out of range [{}, {}] for parameter {}",
114                value, self.min_value, self.maxvalue, self.name
115            )));
116        }
117
118        self.value = value;
119        Ok(())
120    }
121
122    /// Get a random value within the parameter's range
123    pub fn random_value(&self) -> F {
124        if self.is_categorical {
125            if let Some(values) = &self.categorical_values {
126                let mut rng = scirs2_core::random::rng();
127                let idx = rng.random_range(0..values.len());
128                return values[idx];
129            }
130        }
131
132        let range = self.maxvalue - self.min_value;
133        let mut rng = scirs2_core::random::rng();
134        let rand_val = F::from(rng.random::<f64>()).unwrap() * range + self.min_value;
135
136        if let Some(step) = self.step {
137            // For discrete parameters..round to the nearest step
138            let steps = ((rand_val - self.min_value) / step).round();
139            self.min_value + steps * step
140        } else {
141            rand_val
142        }
143    }
144
145    /// Validate that the current parameter configuration is valid
146    pub fn validate(&self) -> Result<()> {
147        if self.is_categorical {
148            if let Some(values) = &self.categorical_values {
149                if values.is_empty() {
150                    return Err(MetricsError::InvalidArgument(
151                        "Categorical values cannot be empty".to_string(),
152                    ));
153                }
154                if !values.contains(&self.value) {
155                    return Err(MetricsError::InvalidArgument(format!(
156                        "Current value {} is not in categorical values for parameter {}",
157                        self.value, self.name
158                    )));
159                }
160            } else {
161                return Err(MetricsError::InvalidArgument(format!(
162                    "Categorical parameter {} missing values",
163                    self.name
164                )));
165            }
166        } else {
167            if self.min_value > self.maxvalue {
168                return Err(MetricsError::InvalidArgument(format!(
169                    "Min value {} cannot be greater than max value {} for parameter {}",
170                    self.min_value, self.maxvalue, self.name
171                )));
172            }
173            if self.value < self.min_value || self.value > self.maxvalue {
174                return Err(MetricsError::InvalidArgument(format!(
175                    "Current value {} is out of range [{}, {}] for parameter {}",
176                    self.value, self.min_value, self.maxvalue, self.name
177                )));
178            }
179            if let Some(step) = self.step {
180                if step <= F::zero() {
181                    return Err(MetricsError::InvalidArgument(format!(
182                        "Step size must be positive for parameter {}",
183                        self.name
184                    )));
185                }
186            }
187        }
188        Ok(())
189    }
190
191    /// Get the valid range for this parameter
192    pub fn get_range(&self) -> (F, F) {
193        (self.min_value, self.maxvalue)
194    }
195
196    /// Get the step size (if discrete)
197    pub fn get_step(&self) -> Option<F> {
198        self.step
199    }
200
201    /// Check if parameter is categorical
202    pub fn is_categorical(&self) -> bool {
203        self.is_categorical
204    }
205
206    /// Get categorical values (if categorical)
207    pub fn get_categorical_values(&self) -> Option<&Vec<F>> {
208        self.categorical_values.as_ref()
209    }
210}
211
212/// A hyperparameter search result
213#[derive(Debug, Clone)]
214pub struct HyperParameterSearchResult<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
215    /// Metric name that was optimized
216    #[allow(dead_code)]
217    metric_name: String,
218    /// Optimization mode used
219    mode: OptimizationMode,
220    /// Best metric value found
221    best_metric: F,
222    /// Best hyperparameter values found
223    best_params: HashMap<String, F>,
224    /// History of all evaluations
225    history: Vec<(HashMap<String, F>, F)>,
226}
227
228impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameterSearchResult<F> {
229    /// Create a new hyperparameter search result
230    pub fn new<S: Into<String>>(
231        metric_name: S,
232        mode: OptimizationMode,
233        best_metric: F,
234        best_params: HashMap<String, F>,
235    ) -> Self {
236        Self {
237            metric_name: metric_name.into(),
238            mode,
239            best_metric,
240            best_params,
241            history: Vec::new(),
242        }
243    }
244
245    /// Add an evaluation to the history
246    pub fn add_evaluation(&mut self, params: HashMap<String, F>, metric: F) {
247        self.history.push((params.clone(), metric));
248
249        // Update best if better
250        let is_better = match self.mode {
251            OptimizationMode::Maximize => metric > self.best_metric,
252            OptimizationMode::Minimize => metric < self.best_metric,
253        };
254
255        if is_better {
256            self.best_metric = metric;
257            self.best_params = params;
258        }
259    }
260
261    /// Get the best metric value
262    pub fn best_metric(&self) -> F {
263        self.best_metric
264    }
265
266    /// Get the best hyperparameter values
267    pub fn best_params(&self) -> &HashMap<String, F> {
268        &self.best_params
269    }
270
271    /// Get the history of evaluations
272    pub fn history(&self) -> &[(HashMap<String, F>, F)] {
273        &self.history
274    }
275}
276
277/// A hyperparameter tuner
278#[derive(Debug)]
279pub struct HyperParameterTuner<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
280    /// Hyperparameters to tune
281    params: Vec<HyperParameter<F>>,
282    /// Metric name to optimize
283    metric_name: String,
284    /// Optimization mode
285    mode: OptimizationMode,
286    /// Maximum number of evaluations
287    max_evals: usize,
288    /// Current best value
289    best_value: Option<F>,
290    /// Current best parameters
291    best_params: HashMap<String, F>,
292    /// History of evaluations
293    history: Vec<(HashMap<String, F>, F)>,
294    /// Phantom data for F type
295    _phantom: PhantomData<F>,
296}
297
298impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameterTuner<F> {
299    /// Create a new hyperparameter tuner
300    pub fn new<S: Into<String>>(
301        params: Vec<HyperParameter<F>>,
302        metric_name: S,
303        maximize: bool,
304        max_evals: usize,
305    ) -> Result<Self> {
306        if params.is_empty() {
307            return Err(MetricsError::InvalidArgument(
308                "Cannot create tuner with empty parameter list".to_string(),
309            ));
310        }
311
312        if max_evals == 0 {
313            return Err(MetricsError::InvalidArgument(
314                "Maximum evaluations must be greater than 0".to_string(),
315            ));
316        }
317
318        // Validate all parameters
319        for param in &params {
320            param.validate()?;
321        }
322
323        // Check for duplicate parameter names
324        let mut names = std::collections::HashSet::new();
325        for param in &params {
326            if !names.insert(param.name()) {
327                return Err(MetricsError::InvalidArgument(format!(
328                    "Duplicate parameter name: {}",
329                    param.name()
330                )));
331            }
332        }
333
334        Ok(Self {
335            params,
336            metric_name: metric_name.into(),
337            mode: if maximize {
338                OptimizationMode::Maximize
339            } else {
340                OptimizationMode::Minimize
341            },
342            max_evals,
343            best_value: None,
344            best_params: HashMap::new(),
345            history: Vec::new(),
346            _phantom: PhantomData,
347        })
348    }
349
350    /// Get the current hyperparameter values
351    pub fn get_current_params(&self) -> HashMap<String, F> {
352        self.params
353            .iter()
354            .map(|p| (p.name().to_string(), p.value()))
355            .collect()
356    }
357
358    /// Set hyperparameter values
359    pub fn set_params(&mut self, params: &HashMap<String, F>) -> Result<()> {
360        for (name, value) in params {
361            if let Some(param) = self.params.iter_mut().find(|p| p.name() == name) {
362                param.set_value(*value)?;
363            }
364        }
365        Ok(())
366    }
367
368    /// Update the tuner with an evaluation result
369    pub fn update(&mut self, metricvalue: F) -> Result<bool> {
370        let current_params = self.get_current_params();
371
372        // Check if this is the best _value so far
373        let is_best = match (self.best_value, self.mode) {
374            (None, _) => true,
375            (Some(best), OptimizationMode::Maximize) => metricvalue > best,
376            (Some(best), OptimizationMode::Minimize) => metricvalue < best,
377        };
378
379        // Update history
380        self.history.push((current_params.clone(), metricvalue));
381
382        // Update best if this is the best so far
383        if is_best {
384            self.best_value = Some(metricvalue);
385            self.best_params = current_params;
386        }
387
388        Ok(is_best)
389    }
390
391    /// Get a random set of hyperparameters
392    pub fn random_params(&self) -> HashMap<String, F> {
393        self.params
394            .iter()
395            .map(|p| (p.name().to_string(), p.random_value()))
396            .collect()
397    }
398
399    /// Run random search for hyperparameter tuning
400    pub fn random_search<FnEval>(
401        &mut self,
402        eval_fn: FnEval,
403    ) -> Result<HyperParameterSearchResult<F>>
404    where
405        FnEval: Fn(&HashMap<String, F>) -> Result<F>,
406    {
407        // Reset history
408        self.history.clear();
409        self.best_value = None;
410
411        for _ in 0..self.max_evals {
412            // Get random parameters
413            let params = self.random_params();
414
415            // Set parameters
416            self.set_params(&params)?;
417
418            // Evaluate
419            let metric = eval_fn(&params)?;
420
421            // Update
422            self.update(metric)?;
423        }
424
425        // Create result
426        let result = HyperParameterSearchResult::new(
427            self.metric_name.clone(),
428            self.mode,
429            self.best_value.unwrap_or_else(|| match self.mode {
430                OptimizationMode::Maximize => F::neg_infinity(),
431                OptimizationMode::Minimize => F::infinity(),
432            }),
433            self.best_params.clone(),
434        );
435
436        Ok(result)
437    }
438
439    /// Get the best parameters found so far
440    pub fn best_params(&self) -> &HashMap<String, F> {
441        &self.best_params
442    }
443
444    /// Get the best metric value found so far
445    pub fn best_value(&self) -> Option<F> {
446        self.best_value
447    }
448
449    /// Get the history of evaluations
450    pub fn history(&self) -> &[(HashMap<String, F>, F)] {
451        &self.history
452    }
453}