Skip to main content

scry_learn/search/
bayes.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Bayesian hyperparameter optimization via Tree-structured Parzen Estimator (TPE).
3//!
4//! [`BayesSearchCV`] uses a TPE surrogate model to guide the search towards
5//! promising regions of the hyperparameter space, typically finding good
6//! configurations in fewer evaluations than grid or random search.
7//!
8//! # Examples
9//!
10//! ```ignore
11//! use scry_learn::prelude::*;
12//! use scry_learn::search::*;
13//!
14//! let mut space = ParamSpace::new();
15//! space.insert("max_depth".into(), ParamDistribution::IntUniform { low: 2, high: 10 });
16//! space.insert("learning_rate".into(), ParamDistribution::LogUniform { low: 0.001, high: 1.0 });
17//!
18//! let result = BayesSearchCV::new(GradientBoostingClassifier::new(), space)
19//!     .n_iter(30)
20//!     .cv(5)
21//!     .scoring(accuracy)
22//!     .fit(&data)
23//!     .unwrap();
24//!
25//! println!("Best: {:?} → {:.3}", result.best_params(), result.best_score());
26//! ```
27
28use std::collections::HashMap;
29
30use crate::dataset::Dataset;
31use crate::error::{Result, ScryLearnError};
32use crate::metrics::accuracy;
33use crate::rng::FastRng;
34use crate::split::{k_fold, stratified_k_fold, ScoringFn};
35
36use super::{evaluate_combo, CvResult, ParamValue, Tunable};
37
38// ---------------------------------------------------------------------------
39// ParamDistribution + ParamSpace
40// ---------------------------------------------------------------------------
41
42/// A distribution from which hyperparameter values can be sampled.
43///
44/// Used with [`BayesSearchCV`] to define a continuous or discrete search space
45/// for each hyperparameter.
46///
47/// # Examples
48///
49/// ```
50/// use scry_learn::search::ParamDistribution;
51///
52/// let lr = ParamDistribution::LogUniform { low: 0.001, high: 1.0 };
53/// let depth = ParamDistribution::IntUniform { low: 2, high: 10 };
54/// ```
55#[derive(Debug, Clone)]
56#[non_exhaustive]
57pub enum ParamDistribution {
58    /// A set of discrete candidate values (any [`ParamValue`] variant).
59    Categorical(Vec<ParamValue>),
60    /// Continuous uniform distribution over `[low, high]`.
61    Uniform {
62        /// Lower bound (inclusive).
63        low: f64,
64        /// Upper bound (inclusive).
65        high: f64,
66    },
67    /// Log-uniform distribution over `[low, high]` (sampled in log space).
68    /// Both `low` and `high` must be positive.
69    LogUniform {
70        /// Lower bound (inclusive, positive).
71        low: f64,
72        /// Upper bound (inclusive, positive).
73        high: f64,
74    },
75    /// Discrete uniform distribution over integers `[low, high]`.
76    IntUniform {
77        /// Lower bound (inclusive).
78        low: usize,
79        /// Upper bound (inclusive).
80        high: usize,
81    },
82}
83
84/// A mapping from parameter names to their search distributions.
85///
86/// # Examples
87///
88/// ```
89/// use std::collections::HashMap;
90/// use scry_learn::search::{ParamDistribution, ParamSpace};
91///
92/// let mut space = ParamSpace::new();
93/// space.insert("max_depth".into(), ParamDistribution::IntUniform { low: 2, high: 10 });
94/// ```
95pub type ParamSpace = HashMap<String, ParamDistribution>;
96
97// ---------------------------------------------------------------------------
98// BayesSearchCV
99// ---------------------------------------------------------------------------
100
101/// Bayesian hyperparameter optimization with cross-validation.
102///
103/// Uses a Tree-structured Parzen Estimator (TPE) to model the objective
104/// function and focus evaluations on the most promising hyperparameter
105/// combinations.
106///
107/// # Algorithm
108///
109/// 1. Evaluate `n_initial` random samples to bootstrap the surrogate model.
110/// 2. For each remaining iteration, split observed results at the `gamma`
111///    quantile into "good" and "bad" groups.
112/// 3. Build factored 1D kernel density estimates for each group.
113/// 4. Draw 100 random candidates and pick the one maximizing `l(x) / g(x)`.
114/// 5. Evaluate the chosen candidate and add it to the history.
115///
116/// # Examples
117///
118/// ```ignore
119/// use scry_learn::prelude::*;
120/// use scry_learn::search::*;
121///
122/// let mut space = ParamSpace::new();
123/// space.insert("max_depth".into(), ParamDistribution::IntUniform { low: 2, high: 10 });
124///
125/// let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
126///     .n_iter(20)
127///     .cv(3)
128///     .fit(&data)
129///     .unwrap();
130///
131/// println!("Best score: {:.3}", result.best_score());
132/// ```
133#[non_exhaustive]
134pub struct BayesSearchCV {
135    base_model: Box<dyn Tunable>,
136    param_space: ParamSpace,
137    n_iter: usize,
138    n_initial: usize,
139    gamma: f64,
140    cv: usize,
141    scorer: ScoringFn,
142    seed: u64,
143    stratified: bool,
144    // Results (populated after fit)
145    best_params_: Option<HashMap<String, ParamValue>>,
146    best_score_: f64,
147    cv_results_: Vec<CvResult>,
148}
149
150impl std::fmt::Debug for BayesSearchCV {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.debug_struct("BayesSearchCV")
153            .field("n_iter", &self.n_iter)
154            .field("n_initial", &self.n_initial)
155            .field("gamma", &self.gamma)
156            .field("cv", &self.cv)
157            .field("seed", &self.seed)
158            .field("stratified", &self.stratified)
159            .field("best_score_", &self.best_score_)
160            .field("cv_results_len", &self.cv_results_.len())
161            .finish()
162    }
163}
164
165impl BayesSearchCV {
166    /// Create a Bayesian search over the given model and parameter space.
167    ///
168    /// Defaults: 30 iterations, 10 initial random samples, gamma 0.25,
169    /// 5-fold CV, accuracy scorer, seed 42, non-stratified.
170    pub fn new(model: impl Tunable + 'static, param_space: ParamSpace) -> Self {
171        Self {
172            base_model: Box::new(model),
173            param_space,
174            n_iter: 30,
175            n_initial: 10,
176            gamma: 0.25,
177            cv: 5,
178            scorer: accuracy,
179            seed: 42,
180            stratified: false,
181            best_params_: None,
182            best_score_: f64::NEG_INFINITY,
183            cv_results_: Vec::new(),
184        }
185    }
186
187    /// Set the total number of iterations (default: 30).
188    pub fn n_iter(mut self, n: usize) -> Self {
189        self.n_iter = n;
190        self
191    }
192
193    /// Set the number of initial random exploration samples (default: 10).
194    pub fn n_initial(mut self, n: usize) -> Self {
195        self.n_initial = n;
196        self
197    }
198
199    /// Set the quantile threshold for splitting good/bad observations (default: 0.25).
200    pub fn gamma(mut self, gamma: f64) -> Self {
201        self.gamma = gamma;
202        self
203    }
204
205    /// Set the number of cross-validation folds (default: 5).
206    pub fn cv(mut self, k: usize) -> Self {
207        self.cv = k;
208        self
209    }
210
211    /// Set the scoring function (default: `accuracy`).
212    pub fn scoring(mut self, scorer: ScoringFn) -> Self {
213        self.scorer = scorer;
214        self
215    }
216
217    /// Set the random seed (default: 42).
218    pub fn seed(mut self, seed: u64) -> Self {
219        self.seed = seed;
220        self
221    }
222
223    /// Enable stratified k-fold CV (default: `false`).
224    ///
225    /// When `true`, uses [`stratified_k_fold`](crate::split::stratified_k_fold)
226    /// to preserve class proportions in each fold.
227    pub fn stratified(mut self, stratified: bool) -> Self {
228        self.stratified = stratified;
229        self
230    }
231
232    /// Run the Bayesian optimization search.
233    ///
234    /// Returns `self` for chained accessor calls.
235    pub fn fit(mut self, data: &Dataset) -> Result<Self> {
236        if self.cv < 2 {
237            return Err(ScryLearnError::InvalidParameter(format!(
238                "cv must be >= 2, got {}",
239                self.cv
240            )));
241        }
242        if self.param_space.is_empty() {
243            return Err(ScryLearnError::InvalidParameter(
244                "parameter space is empty".into(),
245            ));
246        }
247        if self.n_iter == 0 {
248            return Err(ScryLearnError::InvalidParameter(
249                "n_iter must be >= 1".into(),
250            ));
251        }
252
253        let folds = if self.stratified {
254            stratified_k_fold(data, self.cv, self.seed)
255        } else {
256            k_fold(data, self.cv, self.seed)
257        };
258
259        let mut rng = FastRng::new(self.seed);
260
261        // Sorted parameter names for deterministic ordering.
262        let param_names: Vec<String> = {
263            let mut names: Vec<String> = self.param_space.keys().cloned().collect();
264            names.sort();
265            names
266        };
267
268        // Phase 1: random exploration.
269        let n_initial = self.n_initial.min(self.n_iter);
270        for _ in 0..n_initial {
271            let combo = sample_random(&self.param_space, &param_names, &mut rng);
272            let result = evaluate_combo(&*self.base_model, &combo, &folds, self.scorer)?;
273            self.update_best(&result);
274            self.cv_results_.push(result);
275        }
276
277        // Phase 2: TPE-guided search.
278        let n_tpe = self.n_iter - n_initial;
279        for _ in 0..n_tpe {
280            // Split history into good/bad at gamma quantile.
281            let mut scores: Vec<f64> = self
282                .cv_results_
283                .iter()
284                .map(|r| r.mean_score)
285                .filter(|s| s.is_finite())
286                .collect();
287            scores.sort_by(|a, b| a.total_cmp(b));
288
289            let n_good = ((scores.len() as f64 * self.gamma).ceil() as usize).max(1);
290            let threshold = scores[scores.len().saturating_sub(n_good)];
291
292            let (good, bad): (Vec<&CvResult>, Vec<&CvResult>) = self
293                .cv_results_
294                .iter()
295                .filter(|r| r.mean_score.is_finite())
296                .partition(|r| r.mean_score >= threshold);
297
298            // If all observations are "good" (e.g. equal scores), fall back to random.
299            let combo = if bad.is_empty() {
300                sample_random(&self.param_space, &param_names, &mut rng)
301            } else {
302                // Build KDEs for good and bad, sample candidates, pick best EI.
303                let good_kde = build_factored_kde(&good, &param_names, &self.param_space);
304                let bad_kde = build_factored_kde(&bad, &param_names, &self.param_space);
305
306                let n_candidates = 100;
307                let mut best_candidate = sample_random(&self.param_space, &param_names, &mut rng);
308                let mut best_ei = f64::NEG_INFINITY;
309
310                for _ in 0..n_candidates {
311                    let candidate = sample_random(&self.param_space, &param_names, &mut rng);
312                    let l = evaluate_kde(&good_kde, &candidate, &param_names, &self.param_space);
313                    let g = evaluate_kde(&bad_kde, &candidate, &param_names, &self.param_space);
314                    let ei = if g > 1e-300 { l / g } else { l * 1e300 };
315                    if ei > best_ei {
316                        best_ei = ei;
317                        best_candidate = candidate;
318                    }
319                }
320                best_candidate
321            };
322
323            let result = evaluate_combo(&*self.base_model, &combo, &folds, self.scorer)?;
324            self.update_best(&result);
325            self.cv_results_.push(result);
326        }
327
328        if self.best_params_.is_none() {
329            return Err(ScryLearnError::InvalidParameter(
330                "all parameter combinations produced NaN scores".into(),
331            ));
332        }
333
334        Ok(self)
335    }
336
337    /// The best parameter combination found.
338    ///
339    /// # Panics
340    ///
341    /// Panics if called before [`fit`](Self::fit).
342    pub fn best_params(&self) -> &HashMap<String, ParamValue> {
343        self.best_params_.as_ref().expect("call fit() first")
344    }
345
346    /// The best mean CV score achieved.
347    pub fn best_score(&self) -> f64 {
348        self.best_score_
349    }
350
351    /// All evaluated combinations with their scores.
352    pub fn cv_results(&self) -> &[CvResult] {
353        &self.cv_results_
354    }
355
356    fn update_best(&mut self, result: &CvResult) {
357        if result.mean_score.is_finite()
358            && (self.best_params_.is_none() || result.mean_score > self.best_score_)
359        {
360            self.best_score_ = result.mean_score;
361            self.best_params_ = Some(result.params.clone());
362        }
363    }
364}
365
366// ---------------------------------------------------------------------------
367// Sampling helpers
368// ---------------------------------------------------------------------------
369
370/// Sample a random parameter combination from the search space.
371fn sample_random(
372    space: &ParamSpace,
373    param_names: &[String],
374    rng: &mut FastRng,
375) -> HashMap<String, ParamValue> {
376    let mut combo = HashMap::new();
377    for name in param_names {
378        let dist = &space[name];
379        let value = match dist {
380            ParamDistribution::Categorical(values) => {
381                let idx = rng.usize(0..values.len());
382                values[idx].clone()
383            }
384            ParamDistribution::Uniform { low, high } => {
385                ParamValue::Float(low + rng.f64() * (high - low))
386            }
387            ParamDistribution::LogUniform { low, high } => {
388                let log_low = low.ln();
389                let log_high = high.ln();
390                ParamValue::Float((log_low + rng.f64() * (log_high - log_low)).exp())
391            }
392            ParamDistribution::IntUniform { low, high } => {
393                if high > low {
394                    ParamValue::Int(low + rng.usize(0..=(high - low)))
395                } else {
396                    ParamValue::Int(*low)
397                }
398            }
399        };
400        combo.insert(name.clone(), value);
401    }
402    combo
403}
404
405// ---------------------------------------------------------------------------
406// Factored KDE (1D Gaussian kernels per dimension)
407// ---------------------------------------------------------------------------
408
409/// A factored KDE: for each parameter we store either continuous observations
410/// (normalized to [0,1]) or categorical frequency counts.
411enum ParamKde {
412    /// Normalized observations in [0,1] plus the Scott's-rule bandwidth.
413    Continuous {
414        observations: Vec<f64>,
415        bandwidth: f64,
416    },
417    /// Frequency of each categorical index, with Laplace smoothing applied.
418    Categorical {
419        /// Probability for each index.
420        probs: Vec<f64>,
421    },
422}
423
424/// One KDE per parameter dimension (factored assumption).
425struct FactoredKde {
426    kdes: Vec<(String, ParamKde)>,
427}
428
429/// Build a factored KDE from a set of CvResult observations.
430fn build_factored_kde(
431    observations: &[&CvResult],
432    param_names: &[String],
433    space: &ParamSpace,
434) -> FactoredKde {
435    let mut kdes = Vec::with_capacity(param_names.len());
436
437    for name in param_names {
438        let dist = &space[name];
439        if let ParamDistribution::Categorical(values) = dist {
440            let n_categories = values.len();
441            // Count frequencies with Laplace smoothing.
442            let mut counts = vec![1.0_f64; n_categories]; // Laplace prior
443            for obs in observations {
444                if let Some(val) = obs.params.get(name) {
445                    if let Some(idx) = values.iter().position(|v| v == val) {
446                        counts[idx] += 1.0;
447                    }
448                }
449            }
450            let total: f64 = counts.iter().sum();
451            let probs: Vec<f64> = counts.iter().map(|c| c / total).collect();
452            kdes.push((name.clone(), ParamKde::Categorical { probs }));
453        } else {
454            // Normalize observations to [0,1].
455            let obs_normalized: Vec<f64> = observations
456                .iter()
457                .filter_map(|r| r.params.get(name))
458                .map(|v| normalize_param(v, dist))
459                .collect();
460
461            // Scott's rule bandwidth: n^(-1/(d+4)) where d=1 for 1D.
462            let bw = if obs_normalized.is_empty() {
463                1.0
464            } else {
465                (obs_normalized.len() as f64).powf(-1.0 / 5.0)
466            };
467
468            kdes.push((
469                name.clone(),
470                ParamKde::Continuous {
471                    observations: obs_normalized,
472                    bandwidth: bw,
473                },
474            ));
475        }
476    }
477
478    FactoredKde { kdes }
479}
480
481/// Evaluate the factored KDE density at a candidate point.
482fn evaluate_kde(
483    kde: &FactoredKde,
484    candidate: &HashMap<String, ParamValue>,
485    _param_names: &[String],
486    space: &ParamSpace,
487) -> f64 {
488    let mut log_density = 0.0_f64;
489
490    for (name, param_kde) in &kde.kdes {
491        let Some(val) = candidate.get(name) else {
492            continue;
493        };
494        let dist = &space[name];
495
496        match param_kde {
497            ParamKde::Continuous {
498                observations,
499                bandwidth,
500            } => {
501                let x = normalize_param(val, dist);
502                let n = observations.len() as f64;
503                if n < 1.0 {
504                    continue;
505                }
506                // Mean of Gaussian kernel values.
507                let mut density_sum = 0.0_f64;
508                for &obs in observations {
509                    let z = (x - obs) / bandwidth;
510                    density_sum += (-0.5 * z * z).exp();
511                }
512                let density = density_sum / (n * bandwidth * (std::f64::consts::TAU).sqrt());
513                // Clamp to avoid log(0).
514                log_density += density.max(1e-300).ln();
515            }
516            ParamKde::Categorical { probs } => {
517                if let ParamDistribution::Categorical(values) = dist {
518                    if let Some(idx) = values.iter().position(|v| v == val) {
519                        log_density += probs[idx].max(1e-300).ln();
520                    } else {
521                        // Unknown category — use uniform fallback.
522                        log_density += (1.0 / probs.len() as f64).ln();
523                    }
524                }
525            }
526        }
527    }
528
529    log_density.exp()
530}
531
532/// Normalize a parameter value to [0, 1] given its distribution.
533fn normalize_param(value: &ParamValue, dist: &ParamDistribution) -> f64 {
534    match (value, dist) {
535        (ParamValue::Float(v), ParamDistribution::Uniform { low, high }) => {
536            if (high - low).abs() < 1e-300 {
537                0.5
538            } else {
539                (v - low) / (high - low)
540            }
541        }
542        (ParamValue::Float(v), ParamDistribution::LogUniform { low, high }) => {
543            let log_low = low.ln();
544            let log_high = high.ln();
545            if (log_high - log_low).abs() < 1e-300 {
546                0.5
547            } else {
548                (v.ln() - log_low) / (log_high - log_low)
549            }
550        }
551        (ParamValue::Int(v), ParamDistribution::IntUniform { low, high }) => {
552            if high == low {
553                0.5
554            } else {
555                (*v as f64 - *low as f64) / (*high as f64 - *low as f64)
556            }
557        }
558        // Fallback: treat the raw float-like value as already in some space.
559        (ParamValue::Float(v), _) => v.clamp(0.0, 1.0),
560        (ParamValue::Int(v), _) => (*v as f64).clamp(0.0, 1.0),
561        _ => 0.5,
562    }
563}
564
565// ---------------------------------------------------------------------------
566// Tests
567// ---------------------------------------------------------------------------
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::tree::DecisionTreeClassifier;
573
574    /// Build an Iris-like dataset with 3 well-separated classes.
575    fn iris_like() -> Dataset {
576        let n_per_class = 30;
577        let n = n_per_class * 3;
578        let mut f0 = Vec::with_capacity(n);
579        let mut f1 = Vec::with_capacity(n);
580        let mut f2 = Vec::with_capacity(n);
581        let mut f3 = Vec::with_capacity(n);
582        let mut target = Vec::with_capacity(n);
583
584        let mut rng = FastRng::new(123);
585
586        for _ in 0..n_per_class {
587            f0.push(1.0 + rng.f64() * 0.5);
588            f1.push(1.0 + rng.f64() * 0.5);
589            f2.push(0.5 + rng.f64() * 0.3);
590            f3.push(0.1 + rng.f64() * 0.2);
591            target.push(0.0);
592        }
593        for _ in 0..n_per_class {
594            f0.push(5.0 + rng.f64() * 0.5);
595            f1.push(3.0 + rng.f64() * 0.5);
596            f2.push(3.5 + rng.f64() * 0.5);
597            f3.push(1.0 + rng.f64() * 0.3);
598            target.push(1.0);
599        }
600        for _ in 0..n_per_class {
601            f0.push(6.5 + rng.f64() * 0.5);
602            f1.push(3.0 + rng.f64() * 0.5);
603            f2.push(5.5 + rng.f64() * 0.5);
604            f3.push(2.0 + rng.f64() * 0.3);
605            target.push(2.0);
606        }
607
608        Dataset::new(
609            vec![f0, f1, f2, f3],
610            target,
611            vec![
612                "sepal_len".into(),
613                "sepal_wid".into(),
614                "petal_len".into(),
615                "petal_wid".into(),
616            ],
617            "species",
618        )
619    }
620
621    #[test]
622    fn test_bayes_search_int_uniform() {
623        let data = iris_like();
624        let mut space = ParamSpace::new();
625        space.insert(
626            "max_depth".into(),
627            ParamDistribution::IntUniform { low: 2, high: 10 },
628        );
629
630        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
631            .n_iter(15)
632            .n_initial(5)
633            .cv(3)
634            .seed(42)
635            .fit(&data)
636            .unwrap();
637
638        assert!(
639            result.best_score() > 0.7,
640            "bayes best score {:.3} too low",
641            result.best_score()
642        );
643        assert_eq!(result.cv_results().len(), 15);
644        assert!(result.best_params().contains_key("max_depth"));
645    }
646
647    #[test]
648    fn test_bayes_search_categorical() {
649        let data = iris_like();
650        let mut space = ParamSpace::new();
651        space.insert(
652            "max_depth".into(),
653            ParamDistribution::Categorical(vec![
654                ParamValue::Int(2),
655                ParamValue::Int(4),
656                ParamValue::Int(6),
657                ParamValue::Int(8),
658            ]),
659        );
660
661        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
662            .n_iter(10)
663            .n_initial(4)
664            .cv(3)
665            .seed(99)
666            .fit(&data)
667            .unwrap();
668
669        assert!(
670            result.best_score() > 0.5,
671            "bayes categorical best score {:.3} too low",
672            result.best_score()
673        );
674        assert!(result.best_params().contains_key("max_depth"));
675    }
676
677    #[test]
678    fn test_bayes_search_mixed_space() {
679        let data = iris_like();
680        let mut space = ParamSpace::new();
681        space.insert(
682            "max_depth".into(),
683            ParamDistribution::IntUniform { low: 2, high: 8 },
684        );
685        space.insert(
686            "min_samples_split".into(),
687            ParamDistribution::IntUniform { low: 2, high: 10 },
688        );
689
690        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
691            .n_iter(12)
692            .n_initial(5)
693            .cv(3)
694            .seed(42)
695            .fit(&data)
696            .unwrap();
697
698        assert_eq!(result.cv_results().len(), 12);
699        assert!(result.best_params().contains_key("max_depth"));
700        assert!(result.best_params().contains_key("min_samples_split"));
701    }
702
703    #[test]
704    fn test_bayes_search_stratified() {
705        let data = iris_like();
706        let mut space = ParamSpace::new();
707        space.insert(
708            "max_depth".into(),
709            ParamDistribution::IntUniform { low: 2, high: 8 },
710        );
711
712        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
713            .n_iter(10)
714            .n_initial(5)
715            .cv(3)
716            .stratified(true)
717            .seed(42)
718            .fit(&data)
719            .unwrap();
720
721        assert!(
722            result.best_score() > 0.7,
723            "stratified bayes best score {:.3} too low",
724            result.best_score()
725        );
726    }
727
728    #[test]
729    fn test_bayes_search_empty_space() {
730        let data = iris_like();
731        let space = ParamSpace::new();
732        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space).fit(&data);
733        assert!(result.is_err());
734    }
735
736    #[test]
737    fn test_bayes_search_n_iter_zero() {
738        let data = iris_like();
739        let mut space = ParamSpace::new();
740        space.insert(
741            "max_depth".into(),
742            ParamDistribution::IntUniform { low: 2, high: 8 },
743        );
744        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
745            .n_iter(0)
746            .fit(&data);
747        assert!(result.is_err());
748    }
749
750    #[test]
751    fn test_bayes_search_all_initial() {
752        // When n_initial >= n_iter, all samples are random (no TPE phase).
753        let data = iris_like();
754        let mut space = ParamSpace::new();
755        space.insert(
756            "max_depth".into(),
757            ParamDistribution::IntUniform { low: 2, high: 6 },
758        );
759
760        let result = BayesSearchCV::new(DecisionTreeClassifier::new(), space)
761            .n_iter(5)
762            .n_initial(10)
763            .cv(3)
764            .seed(42)
765            .fit(&data)
766            .unwrap();
767
768        assert_eq!(result.cv_results().len(), 5);
769    }
770
771    #[test]
772    fn test_bayes_search_gbc_log_uniform() {
773        let data = iris_like();
774        let mut space = ParamSpace::new();
775        space.insert(
776            "n_estimators".into(),
777            ParamDistribution::Categorical(vec![
778                ParamValue::Int(5),
779                ParamValue::Int(10),
780                ParamValue::Int(20),
781            ]),
782        );
783        space.insert(
784            "max_depth".into(),
785            ParamDistribution::IntUniform { low: 2, high: 4 },
786        );
787
788        let result = BayesSearchCV::new(crate::tree::GradientBoostingClassifier::new(), space)
789            .n_iter(10)
790            .n_initial(5)
791            .cv(3)
792            .scoring(crate::metrics::accuracy)
793            .seed(42)
794            .fit(&data)
795            .unwrap();
796
797        assert!(
798            result.best_score() > 0.5,
799            "gbc bayes best score {:.3} too low",
800            result.best_score()
801        );
802    }
803
804    #[test]
805    fn test_normalize_param() {
806        let dist = ParamDistribution::Uniform {
807            low: 0.0,
808            high: 10.0,
809        };
810        let val = ParamValue::Float(5.0);
811        let norm = normalize_param(&val, &dist);
812        assert!((norm - 0.5).abs() < 1e-10);
813
814        let dist_int = ParamDistribution::IntUniform { low: 0, high: 10 };
815        let val_int = ParamValue::Int(5);
816        let norm_int = normalize_param(&val_int, &dist_int);
817        assert!((norm_int - 0.5).abs() < 1e-10);
818    }
819}