scirs2_optimize/multi_fidelity/
hyperband.rs1use crate::error::{OptimizeError, OptimizeResult};
14
15use super::successive_halving::SuccessiveHalving;
16use super::types::{ConfigSampler, EvaluationResult, MultiFidelityConfig, MultiFidelityResult};
17
18#[derive(Debug, Clone)]
24pub(crate) struct BracketConfig {
25 pub n_initial: usize,
27 pub min_budget: f64,
29 pub max_budget: f64,
31 pub n_rounds: usize,
33}
34
35#[derive(Debug, Clone)]
53pub struct Hyperband {
54 config: MultiFidelityConfig,
55}
56
57impl Hyperband {
58 pub fn new(config: MultiFidelityConfig) -> OptimizeResult<Self> {
60 config.validate()?;
61 Ok(Self { config })
62 }
63
64 pub(crate) fn compute_brackets(&self) -> Vec<BracketConfig> {
66 let s_max = self.config.s_max();
67 let eta = self.config.eta;
68 let eta_f = eta as f64;
69 let mut brackets = Vec::with_capacity(s_max + 1);
70
71 for s in (0..=s_max).rev() {
72 let n_initial =
73 ((s_max + 1) as f64 / (s + 1) as f64 * eta_f.powi(s as i32)).ceil() as usize;
74 let start_budget = self.config.max_budget / eta_f.powi(s as i32);
75 brackets.push(BracketConfig {
76 n_initial,
77 min_budget: start_budget,
78 max_budget: self.config.max_budget,
79 n_rounds: s + 1,
80 });
81 }
82
83 brackets
84 }
85
86 pub fn run<F>(
91 &self,
92 objective: &F,
93 bounds: &[(f64, f64)],
94 sampler: &ConfigSampler,
95 rng_state: &mut u64,
96 ) -> OptimizeResult<MultiFidelityResult>
97 where
98 F: Fn(&[f64], f64) -> OptimizeResult<f64>,
99 {
100 if bounds.is_empty() {
101 return Err(OptimizeError::InvalidParameter(
102 "bounds must not be empty".into(),
103 ));
104 }
105
106 let brackets = self.compute_brackets();
107 let n_brackets = brackets.len();
108
109 let sh = SuccessiveHalving::new(self.config.clone())?;
110
111 let mut all_evals: Vec<EvaluationResult> = Vec::new();
112 let mut total_budget = 0.0;
113 let mut global_best_obj = f64::INFINITY;
114 let mut global_best_cfg: Vec<f64> = Vec::new();
115 let mut eval_id_offset = 0usize;
116
117 for bracket in &brackets {
118 let result = sh.run_with(
119 objective,
120 bounds,
121 sampler,
122 rng_state,
123 bracket.n_initial,
124 bracket.min_budget,
125 )?;
126
127 for mut e in result.evaluations {
129 e.config_id += eval_id_offset;
130 if e.objective < global_best_obj {
131 global_best_obj = e.objective;
132 global_best_cfg = e.config.clone();
133 }
134 all_evals.push(e);
135 }
136 eval_id_offset = all_evals.iter().map(|e| e.config_id).max().unwrap_or(0) + 1;
137
138 total_budget += result.total_budget_used;
139 }
140
141 if global_best_cfg.is_empty() {
142 return Err(OptimizeError::ComputationError(
143 "no evaluations performed across brackets".into(),
144 ));
145 }
146
147 Ok(MultiFidelityResult {
148 best_config: global_best_cfg,
149 best_objective: global_best_obj,
150 total_budget_used: total_budget,
151 evaluations: all_evals,
152 n_brackets,
153 })
154 }
155}
156
157#[cfg(test)]
162mod tests {
163 use super::*;
164
165 fn quadratic(x: &[f64], _budget: f64) -> OptimizeResult<f64> {
166 Ok(x.iter().map(|xi| xi * xi).sum())
167 }
168
169 fn budget_aware_quadratic(x: &[f64], budget: f64) -> OptimizeResult<f64> {
171 let base: f64 = x.iter().map(|xi| xi * xi).sum();
172 Ok(base + 1.0 / budget.sqrt())
174 }
175
176 #[test]
177 fn test_multiple_brackets_generated() {
178 let cfg = MultiFidelityConfig {
179 max_budget: 81.0,
180 min_budget: 1.0,
181 eta: 3,
182 n_initial: 0,
183 };
184 let hb = Hyperband::new(cfg).expect("valid");
185 let brackets = hb.compute_brackets();
186 assert_eq!(brackets.len(), 5);
188 }
189
190 #[test]
191 fn test_best_across_brackets_selected() {
192 let cfg = MultiFidelityConfig {
193 max_budget: 27.0,
194 min_budget: 1.0,
195 eta: 3,
196 n_initial: 0,
197 };
198 let hb = Hyperband::new(cfg).expect("valid");
199 let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
200 let mut rng = 42u64;
201 let result = hb
202 .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
203 .expect("run ok");
204 let true_min = result
206 .evaluations
207 .iter()
208 .map(|e| e.objective)
209 .fold(f64::INFINITY, f64::min);
210 assert!(
211 (result.best_objective - true_min).abs() < 1e-12,
212 "best_objective {} should match minimum evaluation {}",
213 result.best_objective,
214 true_min
215 );
216 }
217
218 #[test]
219 fn test_total_budget_bounded() {
220 let cfg = MultiFidelityConfig {
221 max_budget: 27.0,
222 min_budget: 1.0,
223 eta: 3,
224 n_initial: 0,
225 };
226 let hb = Hyperband::new(cfg).expect("valid");
227 let bounds = vec![(-1.0, 1.0)];
228 let mut rng = 77u64;
229 let result = hb
230 .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
231 .expect("run ok");
232 assert!(result.total_budget_used > 0.0);
234 assert!(result.total_budget_used.is_finite());
235 }
236
237 #[test]
238 fn test_converges_to_optimum() {
239 let cfg = MultiFidelityConfig {
240 max_budget: 81.0,
241 min_budget: 1.0,
242 eta: 3,
243 n_initial: 0,
244 };
245 let hb = Hyperband::new(cfg).expect("valid");
246 let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
247 let mut rng = 12345u64;
248 let result = hb
249 .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
250 .expect("run ok");
251 assert!(
253 result.best_objective < 5.0,
254 "best objective {} should be < 5",
255 result.best_objective
256 );
257 }
258
259 #[test]
260 fn test_eta2_vs_eta3_different_brackets() {
261 let cfg2 = MultiFidelityConfig {
262 max_budget: 64.0,
263 min_budget: 1.0,
264 eta: 2,
265 n_initial: 0,
266 };
267 let cfg3 = MultiFidelityConfig {
268 max_budget: 64.0,
269 min_budget: 1.0,
270 eta: 3,
271 n_initial: 0,
272 };
273 let hb2 = Hyperband::new(cfg2).expect("valid");
274 let hb3 = Hyperband::new(cfg3).expect("valid");
275 let brackets2 = hb2.compute_brackets();
276 let brackets3 = hb3.compute_brackets();
277 assert_eq!(brackets2.len(), 7, "eta=2 should have 7 brackets");
280 assert_eq!(brackets3.len(), 4, "eta=3 should have 4 brackets");
281 }
282
283 #[test]
284 fn test_budget_aware_objective() {
285 let cfg = MultiFidelityConfig {
286 max_budget: 27.0,
287 min_budget: 1.0,
288 eta: 3,
289 n_initial: 0,
290 };
291 let hb = Hyperband::new(cfg).expect("valid");
292 let bounds = vec![(-3.0, 3.0)];
293 let mut rng = 55u64;
294 let result = hb
295 .run(
296 &budget_aware_quadratic,
297 &bounds,
298 &ConfigSampler::LatinHypercube,
299 &mut rng,
300 )
301 .expect("run ok");
302 assert!(result.best_objective.is_finite());
303 assert!(result.n_brackets > 1);
304 }
305
306 #[test]
307 fn test_empty_bounds_error() {
308 let cfg = MultiFidelityConfig::default();
309 let hb = Hyperband::new(cfg).expect("valid");
310 let result = hb.run(&quadratic, &[], &ConfigSampler::Random, &mut 1u64);
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_bracket_budgets_reach_max() {
316 let cfg = MultiFidelityConfig {
317 max_budget: 81.0,
318 min_budget: 1.0,
319 eta: 3,
320 n_initial: 0,
321 };
322 let hb = Hyperband::new(cfg).expect("valid");
323 let brackets = hb.compute_brackets();
324 for b in &brackets {
325 assert!(
326 (b.max_budget - 81.0).abs() < 1e-9,
327 "all brackets should share the same max_budget"
328 );
329 }
330 }
331
332 #[test]
333 fn test_n_brackets_in_result() {
334 let cfg = MultiFidelityConfig {
335 max_budget: 27.0,
336 min_budget: 1.0,
337 eta: 3,
338 n_initial: 0,
339 };
340 let hb = Hyperband::new(cfg).expect("valid");
341 let bounds = vec![(-1.0, 1.0)];
342 let mut rng = 1u64;
343 let result = hb
344 .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
345 .expect("run ok");
346 let expected = hb.compute_brackets().len();
347 assert_eq!(result.n_brackets, expected);
348 }
349}