Skip to main content

tenflowers_neural/hyperparameter_optimization/
mod.rs

1//! Advanced Hyperparameter Optimization (HPO) Methods.
2//!
3//! This module provides production-grade HPO methods beyond the basic grid/random
4//! search in `hparam/`. It includes:
5//!
6//! - [`HpSpace`] / [`HyperParameter`] — flexible parameter space definitions
7//! - \[`BayesianOptimizer`\] (aliased as [`HpoBayesianOptimizer`]) — GP-based Bayesian optimization
8//! - [`HyperBandScheduler`] — multi-fidelity successive halving
9//! - [`BohbOptimizer`] — BOHB (Bayesian Optimization + HyperBand)
10//! - [`PopulationBasedTraining`] — evolutionary HPO (PBT)
11//! - [`EvolutionaryStrategy`] — CMA-ES for HPO
12//! - [`MultiObjectiveHpo`] — Pareto-front HPO with NSGA-II
13//! - [`MedianStopping`] / [`SuccessiveHalving`] / [`PercentileStop`] — early termination
14//! - [`HpoStudy`] / [`HpoLogger`] — experiment tracking
15//! - [`WarmStartSampler`] — transfer-learning-based warm starting
16//!
17//! All randomness uses `scirs2_core::random` (no `rand` crate). No `unsafe` code.
18//! No `unwrap()`.
19
20use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
21use scirs2_core::RngExt;
22use tenflowers_core::{Result, TensorError};
23
24// ─────────────────────────────────────────────────────────────────────────────
25// §1. HyperParameter and HpSpace
26// ─────────────────────────────────────────────────────────────────────────────
27
28/// The type of a hyperparameter.
29#[derive(Debug, Clone, PartialEq)]
30pub enum HpType {
31    /// Continuous real-valued parameter in `[bounds.0, bounds.1]`.
32    Continuous,
33    /// Categorical parameter; valid values are indices into `choices`.
34    Categorical,
35    /// Integer parameter (rounded continuous).
36    Integer,
37    /// Log-scale continuous parameter (sampled uniformly in log space).
38    Log,
39}
40
41/// A single hyperparameter definition.
42#[derive(Debug, Clone)]
43pub struct HyperParameter {
44    /// Name of the parameter.
45    pub name: String,
46    /// Parameter type.
47    pub hp_type: HpType,
48    /// `(low, high)` bounds for Continuous / Integer / Log types.
49    pub bounds: (f64, f64),
50    /// Valid discrete values for Categorical type.
51    pub choices: Vec<f64>,
52    /// If true, sample/transform in log space (applicable to Continuous and Integer).
53    pub log_scale: bool,
54}
55
56impl HyperParameter {
57    /// Construct a continuous parameter.
58    pub fn continuous(name: impl Into<String>, low: f64, high: f64) -> Self {
59        Self {
60            name: name.into(),
61            hp_type: HpType::Continuous,
62            bounds: (low, high),
63            choices: vec![],
64            log_scale: false,
65        }
66    }
67
68    /// Construct a log-scale continuous parameter.
69    pub fn log_continuous(name: impl Into<String>, low: f64, high: f64) -> Self {
70        Self {
71            name: name.into(),
72            hp_type: HpType::Log,
73            bounds: (low, high),
74            choices: vec![],
75            log_scale: true,
76        }
77    }
78
79    /// Construct an integer parameter.
80    pub fn integer(name: impl Into<String>, low: i64, high: i64) -> Self {
81        Self {
82            name: name.into(),
83            hp_type: HpType::Integer,
84            bounds: (low as f64, high as f64),
85            choices: vec![],
86            log_scale: false,
87        }
88    }
89
90    /// Construct a categorical parameter from float-encoded choices.
91    pub fn categorical(name: impl Into<String>, choices: Vec<f64>) -> Self {
92        Self {
93            name: name.into(),
94            hp_type: HpType::Categorical,
95            bounds: (0.0, choices.len().saturating_sub(1) as f64),
96            choices,
97            log_scale: false,
98        }
99    }
100
101    /// Sample a single value from this parameter using the provided RNG.
102    pub fn sample(&self, rng: &mut StdRng) -> f64 {
103        match self.hp_type {
104            HpType::Continuous => {
105                let u: f64 = rng.random();
106                let (lo, hi) = self.bounds;
107                if self.log_scale && lo > 0.0 {
108                    let log_lo = lo.ln();
109                    let log_hi = hi.ln();
110                    (log_lo + u * (log_hi - log_lo)).exp()
111                } else {
112                    lo + u * (hi - lo)
113                }
114            }
115            HpType::Log => {
116                let u: f64 = rng.random();
117                let (lo, hi) = self.bounds;
118                let log_lo = lo.max(1e-300).ln();
119                let log_hi = hi.max(1e-300).ln();
120                (log_lo + u * (log_hi - log_lo)).exp()
121            }
122            HpType::Integer => {
123                let (lo, hi) = self.bounds;
124                let range = (hi - lo + 1.0).max(1.0) as usize;
125                let idx: usize = rng.random_range(0..range);
126                lo + idx as f64
127            }
128            HpType::Categorical => {
129                if self.choices.is_empty() {
130                    return 0.0;
131                }
132                let idx: usize = rng.random_range(0..self.choices.len());
133                self.choices[idx]
134            }
135        }
136    }
137
138    /// Transform a raw value to [0, 1] (for GP input normalisation).
139    pub fn to_unit(&self, v: f64) -> f64 {
140        let (lo, hi) = self.bounds;
141        let range = (hi - lo).max(1e-300);
142        match self.hp_type {
143            HpType::Log => {
144                let log_lo = lo.max(1e-300).ln();
145                let log_hi = hi.max(1e-300).ln();
146                let log_range = (log_hi - log_lo).max(1e-300);
147                (v.max(1e-300).ln() - log_lo) / log_range
148            }
149            HpType::Categorical => {
150                if self.choices.len() <= 1 {
151                    return 0.0;
152                }
153                let idx = self
154                    .choices
155                    .iter()
156                    .position(|&c| (c - v).abs() < 1e-12)
157                    .unwrap_or(0);
158                idx as f64 / (self.choices.len() - 1) as f64
159            }
160            _ => (v - lo) / range,
161        }
162    }
163
164    /// Reverse of `to_unit`: from \[0,1\] back to parameter space.
165    pub fn from_unit(&self, u: f64) -> f64 {
166        let u = u.clamp(0.0, 1.0);
167        let (lo, hi) = self.bounds;
168        match self.hp_type {
169            HpType::Log => {
170                let log_lo = lo.max(1e-300).ln();
171                let log_hi = hi.max(1e-300).ln();
172                (log_lo + u * (log_hi - log_lo)).exp()
173            }
174            HpType::Integer => {
175                let raw = lo + u * (hi - lo);
176                raw.round().clamp(lo, hi)
177            }
178            HpType::Categorical => {
179                if self.choices.is_empty() {
180                    return 0.0;
181                }
182                let idx = (u * (self.choices.len() - 1) as f64).round() as usize;
183                self.choices[idx.min(self.choices.len() - 1)]
184            }
185            _ => lo + u * (hi - lo),
186        }
187    }
188}
189
190/// A collection of hyperparameters forming the search space.
191#[derive(Debug, Clone, Default)]
192pub struct HpSpace {
193    /// Parameter definitions.
194    pub params: Vec<HyperParameter>,
195}
196
197impl HpSpace {
198    /// Create an empty space.
199    pub fn new() -> Self {
200        Self { params: vec![] }
201    }
202
203    /// Add a parameter definition.
204    #[allow(clippy::should_implement_trait)]
205    pub fn add(mut self, hp: HyperParameter) -> Self {
206        self.params.push(hp);
207        self
208    }
209
210    /// Number of parameters in the space.
211    pub fn ndim(&self) -> usize {
212        self.params.len()
213    }
214
215    /// Sample a random point from the space.
216    pub fn sample_random(&self, rng: &mut StdRng) -> Vec<f64> {
217        self.params.iter().map(|p| p.sample(rng)).collect()
218    }
219
220    /// Normalise a raw parameter vector to \[0,1\]^d.
221    pub fn transform_to_unit(&self, values: &[f64]) -> Vec<f64> {
222        self.params
223            .iter()
224            .zip(values.iter())
225            .map(|(p, &v)| p.to_unit(v))
226            .collect()
227    }
228
229    /// Denormalise a unit-cube vector back to parameter space.
230    pub fn transform_from_unit(&self, unit: &[f64]) -> Vec<f64> {
231        self.params
232            .iter()
233            .zip(unit.iter())
234            .map(|(p, &u)| p.from_unit(u))
235            .collect()
236    }
237}
238
239// ─────────────────────────────────────────────────────────────────────────────
240// §2. GpHpo — lightweight GP surrogate
241// ─────────────────────────────────────────────────────────────────────────────
242
243/// A lightweight Gaussian Process surrogate for HPO (RBF kernel, Cholesky inference).
244#[derive(Debug, Clone)]
245pub struct GpHpo {
246    /// Training inputs (unit-normalised), shape [n x d].
247    pub x_train: Vec<Vec<f64>>,
248    /// Training targets.
249    pub y_train: Vec<f64>,
250    /// RBF length scale.
251    pub length_scale: f64,
252    /// Observation noise variance.
253    pub noise: f64,
254    /// Cached Cholesky factor L (lower triangular), stored row-major.
255    chol: Vec<Vec<f64>>,
256    /// Cached alpha = L^{-T} L^{-1} y.
257    alpha: Vec<f64>,
258}
259
260impl GpHpo {
261    /// Create a new GP with given hyper-parameters.
262    pub fn new(length_scale: f64, noise: f64) -> Self {
263        Self {
264            x_train: vec![],
265            y_train: vec![],
266            length_scale,
267            noise,
268            chol: vec![],
269            alpha: vec![],
270        }
271    }
272
273    /// Compute RBF kernel between two vectors.
274    fn rbf(&self, a: &[f64], b: &[f64]) -> f64 {
275        let sq_dist: f64 = a
276            .iter()
277            .zip(b.iter())
278            .map(|(&ai, &bi)| {
279                let d = ai - bi;
280                d * d
281            })
282            .sum();
283        (-sq_dist / (2.0 * self.length_scale * self.length_scale)).exp()
284    }
285
286    /// Fit the GP to the current `x_train` / `y_train` by computing Cholesky.
287    pub fn fit(&mut self) -> Result<()> {
288        let n = self.x_train.len();
289        if n == 0 {
290            return Ok(());
291        }
292        // Build K + noise*I
293        let mut k = vec![vec![0.0_f64; n]; n];
294        for i in 0..n {
295            for j in 0..n {
296                k[i][j] = self.rbf(&self.x_train[i], &self.x_train[j]);
297            }
298            k[i][i] += self.noise;
299        }
300        // Cholesky decomposition
301        let l = cholesky(&k).map_err(|e| TensorError::ComputeError {
302            operation: "GpHpo::fit".into(),
303            details: e,
304            retry_possible: false,
305            context: None,
306        })?;
307        // alpha = K^{-1} y via forward/backward substitution
308        let alpha = chol_solve(&l, &self.y_train)?;
309        self.chol = l;
310        self.alpha = alpha;
311        Ok(())
312    }
313
314    /// Predict mean and variance at a test point `x_star`.
315    pub fn predict(&self, x_star: &[f64]) -> (f64, f64) {
316        let n = self.x_train.len();
317        if n == 0 {
318            return (0.0, 1.0);
319        }
320        // k_star = [k(x*, x_i)]
321        let k_star: Vec<f64> = self.x_train.iter().map(|xi| self.rbf(x_star, xi)).collect();
322        // mean = k_star^T alpha
323        let mean: f64 = k_star
324            .iter()
325            .zip(self.alpha.iter())
326            .map(|(ks, a)| ks * a)
327            .sum();
328        // variance = k(x*,x*) - k_star^T K^{-1} k_star
329        let k_ss = self.rbf(x_star, x_star);
330        // v = L^{-1} k_star
331        let v = forward_sub(&self.chol, &k_star);
332        let var: f64 = k_ss - v.iter().map(|vi| vi * vi).sum::<f64>();
333        (mean, var.max(1e-10))
334    }
335}
336
337// ─────────────────────────────────────────────────────────────────────────────
338// §2b. Acquisition Functions
339// ─────────────────────────────────────────────────────────────────────────────
340
341/// Acquisition function variants for Bayesian optimisation.
342#[derive(Debug, Clone, PartialEq)]
343pub enum HpoAcqFunction {
344    /// Expected Improvement with `xi` exploration bonus.
345    ExpectedImprovement { xi: f64 },
346    /// Probability of Improvement.
347    ProbabilityOfImprovement { xi: f64 },
348    /// Upper Confidence Bound with `kappa` exploration weight.
349    UpperConfidenceBound { kappa: f64 },
350    /// Log Expected Improvement (numerically stable for small improvements).
351    LogExpectedImprovement { xi: f64 },
352}
353
354impl HpoAcqFunction {
355    /// Evaluate acquisition at (mean, std) given `best_y` (the current best observed value).
356    pub fn evaluate(&self, mean: f64, std: f64, best_y: f64) -> f64 {
357        match self {
358            HpoAcqFunction::ExpectedImprovement { xi } => {
359                let z = (mean - best_y - xi) / std.max(1e-9);
360                std * (z * normal_cdf(z) + normal_pdf(z))
361            }
362            HpoAcqFunction::ProbabilityOfImprovement { xi } => {
363                let z = (mean - best_y - xi) / std.max(1e-9);
364                normal_cdf(z)
365            }
366            HpoAcqFunction::UpperConfidenceBound { kappa } => mean + kappa * std,
367            HpoAcqFunction::LogExpectedImprovement { xi } => {
368                let z = (mean - best_y - xi) / std.max(1e-9);
369                let ei = std * (z * normal_cdf(z) + normal_pdf(z));
370                ei.max(1e-300).ln()
371            }
372        }
373    }
374}
375
376// ─────────────────────────────────────────────────────────────────────────────
377// §2c. BayesianOptimizer (HPO variant)
378// ─────────────────────────────────────────────────────────────────────────────
379
380/// Configuration for the HPO Bayesian optimiser.
381#[derive(Debug, Clone)]
382pub struct HpoBayesianConfig {
383    /// Number of random initial evaluations before using the GP.
384    pub n_initial: usize,
385    /// Number of Bayesian optimisation iterations after warm-up.
386    pub n_iter: usize,
387    /// Exploration parameter `xi` for EI/PI.
388    pub xi: f64,
389    /// Exploration parameter `kappa` for UCB.
390    pub kappa: f64,
391    /// Acquisition function variant.
392    pub acq: HpoAcqFunction,
393    /// Number of random candidates evaluated per suggestion step.
394    pub n_candidates: usize,
395}
396
397impl Default for HpoBayesianConfig {
398    fn default() -> Self {
399        Self {
400            n_initial: 5,
401            n_iter: 25,
402            xi: 0.01,
403            kappa: 2.576,
404            acq: HpoAcqFunction::ExpectedImprovement { xi: 0.01 },
405            n_candidates: 1000,
406        }
407    }
408}
409
410/// A full Bayesian optimisation loop backed by a GP surrogate.
411///
412/// Aliased as `HpoBayesianOptimizer` at the module level to avoid name
413/// collision with `bayesian_opt::BayesianOptimizer`.
414#[derive(Debug, Clone)]
415pub struct HpoBayesianOptimizer {
416    /// The hyperparameter space.
417    pub space: HpSpace,
418    /// GP surrogate.
419    pub gp: GpHpo,
420    /// Observed (params, value) pairs — params stored in *raw* (not unit) space.
421    pub observations: Vec<(Vec<f64>, f64)>,
422    /// Configuration.
423    pub config: HpoBayesianConfig,
424}
425
426impl HpoBayesianOptimizer {
427    /// Construct a new Bayesian optimiser.
428    pub fn new(space: HpSpace, config: HpoBayesianConfig) -> Self {
429        let gp = GpHpo::new(1.0, 1e-3);
430        Self {
431            space,
432            gp,
433            observations: vec![],
434            config,
435        }
436    }
437
438    /// Record an observation.
439    pub fn observe(&mut self, params: &[f64], value: f64) {
440        self.observations.push((params.to_vec(), value));
441        // Update GP training data
442        let unit = self.space.transform_to_unit(params);
443        self.gp.x_train.push(unit);
444        self.gp.y_train.push(value);
445        let _ = self.gp.fit();
446    }
447
448    /// Suggest the next point to evaluate.
449    ///
450    /// Returns random samples during warm-up, then maximises the acquisition function.
451    pub fn suggest(&self, rng: &mut StdRng) -> Vec<f64> {
452        if self.observations.len() < self.config.n_initial {
453            return self.space.sample_random(rng);
454        }
455        // Find best observed value (maximisation convention)
456        let best_y = self
457            .observations
458            .iter()
459            .map(|(_, v)| *v)
460            .fold(f64::NEG_INFINITY, f64::max);
461
462        let mut best_acq = f64::NEG_INFINITY;
463        let mut best_candidate = self.space.sample_random(rng);
464
465        for _ in 0..self.config.n_candidates {
466            let candidate = self.space.sample_random(rng);
467            let unit = self.space.transform_to_unit(&candidate);
468            let (mean, var) = self.gp.predict(&unit);
469            let std = var.sqrt();
470            let acq = self.config.acq.evaluate(mean, std, best_y);
471            if acq > best_acq {
472                best_acq = acq;
473                best_candidate = candidate;
474            }
475        }
476        best_candidate
477    }
478
479    /// Return the best observed (params, value) pair.
480    pub fn best(&self) -> Option<(&[f64], f64)> {
481        self.observations
482            .iter()
483            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
484            .map(|(p, v)| (p.as_slice(), *v))
485    }
486}
487
488// ─────────────────────────────────────────────────────────────────────────────
489// §3. HyperBandScheduler
490// ─────────────────────────────────────────────────────────────────────────────
491
492/// Configuration for HyperBand.
493#[derive(Debug, Clone)]
494pub struct HyperBandConfig {
495    /// Maximum resource (iterations / epochs) per configuration.
496    pub max_iter: f64,
497    /// Halving factor η (typically 3 or 4).
498    pub eta: f64,
499    /// Minimum resource per configuration.
500    pub min_resource: f64,
501}
502
503impl Default for HyperBandConfig {
504    fn default() -> Self {
505        Self {
506            max_iter: 81.0,
507            eta: 3.0,
508            min_resource: 1.0,
509        }
510    }
511}
512
513/// A single HyperBand bracket (one successive-halving run).
514#[derive(Debug, Clone)]
515pub struct HbBracket {
516    /// Bracket identifier.
517    pub bracket_id: usize,
518    /// Number of configurations in this bracket.
519    pub n: usize,
520    /// Initial resource allocation for this bracket.
521    pub r: f64,
522    /// Number of successive-halving rounds in this bracket.
523    pub s: usize,
524}
525
526/// The full HyperBand scheduler managing multiple brackets.
527#[derive(Debug, Clone)]
528pub struct HyperBandScheduler {
529    /// Computed bracket schedule.
530    pub brackets: Vec<HbBracket>,
531    /// Index of the currently active bracket.
532    pub current: usize,
533    /// Configuration.
534    pub config: HyperBandConfig,
535}
536
537impl HyperBandScheduler {
538    /// Create a scheduler and plan the bracket schedule.
539    pub fn new(config: HyperBandConfig) -> Self {
540        let brackets = Self::plan_schedule(&config);
541        Self {
542            brackets,
543            current: 0,
544            config,
545        }
546    }
547
548    /// Plan the full schedule of brackets for the given HyperBand config.
549    pub fn plan_schedule(config: &HyperBandConfig) -> Vec<HbBracket> {
550        // s_max = floor(log_{eta}(max_iter / min_resource))
551        let s_max = (config.max_iter / config.min_resource)
552            .max(1.0)
553            .log(config.eta)
554            .floor() as usize;
555        let mut brackets = Vec::with_capacity(s_max + 1);
556        for s in (0..=s_max).rev() {
557            let n =
558                ((s_max + 1) as f64 / (s + 1) as f64 * config.eta.powi(s as i32)).ceil() as usize;
559            let r = config.max_iter / config.eta.powi(s as i32);
560            brackets.push(HbBracket {
561                bracket_id: s_max - s,
562                n,
563                r,
564                s,
565            });
566        }
567        brackets
568    }
569
570    /// Generate random configurations for a given bracket.
571    pub fn get_configurations(
572        &self,
573        bracket: &HbBracket,
574        space: &HpSpace,
575        rng: &mut StdRng,
576    ) -> Vec<Vec<f64>> {
577        (0..bracket.n).map(|_| space.sample_random(rng)).collect()
578    }
579
580    /// Keep the top `n_keep` configurations by score (descending).
581    pub fn promote(scores: &[(Vec<f64>, f64)], n_keep: usize) -> Vec<Vec<f64>> {
582        let mut sorted = scores.to_vec();
583        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
584        sorted.into_iter().take(n_keep).map(|(p, _)| p).collect()
585    }
586
587    /// Run a single successive-halving round within bracket `b_idx`.
588    /// Returns configurations to evaluate at each resource level.
589    pub fn successive_halving_rounds(&self, bracket_idx: usize) -> Vec<(usize, f64)> {
590        if bracket_idx >= self.brackets.len() {
591            return vec![];
592        }
593        let b = &self.brackets[bracket_idx];
594        (0..=b.s)
595            .map(|i| {
596                let n_i = (b.n as f64 / self.config.eta.powi(i as i32)).floor() as usize;
597                let r_i = b.r * self.config.eta.powi(i as i32);
598                (n_i.max(1), r_i)
599            })
600            .collect()
601    }
602}
603
604// ─────────────────────────────────────────────────────────────────────────────
605// §4. BOHB — Bayesian Optimization + HyperBand
606// ─────────────────────────────────────────────────────────────────────────────
607
608/// Configuration for BOHB.
609#[derive(Debug, Clone)]
610pub struct BohbConfig {
611    /// Number of random observations before fitting KDEs.
612    pub n_initial: usize,
613    /// KDE bandwidth.
614    pub bandwidth: f64,
615    /// Fraction of top observations used for the "good" KDE.
616    pub top_frac: f64,
617}
618
619impl Default for BohbConfig {
620    fn default() -> Self {
621        Self {
622            n_initial: 10,
623            bandwidth: 1.0,
624            top_frac: 0.15,
625        }
626    }
627}
628
629/// Parzen-window KDE sampler used by BOHB.
630#[derive(Debug, Clone)]
631pub struct KdeSampler {
632    /// Observations classified as "good" (top performers).
633    pub good_obs: Vec<f64>,
634    /// Observations classified as "bad" (remaining).
635    pub bad_obs: Vec<f64>,
636    /// KDE bandwidth.
637    pub bandwidth: f64,
638}
639
640impl KdeSampler {
641    /// Evaluate the KDE density at `x` using Gaussian kernel.
642    pub fn kde_pdf(x: f64, samples: &[f64], bw: f64) -> f64 {
643        if samples.is_empty() {
644            return 1.0;
645        }
646        let n = samples.len() as f64;
647        let sum: f64 = samples.iter().map(|&s| normal_pdf((x - s) / bw) / bw).sum();
648        sum / n
649    }
650
651    /// Sample a candidate by maximising the ratio l(x)/g(x) over random draws.
652    pub fn sample_from_kde(
653        good: &[f64],
654        bad: &[f64],
655        bw: f64,
656        rng: &mut StdRng,
657        n_candidates: usize,
658    ) -> f64 {
659        if good.is_empty() {
660            // fall back to uniform
661            return rng.random();
662        }
663        let mut best_ratio = f64::NEG_INFINITY;
664        let mut best_x = 0.5_f64;
665        for _ in 0..n_candidates {
666            // sample from good KDE
667            let idx: usize = rng.random_range(0..good.len());
668            let noise: f64 = rng.random::<f64>() * 2.0 - 1.0; // uniform [-1,1]
669            let x = (good[idx] + noise * bw).clamp(0.0, 1.0);
670            let lx = Self::kde_pdf(x, good, bw).max(1e-300);
671            let gx = Self::kde_pdf(x, bad, bw).max(1e-300);
672            let ratio = lx / gx;
673            if ratio > best_ratio {
674                best_ratio = ratio;
675                best_x = x;
676            }
677        }
678        best_x
679    }
680
681    /// Fit the KDE sampler from a list of (value, score) observations.
682    pub fn fit_from_observations(obs: &[(f64, f64)], top_frac: f64, bw: f64) -> Self {
683        if obs.is_empty() {
684            return Self {
685                good_obs: vec![],
686                bad_obs: vec![],
687                bandwidth: bw,
688            };
689        }
690        let mut sorted = obs.to_vec();
691        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
692        let n_good = ((obs.len() as f64 * top_frac).ceil() as usize).max(1);
693        let good_obs: Vec<f64> = sorted[..n_good].iter().map(|(v, _)| *v).collect();
694        let bad_obs: Vec<f64> = sorted[n_good..].iter().map(|(v, _)| *v).collect();
695        Self {
696            good_obs,
697            bad_obs,
698            bandwidth: bw,
699        }
700    }
701}
702
703/// BOHB optimiser — combines Bayesian Optimization with HyperBand.
704#[derive(Debug, Clone)]
705pub struct BohbOptimizer {
706    /// Hyperparameter space.
707    pub space: HpSpace,
708    /// Underlying HyperBand scheduler.
709    pub hb: HyperBandScheduler,
710    /// Per-dimension KDE samplers (one per parameter).
711    pub samplers: Vec<KdeSampler>,
712    /// All recorded observations: (params, score).
713    observations: Vec<(Vec<f64>, f64)>,
714    /// Config.
715    pub config: BohbConfig,
716}
717
718impl BohbOptimizer {
719    /// Create a new BOHB optimiser.
720    pub fn new(space: HpSpace, hb_config: HyperBandConfig, bohb_config: BohbConfig) -> Self {
721        let hb = HyperBandScheduler::new(hb_config);
722        let samplers = vec![
723            KdeSampler {
724                good_obs: vec![],
725                bad_obs: vec![],
726                bandwidth: bohb_config.bandwidth
727            };
728            space.ndim()
729        ];
730        Self {
731            space,
732            hb,
733            samplers,
734            observations: vec![],
735            config: bohb_config,
736        }
737    }
738
739    /// Record an observation and refit KDE samplers.
740    pub fn observe(&mut self, params: &[f64], score: f64) {
741        self.observations.push((params.to_vec(), score));
742        // Refit per-dimension KDE samplers
743        for (dim, sampler) in self.samplers.iter_mut().enumerate() {
744            let dim_obs: Vec<(f64, f64)> = self
745                .observations
746                .iter()
747                .map(|(p, s)| (self.space.params[dim].to_unit(p[dim]), *s))
748                .collect();
749            *sampler = KdeSampler::fit_from_observations(
750                &dim_obs,
751                self.config.top_frac,
752                self.config.bandwidth,
753            );
754        }
755    }
756
757    /// Suggest a new configuration.
758    pub fn suggest_bohb(&self, rng: &mut StdRng) -> Vec<f64> {
759        if self.observations.len() < self.config.n_initial {
760            return self.space.sample_random(rng);
761        }
762        // Sample each dimension independently via its KDE sampler
763        let unit: Vec<f64> = self
764            .samplers
765            .iter()
766            .map(|s| KdeSampler::sample_from_kde(&s.good_obs, &s.bad_obs, s.bandwidth, rng, 64))
767            .collect();
768        self.space.transform_from_unit(&unit)
769    }
770}
771
772// ─────────────────────────────────────────────────────────────────────────────
773// §5. Population-Based Training (PBT)
774// ─────────────────────────────────────────────────────────────────────────────
775
776/// Configuration for Population-Based Training.
777#[derive(Debug, Clone)]
778pub struct PbtConfig {
779    /// Number of workers in the population.
780    pub population_size: usize,
781    /// Fraction of the population to replace during exploitation.
782    pub exploit_frac: f64,
783    /// Additive noise magnitude during exploration.
784    pub explore_noise: f64,
785}
786
787impl Default for PbtConfig {
788    fn default() -> Self {
789        Self {
790            population_size: 10,
791            exploit_frac: 0.2,
792            explore_noise: 0.2,
793        }
794    }
795}
796
797/// A single worker in the PBT population.
798#[derive(Debug, Clone)]
799pub struct PbtMember {
800    /// Current hyperparameter values.
801    pub params: Vec<f64>,
802    /// Last reported score.
803    pub score: f64,
804    /// Number of training steps completed.
805    pub steps_trained: usize,
806}
807
808impl PbtMember {
809    /// Construct a new member.
810    pub fn new(params: Vec<f64>) -> Self {
811        Self {
812            params,
813            score: f64::NEG_INFINITY,
814            steps_trained: 0,
815        }
816    }
817}
818
819/// The PBT population.
820#[derive(Debug, Clone)]
821pub struct PbtPopulation {
822    /// Workers.
823    pub members: Vec<PbtMember>,
824}
825
826impl PbtPopulation {
827    /// Construct a population by random sampling.
828    pub fn random(space: &HpSpace, size: usize, rng: &mut StdRng) -> Self {
829        let members = (0..size)
830            .map(|_| PbtMember::new(space.sample_random(rng)))
831            .collect();
832        Self { members }
833    }
834
835    /// Exploit: replace the bottom-`exploit_frac` workers' params with top workers' params.
836    pub fn exploit(&mut self, rng: &mut StdRng, exploit_frac: f64) {
837        let n = self.members.len();
838        if n < 2 {
839            return;
840        }
841        let n_replace = ((n as f64 * exploit_frac).ceil() as usize).min(n / 2);
842        // Sort indices by score descending
843        let mut order: Vec<usize> = (0..n).collect();
844        order.sort_by(|&a, &b| {
845            self.members[b]
846                .score
847                .partial_cmp(&self.members[a].score)
848                .unwrap_or(std::cmp::Ordering::Equal)
849        });
850        // Replace bottom-n_replace with a random top-n_replace
851        let top_indices: Vec<usize> = order[..n_replace].to_vec();
852        let bottom_indices: Vec<usize> = order[n - n_replace..].to_vec();
853        for &bot in &bottom_indices {
854            let src_idx: usize = rng.random_range(0..top_indices.len());
855            let src = top_indices[src_idx];
856            let new_params = self.members[src].params.clone();
857            self.members[bot].params = new_params;
858        }
859    }
860
861    /// Explore: perturb each parameter by ±noise or resample.
862    pub fn explore_member(
863        params: &[f64],
864        space: &HpSpace,
865        noise: f64,
866        rng: &mut StdRng,
867    ) -> Vec<f64> {
868        params
869            .iter()
870            .zip(space.params.iter())
871            .map(|(&v, hp)| {
872                let perturb: f64 = rng.random();
873                if perturb < 0.1 {
874                    // resample
875                    hp.sample(rng)
876                } else {
877                    let delta: f64 = (rng.random::<f64>() * 2.0 - 1.0) * noise;
878                    let (lo, hi) = hp.bounds;
879                    (v * (1.0 + delta)).clamp(lo, hi)
880                }
881            })
882            .collect()
883    }
884
885    /// Run a single PBT step: update scores, exploit, then explore.
886    pub fn pbt_step(
887        &mut self,
888        scores: &[f64],
889        space: &HpSpace,
890        config: &PbtConfig,
891        rng: &mut StdRng,
892    ) {
893        // Update scores and step counts
894        for (member, &score) in self.members.iter_mut().zip(scores.iter()) {
895            member.score = score;
896            member.steps_trained += 1;
897        }
898        // Exploit
899        self.exploit(rng, config.exploit_frac);
900        // Explore
901        let new_params: Vec<Vec<f64>> = self
902            .members
903            .iter()
904            .map(|m| Self::explore_member(&m.params, space, config.explore_noise, rng))
905            .collect();
906        for (member, params) in self.members.iter_mut().zip(new_params) {
907            member.params = params;
908        }
909    }
910}
911
912/// Standalone population-based training runner.
913#[derive(Debug, Clone)]
914pub struct PopulationBasedTraining {
915    /// The population.
916    pub population: PbtPopulation,
917    /// Configuration.
918    pub config: PbtConfig,
919    /// Hyperparameter space.
920    pub space: HpSpace,
921}
922
923impl PopulationBasedTraining {
924    /// Create a new PBT instance with a randomly initialised population.
925    pub fn new(space: HpSpace, config: PbtConfig, rng: &mut StdRng) -> Self {
926        let population = PbtPopulation::random(&space, config.population_size, rng);
927        Self {
928            population,
929            config,
930            space,
931        }
932    }
933
934    /// Run one PBT step given the current scores for each worker.
935    pub fn step(&mut self, scores: &[f64], rng: &mut StdRng) {
936        let config = self.config.clone();
937        let space = self.space.clone();
938        self.population.pbt_step(scores, &space, &config, rng);
939    }
940
941    /// Return the best member in the current population.
942    pub fn best(&self) -> Option<&PbtMember> {
943        self.population.members.iter().max_by(|a, b| {
944            a.score
945                .partial_cmp(&b.score)
946                .unwrap_or(std::cmp::Ordering::Equal)
947        })
948    }
949}
950
951// ─────────────────────────────────────────────────────────────────────────────
952// §6. CMA-ES for HPO
953// ─────────────────────────────────────────────────────────────────────────────
954
955/// Configuration for CMA-ES HPO.
956#[derive(Debug, Clone)]
957pub struct CmaHpoConfig {
958    /// Problem dimensionality.
959    pub n_dims: usize,
960    /// Initial step size σ₀.
961    pub sigma0: f64,
962    /// Population size λ (default: 4 + floor(3 ln d)).
963    pub lambda: usize,
964}
965
966impl CmaHpoConfig {
967    /// Create a config with auto-computed λ.
968    pub fn new(n_dims: usize, sigma0: f64) -> Self {
969        let lambda = (4.0 + (3.0 * (n_dims as f64).ln()).floor()) as usize;
970        Self {
971            n_dims,
972            sigma0,
973            lambda: lambda.max(4),
974        }
975    }
976}
977
978/// State of a running CMA-ES optimiser.
979#[derive(Debug, Clone)]
980pub struct CmaHpoState {
981    /// Distribution mean.
982    pub mean: Vec<f64>,
983    /// Step size σ.
984    pub sigma: f64,
985    /// Covariance matrix C (stored row-major, d×d).
986    pub cov: Vec<f64>,
987    /// Evolution path p_σ.
988    pub p_sigma: Vec<f64>,
989    /// Evolution path p_c.
990    pub p_c: Vec<f64>,
991    /// Eigenvalues of C.
992    pub eigenvalues: Vec<f64>,
993    /// Eigenvectors of C (row-major, d×d).
994    pub eigenvectors: Vec<Vec<f64>>,
995    /// Generation counter.
996    pub generation: usize,
997    /// Problem dimension.
998    pub n_dims: usize,
999}
1000
1001impl CmaHpoState {
1002    /// Initialise with a given mean and σ₀.
1003    pub fn new(mean: Vec<f64>, sigma0: f64) -> Self {
1004        let d = mean.len();
1005        let mut cov = vec![0.0_f64; d * d];
1006        for i in 0..d {
1007            cov[i * d + i] = 1.0;
1008        }
1009        let eigenvectors = (0..d)
1010            .map(|i| {
1011                let mut row = vec![0.0_f64; d];
1012                row[i] = 1.0;
1013                row
1014            })
1015            .collect();
1016        Self {
1017            mean,
1018            sigma: sigma0,
1019            cov,
1020            p_sigma: vec![0.0; d],
1021            p_c: vec![0.0; d],
1022            eigenvalues: vec![1.0; d],
1023            eigenvectors,
1024            generation: 0,
1025            n_dims: d,
1026        }
1027    }
1028}
1029
1030/// CMA-ES optimiser for hyperparameter optimisation.
1031#[derive(Debug, Clone)]
1032pub struct EvolutionaryStrategy {
1033    /// CMA-ES state.
1034    pub state: CmaHpoState,
1035    /// Configuration.
1036    pub config: CmaHpoConfig,
1037    /// Recombination weights.
1038    weights: Vec<f64>,
1039    /// Effective number of parents μ_eff.
1040    mu_eff: f64,
1041}
1042
1043impl EvolutionaryStrategy {
1044    /// Create a new CMA-ES optimiser.
1045    pub fn new(config: CmaHpoConfig, initial_mean: Vec<f64>) -> Self {
1046        assert_eq!(config.n_dims, initial_mean.len());
1047        let lambda = config.lambda;
1048        let mu = lambda / 2;
1049        // Recombination weights (log-linear)
1050        let weights: Vec<f64> = (0..mu).map(|i| (mu as f64 + 0.5 - i as f64).ln()).collect();
1051        let w_sum: f64 = weights.iter().sum();
1052        let weights: Vec<f64> = weights.iter().map(|w| w / w_sum).collect();
1053        let mu_eff = 1.0 / weights.iter().map(|w| w * w).sum::<f64>();
1054        let state = CmaHpoState::new(initial_mean, config.sigma0);
1055        Self {
1056            state,
1057            config,
1058            weights,
1059            mu_eff,
1060        }
1061    }
1062
1063    /// Sample λ candidate solutions.
1064    pub fn sample(&self, rng: &mut StdRng) -> Vec<Vec<f64>> {
1065        let d = self.state.n_dims;
1066        let lambda = self.config.lambda;
1067        (0..lambda).map(|_| self.sample_one(rng, d)).collect()
1068    }
1069
1070    fn sample_one(&self, rng: &mut StdRng, d: usize) -> Vec<f64> {
1071        // z ~ N(0, I)
1072        let z: Vec<f64> = (0..d).map(|_| standard_normal(rng)).collect();
1073        // x = mean + sigma * B D z
1074        let mut bd_z = vec![0.0_f64; d];
1075        for i in 0..d {
1076            let mut sum = 0.0;
1077            for j in 0..d {
1078                sum += self.state.eigenvectors[j][i] * self.state.eigenvalues[j].sqrt() * z[j];
1079            }
1080            bd_z[i] = sum;
1081        }
1082        (0..d)
1083            .map(|i| self.state.mean[i] + self.state.sigma * bd_z[i])
1084            .collect()
1085    }
1086
1087    /// Update the CMA-ES state given `selected` (mu best solutions) and `weights`.
1088    pub fn update(&mut self, selected: &[Vec<f64>]) {
1089        let d = self.state.n_dims;
1090        let mu = self.weights.len().min(selected.len());
1091        if mu == 0 {
1092            return;
1093        }
1094
1095        // Hansen (2016) CMA-ES constants
1096        let cc = (4.0 + self.mu_eff / d as f64) / (d as f64 + 4.0 + 2.0 * self.mu_eff / d as f64);
1097        let c_sigma = (self.mu_eff + 2.0) / (d as f64 + self.mu_eff + 5.0);
1098        let c1 = 2.0 / ((d as f64 + 1.3).powi(2) + self.mu_eff);
1099        let c_mu = (2.0 * (self.mu_eff - 2.0 + 1.0 / self.mu_eff))
1100            / ((d as f64 + 2.0).powi(2) + self.mu_eff);
1101        let d_sigma =
1102            1.0 + 2.0 * (((self.mu_eff - 1.0) / (d as f64 + 1.0)).sqrt() - 1.0).max(0.0) + c_sigma;
1103        let chi_n =
1104            (d as f64).sqrt() * (1.0 - 1.0 / (4.0 * d as f64) + 1.0 / (21.0 * d as f64 * d as f64));
1105
1106        // New mean
1107        let mut new_mean = vec![0.0_f64; d];
1108        for (w, x) in self.weights.iter().zip(selected[..mu].iter()) {
1109            for k in 0..d {
1110                new_mean[k] += w * x[k];
1111            }
1112        }
1113
1114        // Step δ = (new_mean - old_mean) / sigma
1115        let delta: Vec<f64> = (0..d)
1116            .map(|i| (new_mean[i] - self.state.mean[i]) / self.state.sigma)
1117            .collect();
1118
1119        // B^T delta (in eigenvector basis)
1120        let bt_delta: Vec<f64> = (0..d)
1121            .map(|i| {
1122                (0..d)
1123                    .map(|j| self.state.eigenvectors[i][j] * delta[j])
1124                    .sum::<f64>()
1125            })
1126            .collect();
1127
1128        // Update evolution path p_sigma
1129        let sq_mu_eff = self.mu_eff.sqrt();
1130        let h_sigma_val = {
1131            let norm_p: f64 = self.state.p_sigma.iter().map(|v| v * v).sum::<f64>().sqrt();
1132            let threshold = (1.4 + 2.0 / (d as f64 + 1.0)) * chi_n;
1133            if norm_p / (1.0 - (1.0 - c_sigma).powi(2 * (self.state.generation + 1) as i32)).sqrt()
1134                < threshold
1135            {
1136                1.0
1137            } else {
1138                0.0
1139            }
1140        };
1141        for i in 0..d {
1142            let d_inv_bt = bt_delta[i] / self.state.eigenvalues[i].sqrt().max(1e-15);
1143            self.state.p_sigma[i] = (1.0 - c_sigma) * self.state.p_sigma[i]
1144                + (c_sigma * (2.0 - c_sigma) * self.mu_eff).sqrt() * d_inv_bt;
1145        }
1146        // Update p_c
1147        for i in 0..d {
1148            self.state.p_c[i] = (1.0 - cc) * self.state.p_c[i]
1149                + h_sigma_val * (cc * (2.0 - cc) * self.mu_eff).sqrt() * delta[i];
1150        }
1151        // Update covariance
1152        let c1_term: Vec<f64> = (0..d * d)
1153            .map(|idx| {
1154                let r = idx / d;
1155                let c = idx % d;
1156                self.state.p_c[r] * self.state.p_c[c]
1157            })
1158            .collect();
1159        for idx in 0..d * d {
1160            let r = idx / d;
1161            let c = idx % d;
1162            let c_mu_sum: f64 = (0..mu)
1163                .map(|k| {
1164                    let di = (selected[k][r] - self.state.mean[r]) / self.state.sigma;
1165                    let dj = (selected[k][c] - self.state.mean[c]) / self.state.sigma;
1166                    self.weights[k] * di * dj
1167                })
1168                .sum();
1169            self.state.cov[idx] =
1170                (1.0 - c1 - c_mu) * self.state.cov[idx] + c1 * c1_term[idx] + c_mu * c_mu_sum;
1171        }
1172        // Update sigma via CSA
1173        let norm_ps: f64 = self.state.p_sigma.iter().map(|v| v * v).sum::<f64>().sqrt();
1174        self.state.sigma *= ((c_sigma / d_sigma) * (norm_ps / chi_n - 1.0)).exp();
1175
1176        // Eigen-decomposition (symmetric power iteration — simplified)
1177        let (evecs, evals) = eigen_decompose_sym(&self.state.cov, d);
1178        self.state.eigenvectors = evecs;
1179        self.state.eigenvalues = evals;
1180        self.state.mean = new_mean;
1181        self.state.generation += 1;
1182    }
1183}
1184
1185// ─────────────────────────────────────────────────────────────────────────────
1186// §7. Multi-Objective HPO
1187// ─────────────────────────────────────────────────────────────────────────────
1188
1189/// Configuration for multi-objective HPO.
1190#[derive(Debug, Clone)]
1191pub struct MoHpoConfig {
1192    /// Number of objectives.
1193    pub n_objectives: usize,
1194    /// Total number of evaluation iterations.
1195    pub n_iter: usize,
1196}
1197
1198/// A multi-objective observation.
1199#[derive(Debug, Clone)]
1200pub struct MoObservation {
1201    /// Parameter values.
1202    pub params: Vec<f64>,
1203    /// Objective values (one per objective; all are to be minimised).
1204    pub objectives: Vec<f64>,
1205}
1206
1207/// Return indices of Pareto-optimal observations (non-dominated set).
1208pub fn pareto_front(observations: &[MoObservation]) -> Vec<usize> {
1209    let n = observations.len();
1210    let mut dominated = vec![false; n];
1211    for i in 0..n {
1212        for j in 0..n {
1213            if i == j {
1214                continue;
1215            }
1216            // Check if j dominates i (j ≤ i in all objectives, strictly in at least one)
1217            let all_le = observations[j]
1218                .objectives
1219                .iter()
1220                .zip(observations[i].objectives.iter())
1221                .all(|(&oj, &oi)| oj <= oi);
1222            let any_lt = observations[j]
1223                .objectives
1224                .iter()
1225                .zip(observations[i].objectives.iter())
1226                .any(|(&oj, &oi)| oj < oi);
1227            if all_le && any_lt {
1228                dominated[i] = true;
1229                break;
1230            }
1231        }
1232    }
1233    (0..n).filter(|&i| !dominated[i]).collect()
1234}
1235
1236/// Compute approximate hypervolume contribution of each observation w.r.t. a reference point.
1237///
1238/// Uses a simple bounding-box approximation: contribution of point i is
1239/// the product of (ref - obj_i) over all objectives (bounded from below at 0).
1240pub fn hypervolume_contribution(obs: &[MoObservation], ref_point: &[f64]) -> Vec<f64> {
1241    obs.iter()
1242        .map(|o| {
1243            o.objectives
1244                .iter()
1245                .zip(ref_point.iter())
1246                .map(|(&obj, &r)| (r - obj).max(0.0))
1247                .product()
1248        })
1249        .collect()
1250}
1251
1252/// NSGA-II crowding-distance selection: return `n` indices.
1253pub fn nsga2_select(obs: &[MoObservation], n: usize) -> Vec<usize> {
1254    if obs.is_empty() || n == 0 {
1255        return vec![];
1256    }
1257    // Non-dominated sorting
1258    let mut remaining: Vec<usize> = (0..obs.len()).collect();
1259    let mut result: Vec<usize> = Vec::with_capacity(n);
1260    while result.len() < n && !remaining.is_empty() {
1261        // Find Pareto front within remaining
1262        let front_indices = {
1263            let subset: Vec<MoObservation> = remaining.iter().map(|&i| obs[i].clone()).collect();
1264            pareto_front(&subset)
1265                .into_iter()
1266                .map(|local_i| remaining[local_i])
1267                .collect::<Vec<_>>()
1268        };
1269        let need = n - result.len();
1270        if front_indices.len() <= need {
1271            result.extend_from_slice(&front_indices);
1272        } else {
1273            // Sort front by crowding distance (descending) and take `need`
1274            let mut cd = crowding_distance(obs, &front_indices);
1275            cd.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
1276            result.extend(cd.into_iter().take(need).map(|(i, _)| i));
1277        }
1278        let front_set: std::collections::HashSet<usize> = front_indices.iter().cloned().collect();
1279        remaining.retain(|i| !front_set.contains(i));
1280    }
1281    result
1282}
1283
1284fn crowding_distance(obs: &[MoObservation], indices: &[usize]) -> Vec<(usize, f64)> {
1285    let m = obs.first().map(|o| o.objectives.len()).unwrap_or(0);
1286    let n = indices.len();
1287    let mut distances = vec![0.0_f64; n];
1288    for obj_idx in 0..m {
1289        let mut sorted: Vec<(usize, f64)> = indices
1290            .iter()
1291            .enumerate()
1292            .map(|(local, &global)| (local, obs[global].objectives[obj_idx]))
1293            .collect();
1294        sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1295        distances[sorted[0].0] = f64::INFINITY;
1296        distances[sorted[n - 1].0] = f64::INFINITY;
1297        let f_min = sorted[0].1;
1298        let f_max = sorted[n - 1].1;
1299        let range = (f_max - f_min).max(1e-15);
1300        for k in 1..n - 1 {
1301            distances[sorted[k].0] += (sorted[k + 1].1 - sorted[k - 1].1) / range;
1302        }
1303    }
1304    indices
1305        .iter()
1306        .enumerate()
1307        .map(|(local, &global)| (global, distances[local]))
1308        .collect()
1309}
1310
1311/// Multi-objective HPO optimiser.
1312#[derive(Debug, Clone)]
1313pub struct MultiObjectiveHpo {
1314    /// Configuration.
1315    pub config: MoHpoConfig,
1316    /// Hyperparameter space.
1317    pub space: HpSpace,
1318    /// All recorded observations.
1319    pub observations: Vec<MoObservation>,
1320}
1321
1322impl MultiObjectiveHpo {
1323    /// Create a new multi-objective HPO instance.
1324    pub fn new(space: HpSpace, config: MoHpoConfig) -> Self {
1325        Self {
1326            config,
1327            space,
1328            observations: vec![],
1329        }
1330    }
1331
1332    /// Add an observation.
1333    pub fn observe(&mut self, params: Vec<f64>, objectives: Vec<f64>) {
1334        self.observations.push(MoObservation { params, objectives });
1335    }
1336
1337    /// Return the current Pareto front.
1338    pub fn pareto_front(&self) -> Vec<usize> {
1339        pareto_front(&self.observations)
1340    }
1341
1342    /// Suggest the next candidate (random for now — can be extended with SMS-EGO etc.).
1343    pub fn suggest(&self, rng: &mut StdRng) -> Vec<f64> {
1344        self.space.sample_random(rng)
1345    }
1346}
1347
1348// ─────────────────────────────────────────────────────────────────────────────
1349// §8. Early Termination
1350// ─────────────────────────────────────────────────────────────────────────────
1351
1352/// Median-stopping rule: stop a trial if its value is below the median of
1353/// all other trials at the same step.
1354#[derive(Debug, Clone)]
1355pub struct MedianStopping {
1356    /// Minimum number of steps before considering stopping.
1357    pub patience: usize,
1358    /// Minimum number of completed trials required before applying the rule.
1359    pub min_trials: usize,
1360}
1361
1362impl MedianStopping {
1363    /// Create a new median-stopping instance.
1364    pub fn new(patience: usize, min_trials: usize) -> Self {
1365        Self {
1366            patience,
1367            min_trials,
1368        }
1369    }
1370
1371    /// Return `true` if the trial should be stopped.
1372    ///
1373    /// `trial_curve` — sequence of metric values for the current trial (step-by-step).
1374    /// `all_curves` — sequences for all other completed trials.
1375    pub fn should_stop(&self, trial_curve: &[f64], all_curves: &[Vec<f64>]) -> bool {
1376        let step = trial_curve.len();
1377        if step < self.patience || all_curves.len() < self.min_trials {
1378            return false;
1379        }
1380        let current_val = match trial_curve.last() {
1381            Some(&v) => v,
1382            None => return false,
1383        };
1384        // Collect the best value at this step from all other trials
1385        let mut other_bests: Vec<f64> = all_curves
1386            .iter()
1387            .filter_map(|c| c.get(step - 1).copied())
1388            .collect();
1389        if other_bests.is_empty() {
1390            return false;
1391        }
1392        other_bests.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1393        let median = percentile_sorted(&other_bests, 50.0);
1394        // Stop if current value is below median (assuming higher is better)
1395        current_val < median
1396    }
1397}
1398
1399/// Successive halving early-termination rule.
1400#[derive(Debug, Clone)]
1401pub struct SuccessiveHalving {
1402    /// Minimum budget allocated per trial in round 0.
1403    pub min_budget: f64,
1404    /// Maximum budget.
1405    pub max_budget: f64,
1406    /// Halving factor η.
1407    pub eta: f64,
1408}
1409
1410impl SuccessiveHalving {
1411    /// Create a new successive halving instance.
1412    pub fn new(min_budget: f64, max_budget: f64, eta: f64) -> Self {
1413        Self {
1414            min_budget,
1415            max_budget,
1416            eta,
1417        }
1418    }
1419
1420    /// Return the budget for a given round index.
1421    pub fn budget_for_round(&self, round: usize) -> f64 {
1422        (self.min_budget * self.eta.powi(round as i32)).min(self.max_budget)
1423    }
1424
1425    /// Total number of rounds from min to max budget.
1426    pub fn n_rounds(&self) -> usize {
1427        (self.max_budget / self.min_budget)
1428            .max(1.0)
1429            .log(self.eta)
1430            .ceil() as usize
1431    }
1432}
1433
1434/// Percentile-based early stopping: stop if the current value is below
1435/// the `percentile`-th percentile of historical values at the same step.
1436#[derive(Debug, Clone)]
1437pub struct PercentileStop {
1438    /// Percentile threshold (0-100).
1439    pub percentile: f64,
1440    /// Minimum number of steps before this rule applies.
1441    pub min_steps: usize,
1442}
1443
1444impl PercentileStop {
1445    /// Create a new percentile-stop instance.
1446    pub fn new(percentile: f64, min_steps: usize) -> Self {
1447        Self {
1448            percentile,
1449            min_steps,
1450        }
1451    }
1452
1453    /// Return `true` if `value` at `step` is below the `percentile` of `history`.
1454    pub fn should_stop_percentile(&self, value: f64, step: usize, history: &[f64]) -> bool {
1455        if step < self.min_steps || history.is_empty() {
1456            return false;
1457        }
1458        let mut sorted = history.to_vec();
1459        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1460        let threshold = percentile_sorted(&sorted, self.percentile);
1461        value < threshold
1462    }
1463}
1464
1465// ─────────────────────────────────────────────────────────────────────────────
1466// §9. HpoLogger / HpoStudy
1467// ─────────────────────────────────────────────────────────────────────────────
1468
1469/// Status of a single trial.
1470#[derive(Debug, Clone, PartialEq, Eq)]
1471pub enum TrialStatus {
1472    /// Trial is currently running.
1473    Running,
1474    /// Trial completed successfully.
1475    Complete,
1476    /// Trial was pruned by early-termination.
1477    Pruned,
1478    /// Trial failed due to an error.
1479    Failed,
1480}
1481
1482/// A single HPO trial record.
1483#[derive(Debug, Clone)]
1484pub struct HpoTrial {
1485    /// Unique trial ID.
1486    pub id: usize,
1487    /// Hyperparameter values.
1488    pub params: Vec<f64>,
1489    /// Names corresponding to `params`.
1490    pub param_names: Vec<String>,
1491    /// Observed objective value.
1492    pub value: f64,
1493    /// Trial status.
1494    pub status: TrialStatus,
1495    /// Start time (milliseconds since epoch, or arbitrary counter).
1496    pub start_ms: u64,
1497    /// End time.
1498    pub end_ms: u64,
1499}
1500
1501impl HpoTrial {
1502    /// Create a new trial record.
1503    pub fn new(
1504        id: usize,
1505        params: Vec<f64>,
1506        param_names: Vec<String>,
1507        value: f64,
1508        status: TrialStatus,
1509    ) -> Self {
1510        Self {
1511            id,
1512            params,
1513            param_names,
1514            value,
1515            status,
1516            start_ms: 0,
1517            end_ms: 0,
1518        }
1519    }
1520}
1521
1522/// Direction of optimisation.
1523#[derive(Debug, Clone, PartialEq, Eq)]
1524pub enum OptDirection {
1525    /// Minimise the objective.
1526    Minimize,
1527    /// Maximise the objective.
1528    Maximize,
1529}
1530
1531/// An HPO study, collecting trials and providing analysis.
1532#[derive(Debug, Clone)]
1533pub struct HpoStudy {
1534    /// Study name.
1535    pub name: String,
1536    /// All recorded trials.
1537    pub trials: Vec<HpoTrial>,
1538    /// Optimisation direction.
1539    pub direction: OptDirection,
1540}
1541
1542impl HpoStudy {
1543    /// Create a new study.
1544    pub fn new(name: impl Into<String>, direction: OptDirection) -> Self {
1545        Self {
1546            name: name.into(),
1547            trials: vec![],
1548            direction,
1549        }
1550    }
1551
1552    /// Add a completed trial.
1553    pub fn add_trial(&mut self, trial: HpoTrial) {
1554        self.trials.push(trial);
1555    }
1556
1557    /// Return the best trial.
1558    pub fn best_trial(&self) -> Option<&HpoTrial> {
1559        let complete: Vec<&HpoTrial> = self
1560            .trials
1561            .iter()
1562            .filter(|t| t.status == TrialStatus::Complete)
1563            .collect();
1564        match self.direction {
1565            OptDirection::Minimize => complete.into_iter().min_by(|a, b| {
1566                a.value
1567                    .partial_cmp(&b.value)
1568                    .unwrap_or(std::cmp::Ordering::Equal)
1569            }),
1570            OptDirection::Maximize => complete.into_iter().max_by(|a, b| {
1571                a.value
1572                    .partial_cmp(&b.value)
1573                    .unwrap_or(std::cmp::Ordering::Equal)
1574            }),
1575        }
1576    }
1577}
1578
1579/// Stand-alone helper: return the best trial from a study.
1580pub fn best_trial(study: &HpoStudy) -> Option<&HpoTrial> {
1581    study.best_trial()
1582}
1583
1584/// Compute simplified fANOVA-style importances by measuring the variance of
1585/// the objective when each parameter is varied independently.
1586///
1587/// For each parameter dim `k`, the importance is:
1588///   Var[ E[y | x_k] ] / Var\[y\]
1589///
1590/// Approximated by binning parameter `k` into 5 equal-width buckets and
1591/// computing the variance of bucket means.
1592pub fn importance_by_fanova(study: &HpoStudy) -> Vec<(String, f64)> {
1593    let complete: Vec<&HpoTrial> = study
1594        .trials
1595        .iter()
1596        .filter(|t| t.status == TrialStatus::Complete)
1597        .collect();
1598    if complete.is_empty() {
1599        return vec![];
1600    }
1601    let n_params = complete[0].params.len();
1602    let values: Vec<f64> = complete.iter().map(|t| t.value).collect();
1603    let total_var = variance(&values);
1604    if total_var < 1e-15 {
1605        return complete[0]
1606            .param_names
1607            .iter()
1608            .map(|n| (n.clone(), 0.0))
1609            .collect();
1610    }
1611    let n_buckets = 5usize;
1612    (0..n_params)
1613        .map(|k| {
1614            let name = complete
1615                .first()
1616                .and_then(|t| t.param_names.get(k))
1617                .cloned()
1618                .unwrap_or_else(|| format!("param_{k}"));
1619            let param_vals: Vec<f64> = complete.iter().map(|t| t.params[k]).collect();
1620            let p_min = param_vals.iter().cloned().fold(f64::INFINITY, f64::min);
1621            let p_max = param_vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
1622            let range = (p_max - p_min).max(1e-15);
1623            let mut bucket_means = vec![];
1624            for b in 0..n_buckets {
1625                let lo = p_min + (b as f64 / n_buckets as f64) * range;
1626                let hi = p_min + ((b + 1) as f64 / n_buckets as f64) * range;
1627                let bucket_vals: Vec<f64> = complete
1628                    .iter()
1629                    .filter(|t| {
1630                        let pv = t.params[k];
1631                        pv >= lo && (pv < hi || b == n_buckets - 1)
1632                    })
1633                    .map(|t| t.value)
1634                    .collect();
1635                if !bucket_vals.is_empty() {
1636                    bucket_means.push(mean(&bucket_vals));
1637                }
1638            }
1639            let importance = if bucket_means.len() > 1 {
1640                variance(&bucket_means) / total_var
1641            } else {
1642                0.0
1643            };
1644            (name, importance)
1645        })
1646        .collect()
1647}
1648
1649/// Convenience struct wrapping study-level logging functionality.
1650#[derive(Debug, Clone, Default)]
1651pub struct HpoLogger {
1652    /// Internal study.
1653    pub study: Option<HpoStudy>,
1654}
1655
1656impl HpoLogger {
1657    /// Create a logger with a new study.
1658    pub fn new(name: impl Into<String>, direction: OptDirection) -> Self {
1659        Self {
1660            study: Some(HpoStudy::new(name, direction)),
1661        }
1662    }
1663
1664    /// Log a new trial.
1665    pub fn log_trial(&mut self, trial: HpoTrial) {
1666        if let Some(s) = &mut self.study {
1667            s.add_trial(trial);
1668        }
1669    }
1670
1671    /// Return the best trial across all logged runs.
1672    pub fn best(&self) -> Option<&HpoTrial> {
1673        self.study.as_ref().and_then(|s| s.best_trial())
1674    }
1675}
1676
1677// ─────────────────────────────────────────────────────────────────────────────
1678// §10. Transfer-Learning HPO (Warm Starting)
1679// ─────────────────────────────────────────────────────────────────────────────
1680
1681/// A previous HPO study used for warm starting.
1682#[derive(Debug, Clone)]
1683pub struct PreviousStudy {
1684    /// Names of the parameters in this study.
1685    pub param_names: Vec<String>,
1686    /// Observed (params, score) pairs.
1687    pub trials: Vec<(Vec<f64>, f64)>,
1688}
1689
1690impl PreviousStudy {
1691    /// Construct a previous study.
1692    pub fn new(param_names: Vec<String>, trials: Vec<(Vec<f64>, f64)>) -> Self {
1693        Self {
1694            param_names,
1695            trials,
1696        }
1697    }
1698}
1699
1700/// Warm-start sampler — uses top results from previous studies as initial
1701/// candidates for a new study.
1702#[derive(Debug, Clone)]
1703pub struct WarmStartSampler {
1704    /// Previously completed studies.
1705    pub previous: Vec<PreviousStudy>,
1706}
1707
1708impl WarmStartSampler {
1709    /// Create a new sampler from a set of previous studies.
1710    pub fn new(previous: Vec<PreviousStudy>) -> Self {
1711        Self { previous }
1712    }
1713
1714    /// Map a parameter vector from a source study into the target space by
1715    /// matching parameters by name and normalising.
1716    pub fn map_parameters(
1717        params: &[f64],
1718        source_names: &[String],
1719        target_space: &HpSpace,
1720    ) -> Vec<f64> {
1721        target_space
1722            .params
1723            .iter()
1724            .map(|hp| {
1725                // Find matching dimension in source by name
1726                source_names
1727                    .iter()
1728                    .position(|n| n == &hp.name)
1729                    .and_then(|idx| params.get(idx).copied())
1730                    .map(|v| {
1731                        // Re-normalise to target bounds: assume source already in target range
1732                        v.clamp(hp.bounds.0, hp.bounds.1)
1733                    })
1734                    .unwrap_or_else(|| {
1735                        // Default: midpoint of target bounds
1736                        (hp.bounds.0 + hp.bounds.1) / 2.0
1737                    })
1738            })
1739            .collect()
1740    }
1741
1742    /// Return the top-`n` warm-start configurations from all previous studies.
1743    /// They are projected into the `target_space` by name-matching.
1744    pub fn select_warm_starts(&self, target_space: &HpSpace, n: usize) -> Vec<Vec<f64>> {
1745        // Collect all (projected_params, score) from previous studies
1746        let mut all: Vec<(Vec<f64>, f64)> = self
1747            .previous
1748            .iter()
1749            .flat_map(|study| {
1750                study.trials.iter().map(|(params, score)| {
1751                    let mapped = Self::map_parameters(params, &study.param_names, target_space);
1752                    (mapped, *score)
1753                })
1754            })
1755            .collect();
1756        // Sort by score descending
1757        all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1758        all.into_iter().take(n).map(|(p, _)| p).collect()
1759    }
1760}
1761
1762// ─────────────────────────────────────────────────────────────────────────────
1763// Linear algebra helpers (private)
1764// ─────────────────────────────────────────────────────────────────────────────
1765
1766/// Cholesky decomposition: A = L L^T.  Returns L (lower triangular).
1767fn cholesky(a: &[Vec<f64>]) -> std::result::Result<Vec<Vec<f64>>, String> {
1768    let n = a.len();
1769    let mut l = vec![vec![0.0_f64; n]; n];
1770    for i in 0..n {
1771        for j in 0..=i {
1772            let sum: f64 = (0..j).map(|k| l[i][k] * l[j][k]).sum();
1773            if i == j {
1774                let val = a[i][i] - sum;
1775                if val < 0.0 {
1776                    return Err(format!("Matrix not positive definite at ({i},{i}): {val}"));
1777                }
1778                l[i][j] = val.sqrt();
1779            } else {
1780                let ljj = l[j][j];
1781                if ljj.abs() < 1e-300 {
1782                    return Err("Zero diagonal in Cholesky".into());
1783                }
1784                l[i][j] = (a[i][j] - sum) / ljj;
1785            }
1786        }
1787    }
1788    Ok(l)
1789}
1790
1791/// Solve L x = b (forward substitution, L lower triangular).
1792fn forward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1793    let n = b.len();
1794    let mut x = vec![0.0_f64; n];
1795    for i in 0..n {
1796        let sum: f64 = (0..i).map(|j| l[i][j] * x[j]).sum();
1797        let lii = l[i][i];
1798        x[i] = if lii.abs() < 1e-300 {
1799            0.0
1800        } else {
1801            (b[i] - sum) / lii
1802        };
1803    }
1804    x
1805}
1806
1807/// Solve L^T x = b (backward substitution, L lower triangular).
1808fn backward_sub(l: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1809    let n = b.len();
1810    let mut x = vec![0.0_f64; n];
1811    for i in (0..n).rev() {
1812        let sum: f64 = (i + 1..n).map(|j| l[j][i] * x[j]).sum();
1813        let lii = l[i][i];
1814        x[i] = if lii.abs() < 1e-300 {
1815            0.0
1816        } else {
1817            (b[i] - sum) / lii
1818        };
1819    }
1820    x
1821}
1822
1823/// Solve (L L^T) x = b via Cholesky (forward then backward substitution).
1824fn chol_solve(l: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>> {
1825    let y = forward_sub(l, b);
1826    Ok(backward_sub(l, &y))
1827}
1828
1829// ─────────────────────────────────────────────────────────────────────────────
1830// Statistics helpers (private)
1831// ─────────────────────────────────────────────────────────────────────────────
1832
1833fn normal_pdf(x: f64) -> f64 {
1834    (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt()
1835}
1836
1837fn normal_cdf(x: f64) -> f64 {
1838    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
1839}
1840
1841/// Abramowitz & Stegun approximation to erf(x).
1842fn erf(x: f64) -> f64 {
1843    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
1844    let x = x.abs();
1845    let t = 1.0 / (1.0 + 0.3275911 * x);
1846    let poly = t
1847        * (0.254829592
1848            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
1849    sign * (1.0 - poly * (-x * x).exp())
1850}
1851
1852fn mean(v: &[f64]) -> f64 {
1853    if v.is_empty() {
1854        return 0.0;
1855    }
1856    v.iter().sum::<f64>() / v.len() as f64
1857}
1858
1859fn variance(v: &[f64]) -> f64 {
1860    if v.len() < 2 {
1861        return 0.0;
1862    }
1863    let m = mean(v);
1864    v.iter().map(|&x| (x - m).powi(2)).sum::<f64>() / (v.len() - 1) as f64
1865}
1866
1867/// Compute the `p`-th percentile from a pre-sorted slice.
1868fn percentile_sorted(sorted: &[f64], p: f64) -> f64 {
1869    if sorted.is_empty() {
1870        return 0.0;
1871    }
1872    let idx = ((p / 100.0) * (sorted.len() - 1) as f64).round() as usize;
1873    sorted[idx.min(sorted.len() - 1)]
1874}
1875
1876/// Generate a standard normal sample using the Box-Muller transform.
1877fn standard_normal(rng: &mut StdRng) -> f64 {
1878    let u1: f64 = rng.random::<f64>().max(1e-300);
1879    let u2: f64 = rng.random();
1880    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1881}
1882
1883// ─────────────────────────────────────────────────────────────────────────────
1884// Eigen-decomposition helper (Jacobi iterations for symmetric matrices)
1885// ─────────────────────────────────────────────────────────────────────────────
1886
1887/// Symmetric Jacobi eigen-decomposition for a d×d matrix stored row-major.
1888/// Returns (eigenvectors row-major d×d, eigenvalues d).
1889fn eigen_decompose_sym(a_flat: &[f64], d: usize) -> (Vec<Vec<f64>>, Vec<f64>) {
1890    let mut a: Vec<Vec<f64>> = (0..d)
1891        .map(|i| (0..d).map(|j| a_flat[i * d + j]).collect())
1892        .collect();
1893    // Identity eigenvector matrix
1894    let mut v: Vec<Vec<f64>> = (0..d)
1895        .map(|i| {
1896            let mut row = vec![0.0_f64; d];
1897            row[i] = 1.0;
1898            row
1899        })
1900        .collect();
1901    let max_iter = 100 * d * d;
1902    for _ in 0..max_iter {
1903        // Find largest off-diagonal element
1904        let mut p = 0usize;
1905        let mut q = 1usize;
1906        let mut max_val = 0.0_f64;
1907        for i in 0..d {
1908            for j in i + 1..d {
1909                let val = a[i][j].abs();
1910                if val > max_val {
1911                    max_val = val;
1912                    p = i;
1913                    q = j;
1914                }
1915            }
1916        }
1917        if max_val < 1e-10 {
1918            break;
1919        }
1920        // Compute rotation angle
1921        let theta = if (a[q][q] - a[p][p]).abs() < 1e-15 {
1922            std::f64::consts::FRAC_PI_4
1923        } else {
1924            0.5 * ((2.0 * a[p][q]) / (a[q][q] - a[p][p])).atan()
1925        };
1926        let (s, c) = (theta.sin(), theta.cos());
1927        // Rotate rows/cols p, q
1928        let app = c * c * a[p][p] - 2.0 * s * c * a[p][q] + s * s * a[q][q];
1929        let aqq = s * s * a[p][p] + 2.0 * s * c * a[p][q] + c * c * a[q][q];
1930        a[p][q] = 0.0;
1931        a[q][p] = 0.0;
1932        a[p][p] = app;
1933        a[q][q] = aqq;
1934        for r in 0..d {
1935            if r != p && r != q {
1936                let apr = c * a[p][r] - s * a[q][r];
1937                let aqr = s * a[p][r] + c * a[q][r];
1938                a[p][r] = apr;
1939                a[r][p] = apr;
1940                a[q][r] = aqr;
1941                a[r][q] = aqr;
1942            }
1943        }
1944        // Update eigenvectors
1945        for r in 0..d {
1946            let vpr = c * v[r][p] - s * v[r][q];
1947            let vqr = s * v[r][p] + c * v[r][q];
1948            v[r][p] = vpr;
1949            v[r][q] = vqr;
1950        }
1951    }
1952    let eigenvalues: Vec<f64> = (0..d).map(|i| a[i][i]).collect();
1953    (v, eigenvalues)
1954}
1955
1956// ─────────────────────────────────────────────────────────────────────────────
1957// Tests
1958// ─────────────────────────────────────────────────────────────────────────────
1959
1960#[cfg(test)]
1961mod tests;