1use crate::{TrainError, TrainResult};
10use scirs2_core::random::{Rng, SeedableRng, StdRng};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum HyperparamValue {
16 Float(f64),
18 Int(i64),
20 Bool(bool),
22 String(String),
24}
25
26impl HyperparamValue {
27 pub fn as_float(&self) -> Option<f64> {
29 match self {
30 HyperparamValue::Float(v) => Some(*v),
31 HyperparamValue::Int(v) => Some(*v as f64),
32 _ => None,
33 }
34 }
35
36 pub fn as_int(&self) -> Option<i64> {
38 match self {
39 HyperparamValue::Int(v) => Some(*v),
40 HyperparamValue::Float(v) => Some(*v as i64),
41 _ => None,
42 }
43 }
44
45 pub fn as_bool(&self) -> Option<bool> {
47 match self {
48 HyperparamValue::Bool(v) => Some(*v),
49 _ => None,
50 }
51 }
52
53 pub fn as_string(&self) -> Option<&str> {
55 match self {
56 HyperparamValue::String(v) => Some(v),
57 _ => None,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub enum HyperparamSpace {
65 Discrete(Vec<HyperparamValue>),
67 Continuous { min: f64, max: f64 },
69 LogUniform { min: f64, max: f64 },
71 IntRange { min: i64, max: i64 },
73}
74
75impl HyperparamSpace {
76 pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
78 if values.is_empty() {
79 return Err(TrainError::InvalidParameter(
80 "Discrete space cannot be empty".to_string(),
81 ));
82 }
83 Ok(Self::Discrete(values))
84 }
85
86 pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
88 if min >= max {
89 return Err(TrainError::InvalidParameter(
90 "min must be less than max".to_string(),
91 ));
92 }
93 Ok(Self::Continuous { min, max })
94 }
95
96 pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
98 if min <= 0.0 || max <= 0.0 || min >= max {
99 return Err(TrainError::InvalidParameter(
100 "min and max must be positive and min < max".to_string(),
101 ));
102 }
103 Ok(Self::LogUniform { min, max })
104 }
105
106 pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
108 if min >= max {
109 return Err(TrainError::InvalidParameter(
110 "min must be less than max".to_string(),
111 ));
112 }
113 Ok(Self::IntRange { min, max })
114 }
115
116 pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
118 match self {
119 HyperparamSpace::Discrete(values) => {
120 let idx = rng.gen_range(0..values.len());
121 values[idx].clone()
122 }
123 HyperparamSpace::Continuous { min, max } => {
124 let value = min + (max - min) * rng.random::<f64>();
125 HyperparamValue::Float(value)
126 }
127 HyperparamSpace::LogUniform { min, max } => {
128 let log_min = min.ln();
129 let log_max = max.ln();
130 let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
131 HyperparamValue::Float(log_value.exp())
132 }
133 HyperparamSpace::IntRange { min, max } => {
134 let value = rng.gen_range(*min..=*max);
135 HyperparamValue::Int(value)
136 }
137 }
138 }
139
140 pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
142 match self {
143 HyperparamSpace::Discrete(values) => values.clone(),
144 HyperparamSpace::IntRange { min, max } => {
145 let range_size = (max - min + 1) as usize;
146 let step = (range_size / num_samples).max(1);
147 (*min..=*max)
148 .step_by(step)
149 .map(HyperparamValue::Int)
150 .collect()
151 }
152 HyperparamSpace::Continuous { min, max } => {
153 let step = (max - min) / (num_samples as f64);
154 (0..num_samples)
155 .map(|i| HyperparamValue::Float(min + step * i as f64))
156 .collect()
157 }
158 HyperparamSpace::LogUniform { min, max } => {
159 let log_min = min.ln();
160 let log_max = max.ln();
161 let log_step = (log_max - log_min) / (num_samples as f64);
162 (0..num_samples)
163 .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
164 .collect()
165 }
166 }
167 }
168}
169
170pub type HyperparamConfig = HashMap<String, HyperparamValue>;
172
173#[derive(Debug, Clone)]
175pub struct HyperparamResult {
176 pub config: HyperparamConfig,
178 pub score: f64,
180 pub metrics: HashMap<String, f64>,
182}
183
184impl HyperparamResult {
185 pub fn new(config: HyperparamConfig, score: f64) -> Self {
187 Self {
188 config,
189 score,
190 metrics: HashMap::new(),
191 }
192 }
193
194 pub fn with_metric(mut self, name: String, value: f64) -> Self {
196 self.metrics.insert(name, value);
197 self
198 }
199}
200
201#[derive(Debug)]
205pub struct GridSearch {
206 param_space: HashMap<String, HyperparamSpace>,
208 num_grid_points: usize,
210 results: Vec<HyperparamResult>,
212}
213
214impl GridSearch {
215 pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
221 Self {
222 param_space,
223 num_grid_points,
224 results: Vec::new(),
225 }
226 }
227
228 pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
230 if self.param_space.is_empty() {
231 return vec![HashMap::new()];
232 }
233
234 let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
235 param_names.sort(); let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
238 for name in ¶m_names {
239 let space = &self.param_space[name];
240 all_values.push(space.grid_values(self.num_grid_points));
241 }
242
243 let mut configs = Vec::new();
245 self.generate_cartesian_product(
246 ¶m_names,
247 &all_values,
248 0,
249 &mut HashMap::new(),
250 &mut configs,
251 );
252
253 configs
254 }
255
256 #[allow(clippy::only_used_in_recursion)]
258 fn generate_cartesian_product(
259 &self,
260 param_names: &[String],
261 all_values: &[Vec<HyperparamValue>],
262 depth: usize,
263 current_config: &mut HyperparamConfig,
264 configs: &mut Vec<HyperparamConfig>,
265 ) {
266 if depth == param_names.len() {
267 configs.push(current_config.clone());
268 return;
269 }
270
271 let param_name = ¶m_names[depth];
272 let values = &all_values[depth];
273
274 for value in values {
275 current_config.insert(param_name.clone(), value.clone());
276 self.generate_cartesian_product(
277 param_names,
278 all_values,
279 depth + 1,
280 current_config,
281 configs,
282 );
283 }
284
285 current_config.remove(param_name);
286 }
287
288 pub fn add_result(&mut self, result: HyperparamResult) {
290 self.results.push(result);
291 }
292
293 pub fn best_result(&self) -> Option<&HyperparamResult> {
295 self.results.iter().max_by(|a, b| {
296 a.score
297 .partial_cmp(&b.score)
298 .unwrap_or(std::cmp::Ordering::Equal)
299 })
300 }
301
302 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
304 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
305 results.sort_by(|a, b| {
306 b.score
307 .partial_cmp(&a.score)
308 .unwrap_or(std::cmp::Ordering::Equal)
309 });
310 results
311 }
312
313 pub fn results(&self) -> &[HyperparamResult] {
315 &self.results
316 }
317
318 pub fn total_configs(&self) -> usize {
320 self.generate_configs().len()
321 }
322}
323
324#[derive(Debug)]
328pub struct RandomSearch {
329 param_space: HashMap<String, HyperparamSpace>,
331 num_samples: usize,
333 rng: StdRng,
335 results: Vec<HyperparamResult>,
337}
338
339impl RandomSearch {
340 pub fn new(
347 param_space: HashMap<String, HyperparamSpace>,
348 num_samples: usize,
349 seed: u64,
350 ) -> Self {
351 Self {
352 param_space,
353 num_samples,
354 rng: StdRng::seed_from_u64(seed),
355 results: Vec::new(),
356 }
357 }
358
359 pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
361 let mut configs = Vec::with_capacity(self.num_samples);
362
363 for _ in 0..self.num_samples {
364 let mut config = HashMap::new();
365
366 for (name, space) in &self.param_space {
367 let value = space.sample(&mut self.rng);
368 config.insert(name.clone(), value);
369 }
370
371 configs.push(config);
372 }
373
374 configs
375 }
376
377 pub fn add_result(&mut self, result: HyperparamResult) {
379 self.results.push(result);
380 }
381
382 pub fn best_result(&self) -> Option<&HyperparamResult> {
384 self.results.iter().max_by(|a, b| {
385 a.score
386 .partial_cmp(&b.score)
387 .unwrap_or(std::cmp::Ordering::Equal)
388 })
389 }
390
391 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
393 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
394 results.sort_by(|a, b| {
395 b.score
396 .partial_cmp(&a.score)
397 .unwrap_or(std::cmp::Ordering::Equal)
398 });
399 results
400 }
401
402 pub fn results(&self) -> &[HyperparamResult] {
404 &self.results
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_hyperparam_value() {
414 let float_val = HyperparamValue::Float(3.5);
415 assert_eq!(float_val.as_float(), Some(3.5));
416 assert_eq!(float_val.as_int(), Some(3));
417
418 let int_val = HyperparamValue::Int(42);
419 assert_eq!(int_val.as_int(), Some(42));
420 assert_eq!(int_val.as_float(), Some(42.0));
421
422 let bool_val = HyperparamValue::Bool(true);
423 assert_eq!(bool_val.as_bool(), Some(true));
424
425 let string_val = HyperparamValue::String("test".to_string());
426 assert_eq!(string_val.as_string(), Some("test"));
427 }
428
429 #[test]
430 fn test_hyperparam_space_discrete() {
431 let space = HyperparamSpace::discrete(vec![
432 HyperparamValue::Float(0.1),
433 HyperparamValue::Float(0.01),
434 ])
435 .unwrap();
436
437 let values = space.grid_values(10);
438 assert_eq!(values.len(), 2);
439
440 let mut rng = StdRng::seed_from_u64(42);
441 let sampled = space.sample(&mut rng);
442 assert!(matches!(sampled, HyperparamValue::Float(_)));
443 }
444
445 #[test]
446 fn test_hyperparam_space_continuous() {
447 let space = HyperparamSpace::continuous(0.0, 1.0).unwrap();
448
449 let values = space.grid_values(5);
450 assert_eq!(values.len(), 5);
451
452 let mut rng = StdRng::seed_from_u64(42);
453 let sampled = space.sample(&mut rng);
454 if let HyperparamValue::Float(v) = sampled {
455 assert!((0.0..=1.0).contains(&v));
456 } else {
457 panic!("Expected Float value");
458 }
459 }
460
461 #[test]
462 fn test_hyperparam_space_log_uniform() {
463 let space = HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap();
464
465 let values = space.grid_values(3);
466 assert_eq!(values.len(), 3);
467
468 let mut rng = StdRng::seed_from_u64(42);
469 let sampled = space.sample(&mut rng);
470 if let HyperparamValue::Float(v) = sampled {
471 assert!((1e-4..=1e-1).contains(&v));
472 } else {
473 panic!("Expected Float value");
474 }
475 }
476
477 #[test]
478 fn test_hyperparam_space_int_range() {
479 let space = HyperparamSpace::int_range(1, 10).unwrap();
480
481 let values = space.grid_values(5);
482 assert!(!values.is_empty());
483
484 let mut rng = StdRng::seed_from_u64(42);
485 let sampled = space.sample(&mut rng);
486 if let HyperparamValue::Int(v) = sampled {
487 assert!((1..=10).contains(&v));
488 } else {
489 panic!("Expected Int value");
490 }
491 }
492
493 #[test]
494 fn test_hyperparam_space_invalid() {
495 assert!(HyperparamSpace::discrete(vec![]).is_err());
496 assert!(HyperparamSpace::continuous(1.0, 0.0).is_err());
497 assert!(HyperparamSpace::log_uniform(0.0, 1.0).is_err());
498 assert!(HyperparamSpace::log_uniform(1.0, 0.5).is_err());
499 assert!(HyperparamSpace::int_range(10, 5).is_err());
500 }
501
502 #[test]
503 fn test_grid_search() {
504 let mut param_space = HashMap::new();
505 param_space.insert(
506 "lr".to_string(),
507 HyperparamSpace::discrete(vec![
508 HyperparamValue::Float(0.1),
509 HyperparamValue::Float(0.01),
510 ])
511 .unwrap(),
512 );
513 param_space.insert(
514 "batch_size".to_string(),
515 HyperparamSpace::int_range(16, 64).unwrap(),
516 );
517
518 let grid_search = GridSearch::new(param_space, 3);
519
520 let configs = grid_search.generate_configs();
521 assert!(!configs.is_empty());
522
523 assert!(configs.len() >= 2);
525 }
526
527 #[test]
528 fn test_grid_search_results() {
529 let mut param_space = HashMap::new();
530 param_space.insert(
531 "lr".to_string(),
532 HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
533 );
534
535 let mut grid_search = GridSearch::new(param_space, 3);
536
537 let mut config = HashMap::new();
538 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
539
540 grid_search.add_result(HyperparamResult::new(config.clone(), 0.9));
541 grid_search.add_result(HyperparamResult::new(config.clone(), 0.95));
542 grid_search.add_result(HyperparamResult::new(config, 0.85));
543
544 let best = grid_search.best_result().unwrap();
545 assert_eq!(best.score, 0.95);
546
547 let sorted = grid_search.sorted_results();
548 assert_eq!(sorted[0].score, 0.95);
549 assert_eq!(sorted[1].score, 0.9);
550 assert_eq!(sorted[2].score, 0.85);
551 }
552
553 #[test]
554 fn test_random_search() {
555 let mut param_space = HashMap::new();
556 param_space.insert(
557 "lr".to_string(),
558 HyperparamSpace::continuous(1e-4, 1e-1).unwrap(),
559 );
560 param_space.insert(
561 "dropout".to_string(),
562 HyperparamSpace::continuous(0.0, 0.5).unwrap(),
563 );
564
565 let mut random_search = RandomSearch::new(param_space, 10, 42);
566
567 let configs = random_search.generate_configs();
568 assert_eq!(configs.len(), 10);
569
570 for config in &configs {
572 assert!(config.contains_key("lr"));
573 assert!(config.contains_key("dropout"));
574 }
575 }
576
577 #[test]
578 fn test_random_search_results() {
579 let mut param_space = HashMap::new();
580 param_space.insert(
581 "lr".to_string(),
582 HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
583 );
584
585 let mut random_search = RandomSearch::new(param_space, 5, 42);
586
587 let mut config = HashMap::new();
588 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
589
590 random_search.add_result(HyperparamResult::new(config.clone(), 0.8));
591 random_search.add_result(HyperparamResult::new(config, 0.9));
592
593 let best = random_search.best_result().unwrap();
594 assert_eq!(best.score, 0.9);
595
596 assert_eq!(random_search.results().len(), 2);
597 }
598
599 #[test]
600 fn test_hyperparam_result_with_metrics() {
601 let mut config = HashMap::new();
602 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
603
604 let result = HyperparamResult::new(config, 0.95)
605 .with_metric("accuracy".to_string(), 0.95)
606 .with_metric("loss".to_string(), 0.05);
607
608 assert_eq!(result.score, 0.95);
609 assert_eq!(result.metrics.get("accuracy"), Some(&0.95));
610 assert_eq!(result.metrics.get("loss"), Some(&0.05));
611 }
612
613 #[test]
614 fn test_grid_search_empty_space() {
615 let grid_search = GridSearch::new(HashMap::new(), 3);
616 let configs = grid_search.generate_configs();
617 assert_eq!(configs.len(), 1); assert!(configs[0].is_empty());
619 }
620
621 #[test]
622 fn test_grid_search_total_configs() {
623 let mut param_space = HashMap::new();
624 param_space.insert(
625 "lr".to_string(),
626 HyperparamSpace::discrete(vec![
627 HyperparamValue::Float(0.1),
628 HyperparamValue::Float(0.01),
629 ])
630 .unwrap(),
631 );
632
633 let grid_search = GridSearch::new(param_space, 3);
634 assert_eq!(grid_search.total_configs(), 2);
635 }
636}