scirs2_metrics/integration/optim/
hyperparameter.rs1use crate::error::{MetricsError, Result};
6use crate::integration::optim::OptimizationMode;
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::Rng;
9use std::collections::HashMap;
10use std::fmt;
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone)]
15pub struct HyperParameter<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
16 name: String,
18 value: F,
20 min_value: F,
22 maxvalue: F,
24 step: Option<F>,
26 is_categorical: bool,
28 categorical_values: Option<Vec<F>>,
30}
31
32impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameter<F> {
33 pub fn new<S: Into<String>>(name: S, value: F, min_value: F, maxvalue: F) -> Self {
35 Self {
36 name: name.into(),
37 value,
38 min_value,
39 maxvalue,
40 step: None,
41 is_categorical: false,
42 categorical_values: None,
43 }
44 }
45
46 pub fn discrete<S: Into<String>>(
48 name: S,
49 value: F,
50 min_value: F,
51 maxvalue: F,
52 step: F,
53 ) -> Self {
54 Self {
55 name: name.into(),
56 value,
57 min_value,
58 maxvalue,
59 step: Some(step),
60 is_categorical: false,
61 categorical_values: None,
62 }
63 }
64
65 pub fn categorical<S: Into<String>>(name: S, value: F, values: Vec<F>) -> Result<Self> {
67 if values.is_empty() {
68 return Err(MetricsError::InvalidArgument(
69 "Categorical values cannot be empty".to_string(),
70 ));
71 }
72 if !values.contains(&value) {
73 return Err(MetricsError::InvalidArgument(format!(
74 "Current value {} must be one of the categorical values",
75 value
76 )));
77 }
78
79 Ok(Self {
80 name: name.into(),
81 value,
82 min_value: F::zero(),
83 maxvalue: F::from(values.len() - 1).unwrap(),
84 step: Some(F::one()),
85 is_categorical: true,
86 categorical_values: Some(values),
87 })
88 }
89
90 pub fn name(&self) -> &str {
92 &self.name
93 }
94
95 pub fn value(&self) -> F {
97 self.value
98 }
99
100 pub fn set_value(&mut self, value: F) -> Result<()> {
102 if self.is_categorical {
103 if let Some(values) = &self.categorical_values {
104 if !values.contains(&value) {
105 return Err(MetricsError::InvalidArgument(format!(
106 "Value {} is not a valid categorical value for parameter {}",
107 value, self.name
108 )));
109 }
110 }
111 } else if value < self.min_value || value > self.maxvalue {
112 return Err(MetricsError::InvalidArgument(format!(
113 "Value {} out of range [{}, {}] for parameter {}",
114 value, self.min_value, self.maxvalue, self.name
115 )));
116 }
117
118 self.value = value;
119 Ok(())
120 }
121
122 pub fn random_value(&self) -> F {
124 if self.is_categorical {
125 if let Some(values) = &self.categorical_values {
126 let mut rng = scirs2_core::random::rng();
127 let idx = rng.random_range(0..values.len());
128 return values[idx];
129 }
130 }
131
132 let range = self.maxvalue - self.min_value;
133 let mut rng = scirs2_core::random::rng();
134 let rand_val = F::from(rng.random::<f64>()).unwrap() * range + self.min_value;
135
136 if let Some(step) = self.step {
137 let steps = ((rand_val - self.min_value) / step).round();
139 self.min_value + steps * step
140 } else {
141 rand_val
142 }
143 }
144
145 pub fn validate(&self) -> Result<()> {
147 if self.is_categorical {
148 if let Some(values) = &self.categorical_values {
149 if values.is_empty() {
150 return Err(MetricsError::InvalidArgument(
151 "Categorical values cannot be empty".to_string(),
152 ));
153 }
154 if !values.contains(&self.value) {
155 return Err(MetricsError::InvalidArgument(format!(
156 "Current value {} is not in categorical values for parameter {}",
157 self.value, self.name
158 )));
159 }
160 } else {
161 return Err(MetricsError::InvalidArgument(format!(
162 "Categorical parameter {} missing values",
163 self.name
164 )));
165 }
166 } else {
167 if self.min_value > self.maxvalue {
168 return Err(MetricsError::InvalidArgument(format!(
169 "Min value {} cannot be greater than max value {} for parameter {}",
170 self.min_value, self.maxvalue, self.name
171 )));
172 }
173 if self.value < self.min_value || self.value > self.maxvalue {
174 return Err(MetricsError::InvalidArgument(format!(
175 "Current value {} is out of range [{}, {}] for parameter {}",
176 self.value, self.min_value, self.maxvalue, self.name
177 )));
178 }
179 if let Some(step) = self.step {
180 if step <= F::zero() {
181 return Err(MetricsError::InvalidArgument(format!(
182 "Step size must be positive for parameter {}",
183 self.name
184 )));
185 }
186 }
187 }
188 Ok(())
189 }
190
191 pub fn get_range(&self) -> (F, F) {
193 (self.min_value, self.maxvalue)
194 }
195
196 pub fn get_step(&self) -> Option<F> {
198 self.step
199 }
200
201 pub fn is_categorical(&self) -> bool {
203 self.is_categorical
204 }
205
206 pub fn get_categorical_values(&self) -> Option<&Vec<F>> {
208 self.categorical_values.as_ref()
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct HyperParameterSearchResult<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
215 #[allow(dead_code)]
217 metric_name: String,
218 mode: OptimizationMode,
220 best_metric: F,
222 best_params: HashMap<String, F>,
224 history: Vec<(HashMap<String, F>, F)>,
226}
227
228impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameterSearchResult<F> {
229 pub fn new<S: Into<String>>(
231 metric_name: S,
232 mode: OptimizationMode,
233 best_metric: F,
234 best_params: HashMap<String, F>,
235 ) -> Self {
236 Self {
237 metric_name: metric_name.into(),
238 mode,
239 best_metric,
240 best_params,
241 history: Vec::new(),
242 }
243 }
244
245 pub fn add_evaluation(&mut self, params: HashMap<String, F>, metric: F) {
247 self.history.push((params.clone(), metric));
248
249 let is_better = match self.mode {
251 OptimizationMode::Maximize => metric > self.best_metric,
252 OptimizationMode::Minimize => metric < self.best_metric,
253 };
254
255 if is_better {
256 self.best_metric = metric;
257 self.best_params = params;
258 }
259 }
260
261 pub fn best_metric(&self) -> F {
263 self.best_metric
264 }
265
266 pub fn best_params(&self) -> &HashMap<String, F> {
268 &self.best_params
269 }
270
271 pub fn history(&self) -> &[(HashMap<String, F>, F)] {
273 &self.history
274 }
275}
276
277#[derive(Debug)]
279pub struct HyperParameterTuner<F: Float + fmt::Debug + fmt::Display + FromPrimitive> {
280 params: Vec<HyperParameter<F>>,
282 metric_name: String,
284 mode: OptimizationMode,
286 max_evals: usize,
288 best_value: Option<F>,
290 best_params: HashMap<String, F>,
292 history: Vec<(HashMap<String, F>, F)>,
294 _phantom: PhantomData<F>,
296}
297
298impl<F: Float + fmt::Debug + fmt::Display + FromPrimitive> HyperParameterTuner<F> {
299 pub fn new<S: Into<String>>(
301 params: Vec<HyperParameter<F>>,
302 metric_name: S,
303 maximize: bool,
304 max_evals: usize,
305 ) -> Result<Self> {
306 if params.is_empty() {
307 return Err(MetricsError::InvalidArgument(
308 "Cannot create tuner with empty parameter list".to_string(),
309 ));
310 }
311
312 if max_evals == 0 {
313 return Err(MetricsError::InvalidArgument(
314 "Maximum evaluations must be greater than 0".to_string(),
315 ));
316 }
317
318 for param in ¶ms {
320 param.validate()?;
321 }
322
323 let mut names = std::collections::HashSet::new();
325 for param in ¶ms {
326 if !names.insert(param.name()) {
327 return Err(MetricsError::InvalidArgument(format!(
328 "Duplicate parameter name: {}",
329 param.name()
330 )));
331 }
332 }
333
334 Ok(Self {
335 params,
336 metric_name: metric_name.into(),
337 mode: if maximize {
338 OptimizationMode::Maximize
339 } else {
340 OptimizationMode::Minimize
341 },
342 max_evals,
343 best_value: None,
344 best_params: HashMap::new(),
345 history: Vec::new(),
346 _phantom: PhantomData,
347 })
348 }
349
350 pub fn get_current_params(&self) -> HashMap<String, F> {
352 self.params
353 .iter()
354 .map(|p| (p.name().to_string(), p.value()))
355 .collect()
356 }
357
358 pub fn set_params(&mut self, params: &HashMap<String, F>) -> Result<()> {
360 for (name, value) in params {
361 if let Some(param) = self.params.iter_mut().find(|p| p.name() == name) {
362 param.set_value(*value)?;
363 }
364 }
365 Ok(())
366 }
367
368 pub fn update(&mut self, metricvalue: F) -> Result<bool> {
370 let current_params = self.get_current_params();
371
372 let is_best = match (self.best_value, self.mode) {
374 (None, _) => true,
375 (Some(best), OptimizationMode::Maximize) => metricvalue > best,
376 (Some(best), OptimizationMode::Minimize) => metricvalue < best,
377 };
378
379 self.history.push((current_params.clone(), metricvalue));
381
382 if is_best {
384 self.best_value = Some(metricvalue);
385 self.best_params = current_params;
386 }
387
388 Ok(is_best)
389 }
390
391 pub fn random_params(&self) -> HashMap<String, F> {
393 self.params
394 .iter()
395 .map(|p| (p.name().to_string(), p.random_value()))
396 .collect()
397 }
398
399 pub fn random_search<FnEval>(
401 &mut self,
402 eval_fn: FnEval,
403 ) -> Result<HyperParameterSearchResult<F>>
404 where
405 FnEval: Fn(&HashMap<String, F>) -> Result<F>,
406 {
407 self.history.clear();
409 self.best_value = None;
410
411 for _ in 0..self.max_evals {
412 let params = self.random_params();
414
415 self.set_params(¶ms)?;
417
418 let metric = eval_fn(¶ms)?;
420
421 self.update(metric)?;
423 }
424
425 let result = HyperParameterSearchResult::new(
427 self.metric_name.clone(),
428 self.mode,
429 self.best_value.unwrap_or_else(|| match self.mode {
430 OptimizationMode::Maximize => F::neg_infinity(),
431 OptimizationMode::Minimize => F::infinity(),
432 }),
433 self.best_params.clone(),
434 );
435
436 Ok(result)
437 }
438
439 pub fn best_params(&self) -> &HashMap<String, F> {
441 &self.best_params
442 }
443
444 pub fn best_value(&self) -> Option<F> {
446 self.best_value
447 }
448
449 pub fn history(&self) -> &[(HashMap<String, F>, F)] {
451 &self.history
452 }
453}