Skip to main content

tensorlogic_train/hyperparameter/
space.rs

1//! Hyperparameter space definition for search algorithms.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::random::{RngExt, StdRng};
5
6use super::value::HyperparamValue;
7
8/// Hyperparameter space definition.
9#[derive(Debug, Clone)]
10pub enum HyperparamSpace {
11    /// Discrete choices.
12    Discrete(Vec<HyperparamValue>),
13    /// Continuous range [min, max].
14    Continuous { min: f64, max: f64 },
15    /// Log-uniform distribution [min, max].
16    LogUniform { min: f64, max: f64 },
17    /// Integer range [min, max].
18    IntRange { min: i64, max: i64 },
19}
20
21impl HyperparamSpace {
22    /// Create a discrete choice space.
23    pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
24        if values.is_empty() {
25            return Err(TrainError::InvalidParameter(
26                "Discrete space cannot be empty".to_string(),
27            ));
28        }
29        Ok(Self::Discrete(values))
30    }
31
32    /// Create a continuous range space.
33    pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
34        if min >= max {
35            return Err(TrainError::InvalidParameter(
36                "min must be less than max".to_string(),
37            ));
38        }
39        Ok(Self::Continuous { min, max })
40    }
41
42    /// Create a log-uniform distribution space.
43    pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
44        if min <= 0.0 || max <= 0.0 || min >= max {
45            return Err(TrainError::InvalidParameter(
46                "min and max must be positive and min < max".to_string(),
47            ));
48        }
49        Ok(Self::LogUniform { min, max })
50    }
51
52    /// Create an integer range space.
53    pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
54        if min >= max {
55            return Err(TrainError::InvalidParameter(
56                "min must be less than max".to_string(),
57            ));
58        }
59        Ok(Self::IntRange { min, max })
60    }
61
62    /// Sample a value from this space.
63    pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
64        match self {
65            HyperparamSpace::Discrete(values) => {
66                let idx = rng.gen_range(0..values.len());
67                values[idx].clone()
68            }
69            HyperparamSpace::Continuous { min, max } => {
70                let value = min + (max - min) * rng.random::<f64>();
71                HyperparamValue::Float(value)
72            }
73            HyperparamSpace::LogUniform { min, max } => {
74                let log_min = min.ln();
75                let log_max = max.ln();
76                let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
77                HyperparamValue::Float(log_value.exp())
78            }
79            HyperparamSpace::IntRange { min, max } => {
80                let value = rng.gen_range(*min..=*max);
81                HyperparamValue::Int(value)
82            }
83        }
84    }
85
86    /// Get all possible values for grid search (for discrete/int spaces).
87    pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
88        match self {
89            HyperparamSpace::Discrete(values) => values.clone(),
90            HyperparamSpace::IntRange { min, max } => {
91                let range_size = (max - min + 1) as usize;
92                let step = (range_size / num_samples).max(1);
93                (*min..=*max)
94                    .step_by(step)
95                    .map(HyperparamValue::Int)
96                    .collect()
97            }
98            HyperparamSpace::Continuous { min, max } => {
99                let step = (max - min) / (num_samples as f64);
100                (0..num_samples)
101                    .map(|i| HyperparamValue::Float(min + step * i as f64))
102                    .collect()
103            }
104            HyperparamSpace::LogUniform { min, max } => {
105                let log_min = min.ln();
106                let log_max = max.ln();
107                let log_step = (log_max - log_min) / (num_samples as f64);
108                (0..num_samples)
109                    .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
110                    .collect()
111            }
112        }
113    }
114}