tensorlogic_train/
hyperparameter.rs

1//! Hyperparameter optimization utilities.
2//!
3//! This module provides various hyperparameter search strategies:
4//! - Grid search (exhaustive search over parameter grid)
5//! - Random search (random sampling from parameter space)
6//! - Parameter space definition
7//! - Result tracking and comparison
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::random::{Rng, SeedableRng, StdRng};
11use std::collections::HashMap;
12
13/// Hyperparameter value type.
14#[derive(Debug, Clone, PartialEq)]
15pub enum HyperparamValue {
16    /// Floating-point value.
17    Float(f64),
18    /// Integer value.
19    Int(i64),
20    /// Boolean value.
21    Bool(bool),
22    /// String value.
23    String(String),
24}
25
26impl HyperparamValue {
27    /// Get as f64, if possible.
28    pub fn as_float(&self) -> Option<f64> {
29        match self {
30            HyperparamValue::Float(v) => Some(*v),
31            HyperparamValue::Int(v) => Some(*v as f64),
32            _ => None,
33        }
34    }
35
36    /// Get as i64, if possible.
37    pub fn as_int(&self) -> Option<i64> {
38        match self {
39            HyperparamValue::Int(v) => Some(*v),
40            HyperparamValue::Float(v) => Some(*v as i64),
41            _ => None,
42        }
43    }
44
45    /// Get as bool, if possible.
46    pub fn as_bool(&self) -> Option<bool> {
47        match self {
48            HyperparamValue::Bool(v) => Some(*v),
49            _ => None,
50        }
51    }
52
53    /// Get as string, if possible.
54    pub fn as_string(&self) -> Option<&str> {
55        match self {
56            HyperparamValue::String(v) => Some(v),
57            _ => None,
58        }
59    }
60}
61
62/// Hyperparameter space definition.
63#[derive(Debug, Clone)]
64pub enum HyperparamSpace {
65    /// Discrete choices.
66    Discrete(Vec<HyperparamValue>),
67    /// Continuous range [min, max].
68    Continuous { min: f64, max: f64 },
69    /// Log-uniform distribution [min, max].
70    LogUniform { min: f64, max: f64 },
71    /// Integer range [min, max].
72    IntRange { min: i64, max: i64 },
73}
74
75impl HyperparamSpace {
76    /// Create a discrete choice space.
77    pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
78        if values.is_empty() {
79            return Err(TrainError::InvalidParameter(
80                "Discrete space cannot be empty".to_string(),
81            ));
82        }
83        Ok(Self::Discrete(values))
84    }
85
86    /// Create a continuous range space.
87    pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
88        if min >= max {
89            return Err(TrainError::InvalidParameter(
90                "min must be less than max".to_string(),
91            ));
92        }
93        Ok(Self::Continuous { min, max })
94    }
95
96    /// Create a log-uniform distribution space.
97    pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
98        if min <= 0.0 || max <= 0.0 || min >= max {
99            return Err(TrainError::InvalidParameter(
100                "min and max must be positive and min < max".to_string(),
101            ));
102        }
103        Ok(Self::LogUniform { min, max })
104    }
105
106    /// Create an integer range space.
107    pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
108        if min >= max {
109            return Err(TrainError::InvalidParameter(
110                "min must be less than max".to_string(),
111            ));
112        }
113        Ok(Self::IntRange { min, max })
114    }
115
116    /// Sample a value from this space.
117    pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
118        match self {
119            HyperparamSpace::Discrete(values) => {
120                let idx = rng.gen_range(0..values.len());
121                values[idx].clone()
122            }
123            HyperparamSpace::Continuous { min, max } => {
124                let value = min + (max - min) * rng.random::<f64>();
125                HyperparamValue::Float(value)
126            }
127            HyperparamSpace::LogUniform { min, max } => {
128                let log_min = min.ln();
129                let log_max = max.ln();
130                let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
131                HyperparamValue::Float(log_value.exp())
132            }
133            HyperparamSpace::IntRange { min, max } => {
134                let value = rng.gen_range(*min..=*max);
135                HyperparamValue::Int(value)
136            }
137        }
138    }
139
140    /// Get all possible values for grid search (for discrete/int spaces).
141    pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
142        match self {
143            HyperparamSpace::Discrete(values) => values.clone(),
144            HyperparamSpace::IntRange { min, max } => {
145                let range_size = (max - min + 1) as usize;
146                let step = (range_size / num_samples).max(1);
147                (*min..=*max)
148                    .step_by(step)
149                    .map(HyperparamValue::Int)
150                    .collect()
151            }
152            HyperparamSpace::Continuous { min, max } => {
153                let step = (max - min) / (num_samples as f64);
154                (0..num_samples)
155                    .map(|i| HyperparamValue::Float(min + step * i as f64))
156                    .collect()
157            }
158            HyperparamSpace::LogUniform { min, max } => {
159                let log_min = min.ln();
160                let log_max = max.ln();
161                let log_step = (log_max - log_min) / (num_samples as f64);
162                (0..num_samples)
163                    .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
164                    .collect()
165            }
166        }
167    }
168}
169
170/// Hyperparameter configuration (a single point in parameter space).
171pub type HyperparamConfig = HashMap<String, HyperparamValue>;
172
173/// Result of a hyperparameter evaluation.
174#[derive(Debug, Clone)]
175pub struct HyperparamResult {
176    /// Hyperparameter configuration used.
177    pub config: HyperparamConfig,
178    /// Evaluation score (higher is better).
179    pub score: f64,
180    /// Additional metrics.
181    pub metrics: HashMap<String, f64>,
182}
183
184impl HyperparamResult {
185    /// Create a new result.
186    pub fn new(config: HyperparamConfig, score: f64) -> Self {
187        Self {
188            config,
189            score,
190            metrics: HashMap::new(),
191        }
192    }
193
194    /// Add a metric to the result.
195    pub fn with_metric(mut self, name: String, value: f64) -> Self {
196        self.metrics.insert(name, value);
197        self
198    }
199}
200
201/// Grid search strategy for hyperparameter optimization.
202///
203/// Exhaustively searches over a grid of hyperparameter values.
204#[derive(Debug)]
205pub struct GridSearch {
206    /// Parameter space definition.
207    param_space: HashMap<String, HyperparamSpace>,
208    /// Number of grid points per continuous parameter.
209    num_grid_points: usize,
210    /// Results from all evaluations.
211    results: Vec<HyperparamResult>,
212}
213
214impl GridSearch {
215    /// Create a new grid search.
216    ///
217    /// # Arguments
218    /// * `param_space` - Hyperparameter space definition
219    /// * `num_grid_points` - Number of points for continuous parameters
220    pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
221        Self {
222            param_space,
223            num_grid_points,
224            results: Vec::new(),
225        }
226    }
227
228    /// Generate all parameter configurations for grid search.
229    pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
230        if self.param_space.is_empty() {
231            return vec![HashMap::new()];
232        }
233
234        let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
235        param_names.sort(); // Ensure deterministic order
236
237        let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
238        for name in &param_names {
239            let space = &self.param_space[name];
240            all_values.push(space.grid_values(self.num_grid_points));
241        }
242
243        // Generate Cartesian product
244        let mut configs = Vec::new();
245        self.generate_cartesian_product(
246            &param_names,
247            &all_values,
248            0,
249            &mut HashMap::new(),
250            &mut configs,
251        );
252
253        configs
254    }
255
256    /// Recursively generate Cartesian product of parameter values.
257    #[allow(clippy::only_used_in_recursion)]
258    fn generate_cartesian_product(
259        &self,
260        param_names: &[String],
261        all_values: &[Vec<HyperparamValue>],
262        depth: usize,
263        current_config: &mut HyperparamConfig,
264        configs: &mut Vec<HyperparamConfig>,
265    ) {
266        if depth == param_names.len() {
267            configs.push(current_config.clone());
268            return;
269        }
270
271        let param_name = &param_names[depth];
272        let values = &all_values[depth];
273
274        for value in values {
275            current_config.insert(param_name.clone(), value.clone());
276            self.generate_cartesian_product(
277                param_names,
278                all_values,
279                depth + 1,
280                current_config,
281                configs,
282            );
283        }
284
285        current_config.remove(param_name);
286    }
287
288    /// Add a result from evaluating a configuration.
289    pub fn add_result(&mut self, result: HyperparamResult) {
290        self.results.push(result);
291    }
292
293    /// Get the best result found so far.
294    pub fn best_result(&self) -> Option<&HyperparamResult> {
295        self.results.iter().max_by(|a, b| {
296            a.score
297                .partial_cmp(&b.score)
298                .unwrap_or(std::cmp::Ordering::Equal)
299        })
300    }
301
302    /// Get all results sorted by score (descending).
303    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
304        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
305        results.sort_by(|a, b| {
306            b.score
307                .partial_cmp(&a.score)
308                .unwrap_or(std::cmp::Ordering::Equal)
309        });
310        results
311    }
312
313    /// Get all results.
314    pub fn results(&self) -> &[HyperparamResult] {
315        &self.results
316    }
317
318    /// Get total number of configurations to evaluate.
319    pub fn total_configs(&self) -> usize {
320        self.generate_configs().len()
321    }
322}
323
324/// Random search strategy for hyperparameter optimization.
325///
326/// Randomly samples from the hyperparameter space.
327#[derive(Debug)]
328pub struct RandomSearch {
329    /// Parameter space definition.
330    param_space: HashMap<String, HyperparamSpace>,
331    /// Number of random samples to evaluate.
332    num_samples: usize,
333    /// Random number generator.
334    rng: StdRng,
335    /// Results from all evaluations.
336    results: Vec<HyperparamResult>,
337}
338
339impl RandomSearch {
340    /// Create a new random search.
341    ///
342    /// # Arguments
343    /// * `param_space` - Hyperparameter space definition
344    /// * `num_samples` - Number of random configurations to try
345    /// * `seed` - Random seed for reproducibility
346    pub fn new(
347        param_space: HashMap<String, HyperparamSpace>,
348        num_samples: usize,
349        seed: u64,
350    ) -> Self {
351        Self {
352            param_space,
353            num_samples,
354            rng: StdRng::seed_from_u64(seed),
355            results: Vec::new(),
356        }
357    }
358
359    /// Generate random parameter configurations.
360    pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
361        let mut configs = Vec::with_capacity(self.num_samples);
362
363        for _ in 0..self.num_samples {
364            let mut config = HashMap::new();
365
366            for (name, space) in &self.param_space {
367                let value = space.sample(&mut self.rng);
368                config.insert(name.clone(), value);
369            }
370
371            configs.push(config);
372        }
373
374        configs
375    }
376
377    /// Add a result from evaluating a configuration.
378    pub fn add_result(&mut self, result: HyperparamResult) {
379        self.results.push(result);
380    }
381
382    /// Get the best result found so far.
383    pub fn best_result(&self) -> Option<&HyperparamResult> {
384        self.results.iter().max_by(|a, b| {
385            a.score
386                .partial_cmp(&b.score)
387                .unwrap_or(std::cmp::Ordering::Equal)
388        })
389    }
390
391    /// Get all results sorted by score (descending).
392    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
393        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
394        results.sort_by(|a, b| {
395            b.score
396                .partial_cmp(&a.score)
397                .unwrap_or(std::cmp::Ordering::Equal)
398        });
399        results
400    }
401
402    /// Get all results.
403    pub fn results(&self) -> &[HyperparamResult] {
404        &self.results
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_hyperparam_value() {
414        let float_val = HyperparamValue::Float(3.5);
415        assert_eq!(float_val.as_float(), Some(3.5));
416        assert_eq!(float_val.as_int(), Some(3));
417
418        let int_val = HyperparamValue::Int(42);
419        assert_eq!(int_val.as_int(), Some(42));
420        assert_eq!(int_val.as_float(), Some(42.0));
421
422        let bool_val = HyperparamValue::Bool(true);
423        assert_eq!(bool_val.as_bool(), Some(true));
424
425        let string_val = HyperparamValue::String("test".to_string());
426        assert_eq!(string_val.as_string(), Some("test"));
427    }
428
429    #[test]
430    fn test_hyperparam_space_discrete() {
431        let space = HyperparamSpace::discrete(vec![
432            HyperparamValue::Float(0.1),
433            HyperparamValue::Float(0.01),
434        ])
435        .unwrap();
436
437        let values = space.grid_values(10);
438        assert_eq!(values.len(), 2);
439
440        let mut rng = StdRng::seed_from_u64(42);
441        let sampled = space.sample(&mut rng);
442        assert!(matches!(sampled, HyperparamValue::Float(_)));
443    }
444
445    #[test]
446    fn test_hyperparam_space_continuous() {
447        let space = HyperparamSpace::continuous(0.0, 1.0).unwrap();
448
449        let values = space.grid_values(5);
450        assert_eq!(values.len(), 5);
451
452        let mut rng = StdRng::seed_from_u64(42);
453        let sampled = space.sample(&mut rng);
454        if let HyperparamValue::Float(v) = sampled {
455            assert!((0.0..=1.0).contains(&v));
456        } else {
457            panic!("Expected Float value");
458        }
459    }
460
461    #[test]
462    fn test_hyperparam_space_log_uniform() {
463        let space = HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap();
464
465        let values = space.grid_values(3);
466        assert_eq!(values.len(), 3);
467
468        let mut rng = StdRng::seed_from_u64(42);
469        let sampled = space.sample(&mut rng);
470        if let HyperparamValue::Float(v) = sampled {
471            assert!((1e-4..=1e-1).contains(&v));
472        } else {
473            panic!("Expected Float value");
474        }
475    }
476
477    #[test]
478    fn test_hyperparam_space_int_range() {
479        let space = HyperparamSpace::int_range(1, 10).unwrap();
480
481        let values = space.grid_values(5);
482        assert!(!values.is_empty());
483
484        let mut rng = StdRng::seed_from_u64(42);
485        let sampled = space.sample(&mut rng);
486        if let HyperparamValue::Int(v) = sampled {
487            assert!((1..=10).contains(&v));
488        } else {
489            panic!("Expected Int value");
490        }
491    }
492
493    #[test]
494    fn test_hyperparam_space_invalid() {
495        assert!(HyperparamSpace::discrete(vec![]).is_err());
496        assert!(HyperparamSpace::continuous(1.0, 0.0).is_err());
497        assert!(HyperparamSpace::log_uniform(0.0, 1.0).is_err());
498        assert!(HyperparamSpace::log_uniform(1.0, 0.5).is_err());
499        assert!(HyperparamSpace::int_range(10, 5).is_err());
500    }
501
502    #[test]
503    fn test_grid_search() {
504        let mut param_space = HashMap::new();
505        param_space.insert(
506            "lr".to_string(),
507            HyperparamSpace::discrete(vec![
508                HyperparamValue::Float(0.1),
509                HyperparamValue::Float(0.01),
510            ])
511            .unwrap(),
512        );
513        param_space.insert(
514            "batch_size".to_string(),
515            HyperparamSpace::int_range(16, 64).unwrap(),
516        );
517
518        let grid_search = GridSearch::new(param_space, 3);
519
520        let configs = grid_search.generate_configs();
521        assert!(!configs.is_empty());
522
523        // Should have 2 (lr values) * grid_points (batch_size values) configs
524        assert!(configs.len() >= 2);
525    }
526
527    #[test]
528    fn test_grid_search_results() {
529        let mut param_space = HashMap::new();
530        param_space.insert(
531            "lr".to_string(),
532            HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
533        );
534
535        let mut grid_search = GridSearch::new(param_space, 3);
536
537        let mut config = HashMap::new();
538        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
539
540        grid_search.add_result(HyperparamResult::new(config.clone(), 0.9));
541        grid_search.add_result(HyperparamResult::new(config.clone(), 0.95));
542        grid_search.add_result(HyperparamResult::new(config, 0.85));
543
544        let best = grid_search.best_result().unwrap();
545        assert_eq!(best.score, 0.95);
546
547        let sorted = grid_search.sorted_results();
548        assert_eq!(sorted[0].score, 0.95);
549        assert_eq!(sorted[1].score, 0.9);
550        assert_eq!(sorted[2].score, 0.85);
551    }
552
553    #[test]
554    fn test_random_search() {
555        let mut param_space = HashMap::new();
556        param_space.insert(
557            "lr".to_string(),
558            HyperparamSpace::continuous(1e-4, 1e-1).unwrap(),
559        );
560        param_space.insert(
561            "dropout".to_string(),
562            HyperparamSpace::continuous(0.0, 0.5).unwrap(),
563        );
564
565        let mut random_search = RandomSearch::new(param_space, 10, 42);
566
567        let configs = random_search.generate_configs();
568        assert_eq!(configs.len(), 10);
569
570        // Check that each config has all parameters
571        for config in &configs {
572            assert!(config.contains_key("lr"));
573            assert!(config.contains_key("dropout"));
574        }
575    }
576
577    #[test]
578    fn test_random_search_results() {
579        let mut param_space = HashMap::new();
580        param_space.insert(
581            "lr".to_string(),
582            HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
583        );
584
585        let mut random_search = RandomSearch::new(param_space, 5, 42);
586
587        let mut config = HashMap::new();
588        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
589
590        random_search.add_result(HyperparamResult::new(config.clone(), 0.8));
591        random_search.add_result(HyperparamResult::new(config, 0.9));
592
593        let best = random_search.best_result().unwrap();
594        assert_eq!(best.score, 0.9);
595
596        assert_eq!(random_search.results().len(), 2);
597    }
598
599    #[test]
600    fn test_hyperparam_result_with_metrics() {
601        let mut config = HashMap::new();
602        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
603
604        let result = HyperparamResult::new(config, 0.95)
605            .with_metric("accuracy".to_string(), 0.95)
606            .with_metric("loss".to_string(), 0.05);
607
608        assert_eq!(result.score, 0.95);
609        assert_eq!(result.metrics.get("accuracy"), Some(&0.95));
610        assert_eq!(result.metrics.get("loss"), Some(&0.05));
611    }
612
613    #[test]
614    fn test_grid_search_empty_space() {
615        let grid_search = GridSearch::new(HashMap::new(), 3);
616        let configs = grid_search.generate_configs();
617        assert_eq!(configs.len(), 1); // One empty config
618        assert!(configs[0].is_empty());
619    }
620
621    #[test]
622    fn test_grid_search_total_configs() {
623        let mut param_space = HashMap::new();
624        param_space.insert(
625            "lr".to_string(),
626            HyperparamSpace::discrete(vec![
627                HyperparamValue::Float(0.1),
628                HyperparamValue::Float(0.01),
629            ])
630            .unwrap(),
631        );
632
633        let grid_search = GridSearch::new(param_space, 3);
634        assert_eq!(grid_search.total_configs(), 2);
635    }
636}