Skip to main content

tensorlogic_train/hyperparameter/
search.rs

1//! Grid and random search strategies for hyperparameter optimization.
2
3use scirs2_core::random::{SeedableRng, StdRng};
4use std::collections::HashMap;
5
6use super::space::HyperparamSpace;
7use super::value::{HyperparamConfig, HyperparamResult, HyperparamValue};
8
9/// Grid search strategy for hyperparameter optimization.
10///
11/// Exhaustively searches over a grid of hyperparameter values.
12#[derive(Debug)]
13pub struct GridSearch {
14    /// Parameter space definition.
15    param_space: HashMap<String, HyperparamSpace>,
16    /// Number of grid points per continuous parameter.
17    num_grid_points: usize,
18    /// Results from all evaluations.
19    results: Vec<HyperparamResult>,
20}
21
22impl GridSearch {
23    /// Create a new grid search.
24    ///
25    /// # Arguments
26    /// * `param_space` - Hyperparameter space definition
27    /// * `num_grid_points` - Number of points for continuous parameters
28    pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
29        Self {
30            param_space,
31            num_grid_points,
32            results: Vec::new(),
33        }
34    }
35
36    /// Generate all parameter configurations for grid search.
37    pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
38        if self.param_space.is_empty() {
39            return vec![HashMap::new()];
40        }
41        let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
42        param_names.sort();
43        let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
44        for name in &param_names {
45            let space = &self.param_space[name];
46            all_values.push(space.grid_values(self.num_grid_points));
47        }
48        let mut configs = Vec::new();
49        self.generate_cartesian_product(
50            &param_names,
51            &all_values,
52            0,
53            &mut HashMap::new(),
54            &mut configs,
55        );
56        configs
57    }
58
59    /// Recursively generate Cartesian product of parameter values.
60    #[allow(clippy::only_used_in_recursion)]
61    fn generate_cartesian_product(
62        &self,
63        param_names: &[String],
64        all_values: &[Vec<HyperparamValue>],
65        depth: usize,
66        current_config: &mut HyperparamConfig,
67        configs: &mut Vec<HyperparamConfig>,
68    ) {
69        if depth == param_names.len() {
70            configs.push(current_config.clone());
71            return;
72        }
73        let param_name = &param_names[depth];
74        let values = &all_values[depth];
75        for value in values {
76            current_config.insert(param_name.clone(), value.clone());
77            self.generate_cartesian_product(
78                param_names,
79                all_values,
80                depth + 1,
81                current_config,
82                configs,
83            );
84        }
85        current_config.remove(param_name);
86    }
87
88    /// Add a result from evaluating a configuration.
89    pub fn add_result(&mut self, result: HyperparamResult) {
90        self.results.push(result);
91    }
92
93    /// Get the best result found so far.
94    pub fn best_result(&self) -> Option<&HyperparamResult> {
95        self.results.iter().max_by(|a, b| {
96            a.score
97                .partial_cmp(&b.score)
98                .unwrap_or(std::cmp::Ordering::Equal)
99        })
100    }
101
102    /// Get all results sorted by score (descending).
103    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
104        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
105        results.sort_by(|a, b| {
106            b.score
107                .partial_cmp(&a.score)
108                .unwrap_or(std::cmp::Ordering::Equal)
109        });
110        results
111    }
112
113    /// Get all results.
114    pub fn results(&self) -> &[HyperparamResult] {
115        &self.results
116    }
117
118    /// Get total number of configurations to evaluate.
119    pub fn total_configs(&self) -> usize {
120        self.generate_configs().len()
121    }
122}
123
124/// Random search strategy for hyperparameter optimization.
125///
126/// Randomly samples from the hyperparameter space.
127#[derive(Debug)]
128pub struct RandomSearch {
129    /// Parameter space definition.
130    param_space: HashMap<String, HyperparamSpace>,
131    /// Number of random samples to evaluate.
132    num_samples: usize,
133    /// Random number generator.
134    rng: StdRng,
135    /// Results from all evaluations.
136    results: Vec<HyperparamResult>,
137}
138
139impl RandomSearch {
140    /// Create a new random search.
141    ///
142    /// # Arguments
143    /// * `param_space` - Hyperparameter space definition
144    /// * `num_samples` - Number of random configurations to try
145    /// * `seed` - Random seed for reproducibility
146    pub fn new(
147        param_space: HashMap<String, HyperparamSpace>,
148        num_samples: usize,
149        seed: u64,
150    ) -> Self {
151        Self {
152            param_space,
153            num_samples,
154            rng: StdRng::seed_from_u64(seed),
155            results: Vec::new(),
156        }
157    }
158
159    /// Generate random parameter configurations.
160    pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
161        let mut configs = Vec::with_capacity(self.num_samples);
162        for _ in 0..self.num_samples {
163            let mut config = HashMap::new();
164            for (name, space) in &self.param_space {
165                let value = space.sample(&mut self.rng);
166                config.insert(name.clone(), value);
167            }
168            configs.push(config);
169        }
170        configs
171    }
172
173    /// Add a result from evaluating a configuration.
174    pub fn add_result(&mut self, result: HyperparamResult) {
175        self.results.push(result);
176    }
177
178    /// Get the best result found so far.
179    pub fn best_result(&self) -> Option<&HyperparamResult> {
180        self.results.iter().max_by(|a, b| {
181            a.score
182                .partial_cmp(&b.score)
183                .unwrap_or(std::cmp::Ordering::Equal)
184        })
185    }
186
187    /// Get all results sorted by score (descending).
188    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
189        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
190        results.sort_by(|a, b| {
191            b.score
192                .partial_cmp(&a.score)
193                .unwrap_or(std::cmp::Ordering::Equal)
194        });
195        results
196    }
197
198    /// Get all results.
199    pub fn results(&self) -> &[HyperparamResult] {
200        &self.results
201    }
202}