tensorlogic_train/hyperparameter/
space.rs1use crate::{TrainError, TrainResult};
4use scirs2_core::random::{RngExt, StdRng};
5
6use super::value::HyperparamValue;
7
8#[derive(Debug, Clone)]
10pub enum HyperparamSpace {
11 Discrete(Vec<HyperparamValue>),
13 Continuous { min: f64, max: f64 },
15 LogUniform { min: f64, max: f64 },
17 IntRange { min: i64, max: i64 },
19}
20
21impl HyperparamSpace {
22 pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
24 if values.is_empty() {
25 return Err(TrainError::InvalidParameter(
26 "Discrete space cannot be empty".to_string(),
27 ));
28 }
29 Ok(Self::Discrete(values))
30 }
31
32 pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
34 if min >= max {
35 return Err(TrainError::InvalidParameter(
36 "min must be less than max".to_string(),
37 ));
38 }
39 Ok(Self::Continuous { min, max })
40 }
41
42 pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
44 if min <= 0.0 || max <= 0.0 || min >= max {
45 return Err(TrainError::InvalidParameter(
46 "min and max must be positive and min < max".to_string(),
47 ));
48 }
49 Ok(Self::LogUniform { min, max })
50 }
51
52 pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
54 if min >= max {
55 return Err(TrainError::InvalidParameter(
56 "min must be less than max".to_string(),
57 ));
58 }
59 Ok(Self::IntRange { min, max })
60 }
61
62 pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
64 match self {
65 HyperparamSpace::Discrete(values) => {
66 let idx = rng.gen_range(0..values.len());
67 values[idx].clone()
68 }
69 HyperparamSpace::Continuous { min, max } => {
70 let value = min + (max - min) * rng.random::<f64>();
71 HyperparamValue::Float(value)
72 }
73 HyperparamSpace::LogUniform { min, max } => {
74 let log_min = min.ln();
75 let log_max = max.ln();
76 let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
77 HyperparamValue::Float(log_value.exp())
78 }
79 HyperparamSpace::IntRange { min, max } => {
80 let value = rng.gen_range(*min..=*max);
81 HyperparamValue::Int(value)
82 }
83 }
84 }
85
86 pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
88 match self {
89 HyperparamSpace::Discrete(values) => values.clone(),
90 HyperparamSpace::IntRange { min, max } => {
91 let range_size = (max - min + 1) as usize;
92 let step = (range_size / num_samples).max(1);
93 (*min..=*max)
94 .step_by(step)
95 .map(HyperparamValue::Int)
96 .collect()
97 }
98 HyperparamSpace::Continuous { min, max } => {
99 let step = (max - min) / (num_samples as f64);
100 (0..num_samples)
101 .map(|i| HyperparamValue::Float(min + step * i as f64))
102 .collect()
103 }
104 HyperparamSpace::LogUniform { min, max } => {
105 let log_min = min.ln();
106 let log_max = max.ln();
107 let log_step = (log_max - log_min) / (num_samples as f64);
108 (0..num_samples)
109 .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
110 .collect()
111 }
112 }
113 }
114}