tensorlogic_train/hyperparameter/
search.rs1use scirs2_core::random::{SeedableRng, StdRng};
4use std::collections::HashMap;
5
6use super::space::HyperparamSpace;
7use super::value::{HyperparamConfig, HyperparamResult, HyperparamValue};
8
9#[derive(Debug)]
13pub struct GridSearch {
14 param_space: HashMap<String, HyperparamSpace>,
16 num_grid_points: usize,
18 results: Vec<HyperparamResult>,
20}
21
22impl GridSearch {
23 pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
29 Self {
30 param_space,
31 num_grid_points,
32 results: Vec::new(),
33 }
34 }
35
36 pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
38 if self.param_space.is_empty() {
39 return vec![HashMap::new()];
40 }
41 let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
42 param_names.sort();
43 let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
44 for name in ¶m_names {
45 let space = &self.param_space[name];
46 all_values.push(space.grid_values(self.num_grid_points));
47 }
48 let mut configs = Vec::new();
49 self.generate_cartesian_product(
50 ¶m_names,
51 &all_values,
52 0,
53 &mut HashMap::new(),
54 &mut configs,
55 );
56 configs
57 }
58
59 #[allow(clippy::only_used_in_recursion)]
61 fn generate_cartesian_product(
62 &self,
63 param_names: &[String],
64 all_values: &[Vec<HyperparamValue>],
65 depth: usize,
66 current_config: &mut HyperparamConfig,
67 configs: &mut Vec<HyperparamConfig>,
68 ) {
69 if depth == param_names.len() {
70 configs.push(current_config.clone());
71 return;
72 }
73 let param_name = ¶m_names[depth];
74 let values = &all_values[depth];
75 for value in values {
76 current_config.insert(param_name.clone(), value.clone());
77 self.generate_cartesian_product(
78 param_names,
79 all_values,
80 depth + 1,
81 current_config,
82 configs,
83 );
84 }
85 current_config.remove(param_name);
86 }
87
88 pub fn add_result(&mut self, result: HyperparamResult) {
90 self.results.push(result);
91 }
92
93 pub fn best_result(&self) -> Option<&HyperparamResult> {
95 self.results.iter().max_by(|a, b| {
96 a.score
97 .partial_cmp(&b.score)
98 .unwrap_or(std::cmp::Ordering::Equal)
99 })
100 }
101
102 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
104 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
105 results.sort_by(|a, b| {
106 b.score
107 .partial_cmp(&a.score)
108 .unwrap_or(std::cmp::Ordering::Equal)
109 });
110 results
111 }
112
113 pub fn results(&self) -> &[HyperparamResult] {
115 &self.results
116 }
117
118 pub fn total_configs(&self) -> usize {
120 self.generate_configs().len()
121 }
122}
123
124#[derive(Debug)]
128pub struct RandomSearch {
129 param_space: HashMap<String, HyperparamSpace>,
131 num_samples: usize,
133 rng: StdRng,
135 results: Vec<HyperparamResult>,
137}
138
139impl RandomSearch {
140 pub fn new(
147 param_space: HashMap<String, HyperparamSpace>,
148 num_samples: usize,
149 seed: u64,
150 ) -> Self {
151 Self {
152 param_space,
153 num_samples,
154 rng: StdRng::seed_from_u64(seed),
155 results: Vec::new(),
156 }
157 }
158
159 pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
161 let mut configs = Vec::with_capacity(self.num_samples);
162 for _ in 0..self.num_samples {
163 let mut config = HashMap::new();
164 for (name, space) in &self.param_space {
165 let value = space.sample(&mut self.rng);
166 config.insert(name.clone(), value);
167 }
168 configs.push(config);
169 }
170 configs
171 }
172
173 pub fn add_result(&mut self, result: HyperparamResult) {
175 self.results.push(result);
176 }
177
178 pub fn best_result(&self) -> Option<&HyperparamResult> {
180 self.results.iter().max_by(|a, b| {
181 a.score
182 .partial_cmp(&b.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 })
185 }
186
187 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
189 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
190 results.sort_by(|a, b| {
191 b.score
192 .partial_cmp(&a.score)
193 .unwrap_or(std::cmp::Ordering::Equal)
194 });
195 results
196 }
197
198 pub fn results(&self) -> &[HyperparamResult] {
200 &self.results
201 }
202}