Skip to main content

scirs2_optimize/multi_fidelity/
hyperband.rs

1//! Hyperband algorithm.
2//!
3//! Hyperband (Li et al., 2017) is an early-stopping-based approach to
4//! hyperparameter optimization.  It extends Successive Halving by running
5//! multiple brackets with different budget/configuration trade-offs, thereby
6//! hedging against the unknown optimal aggressiveness of early stopping.
7//!
8//! # Reference
9//!
10//! Li, Jamieson, DeSalvo, Rostamizadeh & Talwalkar (2017).  *Hyperband: A
11//! Novel Bandit-Based Approach to Hyperparameter Optimization.* JMLR 18(185).
12
13use crate::error::{OptimizeError, OptimizeResult};
14
15use super::successive_halving::SuccessiveHalving;
16use super::types::{ConfigSampler, EvaluationResult, MultiFidelityConfig, MultiFidelityResult};
17
18// ---------------------------------------------------------------------------
19// Bracket specification
20// ---------------------------------------------------------------------------
21
22/// Configuration for a single Hyperband bracket.
23#[derive(Debug, Clone)]
24pub(crate) struct BracketConfig {
25    /// Number of initial configurations for this bracket.
26    pub n_initial: usize,
27    /// Starting budget for this bracket.
28    pub min_budget: f64,
29    /// Maximum budget (same across all brackets).
30    pub max_budget: f64,
31    /// Number of successive halving rounds in this bracket.
32    pub n_rounds: usize,
33}
34
35// ---------------------------------------------------------------------------
36// Hyperband
37// ---------------------------------------------------------------------------
38
39/// Hyperband optimizer.
40///
41/// Hyperband runs `s_max + 1` brackets of Successive Halving, where
42/// `s_max = floor(log_eta(R))` and `R = max_budget / min_budget`.
43///
44/// Each bracket `s` (from `s_max` down to `0`) trades off the number of
45/// initial configurations against the starting budget:
46///
47/// ```text
48/// bracket s:
49///   n  = ceil((s_max+1)/(s+1)) * eta^s
50///   r  = max_budget / eta^s
51/// ```
52#[derive(Debug, Clone)]
53pub struct Hyperband {
54    config: MultiFidelityConfig,
55}
56
57impl Hyperband {
58    /// Create a new Hyperband instance.
59    pub fn new(config: MultiFidelityConfig) -> OptimizeResult<Self> {
60        config.validate()?;
61        Ok(Self { config })
62    }
63
64    /// Compute the bracket configurations.
65    pub(crate) fn compute_brackets(&self) -> Vec<BracketConfig> {
66        let s_max = self.config.s_max();
67        let eta = self.config.eta;
68        let eta_f = eta as f64;
69        let mut brackets = Vec::with_capacity(s_max + 1);
70
71        for s in (0..=s_max).rev() {
72            let n_initial =
73                ((s_max + 1) as f64 / (s + 1) as f64 * eta_f.powi(s as i32)).ceil() as usize;
74            let start_budget = self.config.max_budget / eta_f.powi(s as i32);
75            brackets.push(BracketConfig {
76                n_initial,
77                min_budget: start_budget,
78                max_budget: self.config.max_budget,
79                n_rounds: s + 1,
80            });
81        }
82
83        brackets
84    }
85
86    /// Run Hyperband.
87    ///
88    /// Iterates over all brackets, runs Successive Halving for each, and
89    /// returns the best configuration found across all brackets.
90    pub fn run<F>(
91        &self,
92        objective: &F,
93        bounds: &[(f64, f64)],
94        sampler: &ConfigSampler,
95        rng_state: &mut u64,
96    ) -> OptimizeResult<MultiFidelityResult>
97    where
98        F: Fn(&[f64], f64) -> OptimizeResult<f64>,
99    {
100        if bounds.is_empty() {
101            return Err(OptimizeError::InvalidParameter(
102                "bounds must not be empty".into(),
103            ));
104        }
105
106        let brackets = self.compute_brackets();
107        let n_brackets = brackets.len();
108
109        let sh = SuccessiveHalving::new(self.config.clone())?;
110
111        let mut all_evals: Vec<EvaluationResult> = Vec::new();
112        let mut total_budget = 0.0;
113        let mut global_best_obj = f64::INFINITY;
114        let mut global_best_cfg: Vec<f64> = Vec::new();
115        let mut eval_id_offset = 0usize;
116
117        for bracket in &brackets {
118            let result = sh.run_with(
119                objective,
120                bounds,
121                sampler,
122                rng_state,
123                bracket.n_initial,
124                bracket.min_budget,
125            )?;
126
127            // Re-number config IDs to be globally unique
128            for mut e in result.evaluations {
129                e.config_id += eval_id_offset;
130                if e.objective < global_best_obj {
131                    global_best_obj = e.objective;
132                    global_best_cfg = e.config.clone();
133                }
134                all_evals.push(e);
135            }
136            eval_id_offset = all_evals.iter().map(|e| e.config_id).max().unwrap_or(0) + 1;
137
138            total_budget += result.total_budget_used;
139        }
140
141        if global_best_cfg.is_empty() {
142            return Err(OptimizeError::ComputationError(
143                "no evaluations performed across brackets".into(),
144            ));
145        }
146
147        Ok(MultiFidelityResult {
148            best_config: global_best_cfg,
149            best_objective: global_best_obj,
150            total_budget_used: total_budget,
151            evaluations: all_evals,
152            n_brackets,
153        })
154    }
155}
156
157// ---------------------------------------------------------------------------
158// Tests
159// ---------------------------------------------------------------------------
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    fn quadratic(x: &[f64], _budget: f64) -> OptimizeResult<f64> {
166        Ok(x.iter().map(|xi| xi * xi).sum())
167    }
168
169    /// Budget-aware: penalize with `1/sqrt(budget)` noise-like term.
170    fn budget_aware_quadratic(x: &[f64], budget: f64) -> OptimizeResult<f64> {
171        let base: f64 = x.iter().map(|xi| xi * xi).sum();
172        // Higher budget => more accurate (smaller perturbation)
173        Ok(base + 1.0 / budget.sqrt())
174    }
175
176    #[test]
177    fn test_multiple_brackets_generated() {
178        let cfg = MultiFidelityConfig {
179            max_budget: 81.0,
180            min_budget: 1.0,
181            eta: 3,
182            n_initial: 0,
183        };
184        let hb = Hyperband::new(cfg).expect("valid");
185        let brackets = hb.compute_brackets();
186        // s_max = 4, so 5 brackets
187        assert_eq!(brackets.len(), 5);
188    }
189
190    #[test]
191    fn test_best_across_brackets_selected() {
192        let cfg = MultiFidelityConfig {
193            max_budget: 27.0,
194            min_budget: 1.0,
195            eta: 3,
196            n_initial: 0,
197        };
198        let hb = Hyperband::new(cfg).expect("valid");
199        let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
200        let mut rng = 42u64;
201        let result = hb
202            .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
203            .expect("run ok");
204        // The reported best should equal the minimum across all evaluations
205        let true_min = result
206            .evaluations
207            .iter()
208            .map(|e| e.objective)
209            .fold(f64::INFINITY, f64::min);
210        assert!(
211            (result.best_objective - true_min).abs() < 1e-12,
212            "best_objective {} should match minimum evaluation {}",
213            result.best_objective,
214            true_min
215        );
216    }
217
218    #[test]
219    fn test_total_budget_bounded() {
220        let cfg = MultiFidelityConfig {
221            max_budget: 27.0,
222            min_budget: 1.0,
223            eta: 3,
224            n_initial: 0,
225        };
226        let hb = Hyperband::new(cfg).expect("valid");
227        let bounds = vec![(-1.0, 1.0)];
228        let mut rng = 77u64;
229        let result = hb
230            .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
231            .expect("run ok");
232        // Total budget should be finite and positive
233        assert!(result.total_budget_used > 0.0);
234        assert!(result.total_budget_used.is_finite());
235    }
236
237    #[test]
238    fn test_converges_to_optimum() {
239        let cfg = MultiFidelityConfig {
240            max_budget: 81.0,
241            min_budget: 1.0,
242            eta: 3,
243            n_initial: 0,
244        };
245        let hb = Hyperband::new(cfg).expect("valid");
246        let bounds = vec![(-5.0, 5.0), (-5.0, 5.0)];
247        let mut rng = 12345u64;
248        let result = hb
249            .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
250            .expect("run ok");
251        // With enough configs, best should be reasonably close to 0
252        assert!(
253            result.best_objective < 5.0,
254            "best objective {} should be < 5",
255            result.best_objective
256        );
257    }
258
259    #[test]
260    fn test_eta2_vs_eta3_different_brackets() {
261        let cfg2 = MultiFidelityConfig {
262            max_budget: 64.0,
263            min_budget: 1.0,
264            eta: 2,
265            n_initial: 0,
266        };
267        let cfg3 = MultiFidelityConfig {
268            max_budget: 64.0,
269            min_budget: 1.0,
270            eta: 3,
271            n_initial: 0,
272        };
273        let hb2 = Hyperband::new(cfg2).expect("valid");
274        let hb3 = Hyperband::new(cfg3).expect("valid");
275        let brackets2 = hb2.compute_brackets();
276        let brackets3 = hb3.compute_brackets();
277        // eta=2: s_max = log_2(64) = 6, so 7 brackets
278        // eta=3: s_max = floor(log_3(64)) = 3, so 4 brackets
279        assert_eq!(brackets2.len(), 7, "eta=2 should have 7 brackets");
280        assert_eq!(brackets3.len(), 4, "eta=3 should have 4 brackets");
281    }
282
283    #[test]
284    fn test_budget_aware_objective() {
285        let cfg = MultiFidelityConfig {
286            max_budget: 27.0,
287            min_budget: 1.0,
288            eta: 3,
289            n_initial: 0,
290        };
291        let hb = Hyperband::new(cfg).expect("valid");
292        let bounds = vec![(-3.0, 3.0)];
293        let mut rng = 55u64;
294        let result = hb
295            .run(
296                &budget_aware_quadratic,
297                &bounds,
298                &ConfigSampler::LatinHypercube,
299                &mut rng,
300            )
301            .expect("run ok");
302        assert!(result.best_objective.is_finite());
303        assert!(result.n_brackets > 1);
304    }
305
306    #[test]
307    fn test_empty_bounds_error() {
308        let cfg = MultiFidelityConfig::default();
309        let hb = Hyperband::new(cfg).expect("valid");
310        let result = hb.run(&quadratic, &[], &ConfigSampler::Random, &mut 1u64);
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn test_bracket_budgets_reach_max() {
316        let cfg = MultiFidelityConfig {
317            max_budget: 81.0,
318            min_budget: 1.0,
319            eta: 3,
320            n_initial: 0,
321        };
322        let hb = Hyperband::new(cfg).expect("valid");
323        let brackets = hb.compute_brackets();
324        for b in &brackets {
325            assert!(
326                (b.max_budget - 81.0).abs() < 1e-9,
327                "all brackets should share the same max_budget"
328            );
329        }
330    }
331
332    #[test]
333    fn test_n_brackets_in_result() {
334        let cfg = MultiFidelityConfig {
335            max_budget: 27.0,
336            min_budget: 1.0,
337            eta: 3,
338            n_initial: 0,
339        };
340        let hb = Hyperband::new(cfg).expect("valid");
341        let bounds = vec![(-1.0, 1.0)];
342        let mut rng = 1u64;
343        let result = hb
344            .run(&quadratic, &bounds, &ConfigSampler::Random, &mut rng)
345            .expect("run ok");
346        let expected = hb.compute_brackets().len();
347        assert_eq!(result.n_brackets, expected);
348    }
349}