Skip to main content

tensorlogic_train/
hyperparameter.rs

1//! Hyperparameter optimization utilities.
2//!
3//! This module provides various hyperparameter search strategies:
4//! - Grid search (exhaustive search over parameter grid)
5//! - Random search (random sampling from parameter space)
6//! - Bayesian optimization (Gaussian Process-based optimization with acquisition functions)
7//! - Parameter space definition
8//! - Result tracking and comparison
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{s, Array1, Array2};
12use scirs2_core::random::{Rng, SeedableRng, StdRng};
13use std::collections::HashMap;
14
15/// Hyperparameter value type.
16#[derive(Debug, Clone, PartialEq)]
17pub enum HyperparamValue {
18    /// Floating-point value.
19    Float(f64),
20    /// Integer value.
21    Int(i64),
22    /// Boolean value.
23    Bool(bool),
24    /// String value.
25    String(String),
26}
27
28impl HyperparamValue {
29    /// Get as f64, if possible.
30    pub fn as_float(&self) -> Option<f64> {
31        match self {
32            HyperparamValue::Float(v) => Some(*v),
33            HyperparamValue::Int(v) => Some(*v as f64),
34            _ => None,
35        }
36    }
37
38    /// Get as i64, if possible.
39    pub fn as_int(&self) -> Option<i64> {
40        match self {
41            HyperparamValue::Int(v) => Some(*v),
42            HyperparamValue::Float(v) => Some(*v as i64),
43            _ => None,
44        }
45    }
46
47    /// Get as bool, if possible.
48    pub fn as_bool(&self) -> Option<bool> {
49        match self {
50            HyperparamValue::Bool(v) => Some(*v),
51            _ => None,
52        }
53    }
54
55    /// Get as string, if possible.
56    pub fn as_string(&self) -> Option<&str> {
57        match self {
58            HyperparamValue::String(v) => Some(v),
59            _ => None,
60        }
61    }
62}
63
64/// Hyperparameter space definition.
65#[derive(Debug, Clone)]
66pub enum HyperparamSpace {
67    /// Discrete choices.
68    Discrete(Vec<HyperparamValue>),
69    /// Continuous range [min, max].
70    Continuous { min: f64, max: f64 },
71    /// Log-uniform distribution [min, max].
72    LogUniform { min: f64, max: f64 },
73    /// Integer range [min, max].
74    IntRange { min: i64, max: i64 },
75}
76
77impl HyperparamSpace {
78    /// Create a discrete choice space.
79    pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
80        if values.is_empty() {
81            return Err(TrainError::InvalidParameter(
82                "Discrete space cannot be empty".to_string(),
83            ));
84        }
85        Ok(Self::Discrete(values))
86    }
87
88    /// Create a continuous range space.
89    pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
90        if min >= max {
91            return Err(TrainError::InvalidParameter(
92                "min must be less than max".to_string(),
93            ));
94        }
95        Ok(Self::Continuous { min, max })
96    }
97
98    /// Create a log-uniform distribution space.
99    pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
100        if min <= 0.0 || max <= 0.0 || min >= max {
101            return Err(TrainError::InvalidParameter(
102                "min and max must be positive and min < max".to_string(),
103            ));
104        }
105        Ok(Self::LogUniform { min, max })
106    }
107
108    /// Create an integer range space.
109    pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
110        if min >= max {
111            return Err(TrainError::InvalidParameter(
112                "min must be less than max".to_string(),
113            ));
114        }
115        Ok(Self::IntRange { min, max })
116    }
117
118    /// Sample a value from this space.
119    pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
120        match self {
121            HyperparamSpace::Discrete(values) => {
122                let idx = rng.gen_range(0..values.len());
123                values[idx].clone()
124            }
125            HyperparamSpace::Continuous { min, max } => {
126                let value = min + (max - min) * rng.random::<f64>();
127                HyperparamValue::Float(value)
128            }
129            HyperparamSpace::LogUniform { min, max } => {
130                let log_min = min.ln();
131                let log_max = max.ln();
132                let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
133                HyperparamValue::Float(log_value.exp())
134            }
135            HyperparamSpace::IntRange { min, max } => {
136                let value = rng.gen_range(*min..=*max);
137                HyperparamValue::Int(value)
138            }
139        }
140    }
141
142    /// Get all possible values for grid search (for discrete/int spaces).
143    pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
144        match self {
145            HyperparamSpace::Discrete(values) => values.clone(),
146            HyperparamSpace::IntRange { min, max } => {
147                let range_size = (max - min + 1) as usize;
148                let step = (range_size / num_samples).max(1);
149                (*min..=*max)
150                    .step_by(step)
151                    .map(HyperparamValue::Int)
152                    .collect()
153            }
154            HyperparamSpace::Continuous { min, max } => {
155                let step = (max - min) / (num_samples as f64);
156                (0..num_samples)
157                    .map(|i| HyperparamValue::Float(min + step * i as f64))
158                    .collect()
159            }
160            HyperparamSpace::LogUniform { min, max } => {
161                let log_min = min.ln();
162                let log_max = max.ln();
163                let log_step = (log_max - log_min) / (num_samples as f64);
164                (0..num_samples)
165                    .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
166                    .collect()
167            }
168        }
169    }
170}
171
172/// Hyperparameter configuration (a single point in parameter space).
173pub type HyperparamConfig = HashMap<String, HyperparamValue>;
174
175/// Result of a hyperparameter evaluation.
176#[derive(Debug, Clone)]
177pub struct HyperparamResult {
178    /// Hyperparameter configuration used.
179    pub config: HyperparamConfig,
180    /// Evaluation score (higher is better).
181    pub score: f64,
182    /// Additional metrics.
183    pub metrics: HashMap<String, f64>,
184}
185
186impl HyperparamResult {
187    /// Create a new result.
188    pub fn new(config: HyperparamConfig, score: f64) -> Self {
189        Self {
190            config,
191            score,
192            metrics: HashMap::new(),
193        }
194    }
195
196    /// Add a metric to the result.
197    pub fn with_metric(mut self, name: String, value: f64) -> Self {
198        self.metrics.insert(name, value);
199        self
200    }
201}
202
203/// Grid search strategy for hyperparameter optimization.
204///
205/// Exhaustively searches over a grid of hyperparameter values.
206#[derive(Debug)]
207pub struct GridSearch {
208    /// Parameter space definition.
209    param_space: HashMap<String, HyperparamSpace>,
210    /// Number of grid points per continuous parameter.
211    num_grid_points: usize,
212    /// Results from all evaluations.
213    results: Vec<HyperparamResult>,
214}
215
216impl GridSearch {
217    /// Create a new grid search.
218    ///
219    /// # Arguments
220    /// * `param_space` - Hyperparameter space definition
221    /// * `num_grid_points` - Number of points for continuous parameters
222    pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
223        Self {
224            param_space,
225            num_grid_points,
226            results: Vec::new(),
227        }
228    }
229
230    /// Generate all parameter configurations for grid search.
231    pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
232        if self.param_space.is_empty() {
233            return vec![HashMap::new()];
234        }
235
236        let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
237        param_names.sort(); // Ensure deterministic order
238
239        let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
240        for name in &param_names {
241            let space = &self.param_space[name];
242            all_values.push(space.grid_values(self.num_grid_points));
243        }
244
245        // Generate Cartesian product
246        let mut configs = Vec::new();
247        self.generate_cartesian_product(
248            &param_names,
249            &all_values,
250            0,
251            &mut HashMap::new(),
252            &mut configs,
253        );
254
255        configs
256    }
257
258    /// Recursively generate Cartesian product of parameter values.
259    #[allow(clippy::only_used_in_recursion)]
260    fn generate_cartesian_product(
261        &self,
262        param_names: &[String],
263        all_values: &[Vec<HyperparamValue>],
264        depth: usize,
265        current_config: &mut HyperparamConfig,
266        configs: &mut Vec<HyperparamConfig>,
267    ) {
268        if depth == param_names.len() {
269            configs.push(current_config.clone());
270            return;
271        }
272
273        let param_name = &param_names[depth];
274        let values = &all_values[depth];
275
276        for value in values {
277            current_config.insert(param_name.clone(), value.clone());
278            self.generate_cartesian_product(
279                param_names,
280                all_values,
281                depth + 1,
282                current_config,
283                configs,
284            );
285        }
286
287        current_config.remove(param_name);
288    }
289
290    /// Add a result from evaluating a configuration.
291    pub fn add_result(&mut self, result: HyperparamResult) {
292        self.results.push(result);
293    }
294
295    /// Get the best result found so far.
296    pub fn best_result(&self) -> Option<&HyperparamResult> {
297        self.results.iter().max_by(|a, b| {
298            a.score
299                .partial_cmp(&b.score)
300                .unwrap_or(std::cmp::Ordering::Equal)
301        })
302    }
303
304    /// Get all results sorted by score (descending).
305    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
306        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
307        results.sort_by(|a, b| {
308            b.score
309                .partial_cmp(&a.score)
310                .unwrap_or(std::cmp::Ordering::Equal)
311        });
312        results
313    }
314
315    /// Get all results.
316    pub fn results(&self) -> &[HyperparamResult] {
317        &self.results
318    }
319
320    /// Get total number of configurations to evaluate.
321    pub fn total_configs(&self) -> usize {
322        self.generate_configs().len()
323    }
324}
325
326/// Random search strategy for hyperparameter optimization.
327///
328/// Randomly samples from the hyperparameter space.
329#[derive(Debug)]
330pub struct RandomSearch {
331    /// Parameter space definition.
332    param_space: HashMap<String, HyperparamSpace>,
333    /// Number of random samples to evaluate.
334    num_samples: usize,
335    /// Random number generator.
336    rng: StdRng,
337    /// Results from all evaluations.
338    results: Vec<HyperparamResult>,
339}
340
341impl RandomSearch {
342    /// Create a new random search.
343    ///
344    /// # Arguments
345    /// * `param_space` - Hyperparameter space definition
346    /// * `num_samples` - Number of random configurations to try
347    /// * `seed` - Random seed for reproducibility
348    pub fn new(
349        param_space: HashMap<String, HyperparamSpace>,
350        num_samples: usize,
351        seed: u64,
352    ) -> Self {
353        Self {
354            param_space,
355            num_samples,
356            rng: StdRng::seed_from_u64(seed),
357            results: Vec::new(),
358        }
359    }
360
361    /// Generate random parameter configurations.
362    pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
363        let mut configs = Vec::with_capacity(self.num_samples);
364
365        for _ in 0..self.num_samples {
366            let mut config = HashMap::new();
367
368            for (name, space) in &self.param_space {
369                let value = space.sample(&mut self.rng);
370                config.insert(name.clone(), value);
371            }
372
373            configs.push(config);
374        }
375
376        configs
377    }
378
379    /// Add a result from evaluating a configuration.
380    pub fn add_result(&mut self, result: HyperparamResult) {
381        self.results.push(result);
382    }
383
384    /// Get the best result found so far.
385    pub fn best_result(&self) -> Option<&HyperparamResult> {
386        self.results.iter().max_by(|a, b| {
387            a.score
388                .partial_cmp(&b.score)
389                .unwrap_or(std::cmp::Ordering::Equal)
390        })
391    }
392
393    /// Get all results sorted by score (descending).
394    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
395        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
396        results.sort_by(|a, b| {
397            b.score
398                .partial_cmp(&a.score)
399                .unwrap_or(std::cmp::Ordering::Equal)
400        });
401        results
402    }
403
404    /// Get all results.
405    pub fn results(&self) -> &[HyperparamResult] {
406        &self.results
407    }
408}
409
410// ============================================================================
411// Bayesian Optimization
412// ============================================================================
413
414/// Acquisition function type for Bayesian Optimization.
415#[derive(Debug, Clone, Copy, PartialEq)]
416pub enum AcquisitionFunction {
417    /// Expected Improvement - balances exploration and exploitation.
418    ExpectedImprovement { xi: f64 },
419    /// Upper Confidence Bound - uses uncertainty to guide exploration.
420    UpperConfidenceBound { kappa: f64 },
421    /// Probability of Improvement - probability of improving over best.
422    ProbabilityOfImprovement { xi: f64 },
423}
424
425impl Default for AcquisitionFunction {
426    fn default() -> Self {
427        Self::ExpectedImprovement { xi: 0.01 }
428    }
429}
430
431/// Gaussian Process kernel for Bayesian Optimization.
432#[derive(Debug, Clone, Copy)]
433pub enum GpKernel {
434    /// Radial Basis Function (RBF) / Squared Exponential kernel.
435    /// K(x, x') = σ² * exp(-||x - x'||² / (2 * l²))
436    Rbf {
437        /// Signal variance (output scale).
438        sigma: f64,
439        /// Length scale (input scale).
440        length_scale: f64,
441    },
442    /// Matérn kernel with ν = 3/2.
443    /// K(x, x') = σ² * (1 + √3 * r / l) * exp(-√3 * r / l)
444    Matern32 {
445        /// Signal variance.
446        sigma: f64,
447        /// Length scale.
448        length_scale: f64,
449    },
450}
451
452impl Default for GpKernel {
453    fn default() -> Self {
454        Self::Rbf {
455            sigma: 1.0,
456            length_scale: 1.0,
457        }
458    }
459}
460
461impl GpKernel {
462    /// Compute kernel matrix K(X, X').
463    fn compute_kernel(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
464        let n1 = x1.nrows();
465        let n2 = x2.nrows();
466        let mut k = Array2::zeros((n1, n2));
467
468        for i in 0..n1 {
469            for j in 0..n2 {
470                let x1_row = x1.row(i);
471                let x2_row = x2.row(j);
472                let dist_sq = x1_row
473                    .iter()
474                    .zip(x2_row.iter())
475                    .map(|(a, b)| (a - b).powi(2))
476                    .sum::<f64>();
477
478                k[[i, j]] = match self {
479                    Self::Rbf {
480                        sigma,
481                        length_scale,
482                    } => sigma.powi(2) * (-dist_sq / (2.0 * length_scale.powi(2))).exp(),
483                    Self::Matern32 {
484                        sigma,
485                        length_scale,
486                    } => {
487                        let r = dist_sq.sqrt();
488                        let sqrt3_r_l = (3.0_f64).sqrt() * r / length_scale;
489                        sigma.powi(2) * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
490                    }
491                };
492            }
493        }
494
495        k
496    }
497
498    /// Compute kernel vector k(X, x).
499    fn compute_kernel_vector(&self, x_train: &Array2<f64>, x_test: &Array1<f64>) -> Array1<f64> {
500        let n = x_train.nrows();
501        let mut k = Array1::zeros(n);
502
503        for i in 0..n {
504            let x_train_row = x_train.row(i);
505            let dist_sq = x_train_row
506                .iter()
507                .zip(x_test.iter())
508                .map(|(a, b)| (a - b).powi(2))
509                .sum::<f64>();
510
511            k[i] = match self {
512                Self::Rbf {
513                    sigma,
514                    length_scale,
515                } => sigma.powi(2) * (-dist_sq / (2.0 * length_scale.powi(2))).exp(),
516                Self::Matern32 {
517                    sigma,
518                    length_scale,
519                } => {
520                    let r = dist_sq.sqrt();
521                    let sqrt3_r_l = (3.0_f64).sqrt() * r / length_scale;
522                    sigma.powi(2) * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
523                }
524            };
525        }
526
527        k
528    }
529}
530
531/// Gaussian Process regressor for Bayesian Optimization.
532///
533/// Provides probabilistic predictions with uncertainty estimates.
534#[derive(Debug)]
535pub struct GaussianProcess {
536    /// Kernel function.
537    kernel: GpKernel,
538    /// Noise variance (observation noise).
539    noise_variance: f64,
540    /// Training inputs (normalized to [0, 1]).
541    x_train: Option<Array2<f64>>,
542    /// Training outputs (standardized).
543    y_train: Option<Array1<f64>>,
544    /// Mean of training outputs (for standardization).
545    y_mean: f64,
546    /// Std of training outputs (for standardization).
547    y_std: f64,
548    /// Cholesky decomposition of K + σ²I (cached for efficiency).
549    l_matrix: Option<Array2<f64>>,
550    /// Alpha = L^T \ (L \ y) (cached).
551    alpha: Option<Array1<f64>>,
552}
553
554impl GaussianProcess {
555    /// Create a new Gaussian Process.
556    pub fn new(kernel: GpKernel, noise_variance: f64) -> Self {
557        Self {
558            kernel,
559            noise_variance,
560            x_train: None,
561            y_train: None,
562            y_mean: 0.0,
563            y_std: 1.0,
564            l_matrix: None,
565            alpha: None,
566        }
567    }
568
569    /// Fit the GP to training data.
570    pub fn fit(&mut self, x: Array2<f64>, y: Array1<f64>) -> TrainResult<()> {
571        if x.nrows() != y.len() {
572            return Err(TrainError::InvalidParameter(
573                "X and y must have same number of samples".to_string(),
574            ));
575        }
576
577        // Standardize y
578        let y_mean = y.mean().unwrap_or(0.0);
579        let y_std = y.std(0.0).max(1e-8);
580        let y_standardized = (&y - y_mean) / y_std;
581
582        // Compute kernel matrix
583        let k = self.kernel.compute_kernel(&x, &x);
584
585        // Add noise: K + σ²I
586        let mut k_noisy = k;
587        for i in 0..k_noisy.nrows() {
588            k_noisy[[i, i]] += self.noise_variance;
589        }
590
591        // Cholesky decomposition
592        let l = self.cholesky(&k_noisy)?;
593
594        // Solve L * alpha' = y
595        let alpha_prime = self.forward_substitution(&l, &y_standardized)?;
596        // Solve L^T * alpha = alpha'
597        let alpha = self.backward_substitution(&l, &alpha_prime)?;
598
599        self.x_train = Some(x);
600        self.y_train = Some(y_standardized);
601        self.y_mean = y_mean;
602        self.y_std = y_std;
603        self.l_matrix = Some(l);
604        self.alpha = Some(alpha);
605
606        Ok(())
607    }
608
609    /// Predict mean and standard deviation at test points.
610    pub fn predict(&self, x_test: &Array2<f64>) -> TrainResult<(Array1<f64>, Array1<f64>)> {
611        let x_train = self
612            .x_train
613            .as_ref()
614            .ok_or_else(|| TrainError::InvalidParameter("GP not fitted".to_string()))?;
615        let l_matrix = self.l_matrix.as_ref().unwrap();
616        let alpha = self.alpha.as_ref().unwrap();
617
618        let n_test = x_test.nrows();
619        let mut means = Array1::zeros(n_test);
620        let mut stds = Array1::zeros(n_test);
621
622        for i in 0..n_test {
623            let x = x_test.row(i).to_owned();
624
625            // Compute k(X*, x)
626            let k_star = self.kernel.compute_kernel_vector(x_train, &x);
627
628            // Mean: k(X*, x)^T * alpha
629            let mean_standardized = k_star.dot(alpha);
630            means[i] = mean_standardized * self.y_std + self.y_mean;
631
632            // Variance: k(x, x) - k(X*, x)^T * (K + σ²I)^(-1) * k(X*, x)
633            let k_star_star = self
634                .kernel
635                .compute_kernel_vector(&x_test.slice(s![i..i + 1, ..]).to_owned(), &x)[0];
636            let v = self
637                .forward_substitution(l_matrix, &k_star)
638                .unwrap_or_else(|_| Array1::zeros(k_star.len()));
639            let variance_standardized = k_star_star - v.dot(&v);
640            stds[i] = (variance_standardized.max(1e-10) * self.y_std.powi(2)).sqrt();
641        }
642
643        Ok((means, stds))
644    }
645
646    /// Cholesky decomposition: K = L * L^T.
647    fn cholesky(&self, k: &Array2<f64>) -> TrainResult<Array2<f64>> {
648        let n = k.nrows();
649        let mut l = Array2::zeros((n, n));
650
651        for i in 0..n {
652            for j in 0..=i {
653                let mut sum = 0.0;
654                for k_idx in 0..j {
655                    sum += l[[i, k_idx]] * l[[j, k_idx]];
656                }
657
658                if i == j {
659                    let val = k[[i, i]] - sum;
660                    if val <= 0.0 {
661                        // Add jitter for numerical stability
662                        l[[i, j]] = (k[[i, i]] - sum + 1e-6).sqrt();
663                    } else {
664                        l[[i, j]] = val.sqrt();
665                    }
666                } else {
667                    l[[i, j]] = (k[[i, j]] - sum) / l[[j, j]];
668                }
669            }
670        }
671
672        Ok(l)
673    }
674
675    /// Forward substitution: solve L * x = b.
676    fn forward_substitution(&self, l: &Array2<f64>, b: &Array1<f64>) -> TrainResult<Array1<f64>> {
677        let n = l.nrows();
678        let mut x = Array1::zeros(n);
679
680        for i in 0..n {
681            let mut sum = 0.0;
682            for j in 0..i {
683                sum += l[[i, j]] * x[j];
684            }
685            x[i] = (b[i] - sum) / l[[i, i]];
686        }
687
688        Ok(x)
689    }
690
691    /// Backward substitution: solve L^T * x = b.
692    fn backward_substitution(&self, l: &Array2<f64>, b: &Array1<f64>) -> TrainResult<Array1<f64>> {
693        let n = l.nrows();
694        let mut x = Array1::zeros(n);
695
696        for i in (0..n).rev() {
697            let mut sum = 0.0;
698            for j in (i + 1)..n {
699                sum += l[[j, i]] * x[j];
700            }
701            x[i] = (b[i] - sum) / l[[i, i]];
702        }
703
704        Ok(x)
705    }
706}
707
708/// Bayesian Optimization for hyperparameter tuning.
709///
710/// Uses Gaussian Processes to model the objective function and acquisition
711/// functions to intelligently select the next hyperparameters to evaluate.
712///
713/// # Algorithm
714/// 1. Initialize with random samples
715/// 2. Fit Gaussian Process to observed data
716/// 3. Optimize acquisition function to find next point
717/// 4. Evaluate objective at new point
718/// 5. Repeat steps 2-4 until budget exhausted
719///
720/// # Example
721/// ```
722/// use tensorlogic_train::*;
723/// use std::collections::HashMap;
724///
725/// let mut param_space = HashMap::new();
726/// param_space.insert(
727///     "lr".to_string(),
728///     HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
729/// );
730///
731/// let mut bayes_opt = BayesianOptimization::new(
732///     param_space,
733///     10,  // n_iterations
734///     5,   // n_initial_points
735///     42,  // seed
736/// );
737///
738/// // In practice, you would evaluate your model here
739/// // bayes_opt.add_result(result);
740/// ```
741pub struct BayesianOptimization {
742    /// Parameter space definition.
743    param_space: HashMap<String, HyperparamSpace>,
744    /// Number of optimization iterations.
745    n_iterations: usize,
746    /// Number of random initial points.
747    n_initial_points: usize,
748    /// Acquisition function.
749    acquisition_fn: AcquisitionFunction,
750    /// Gaussian Process kernel.
751    kernel: GpKernel,
752    /// Observation noise.
753    noise_variance: f64,
754    /// Random number generator.
755    rng: StdRng,
756    /// Results from all evaluations.
757    results: Vec<HyperparamResult>,
758    /// Bounds for normalization [min, max] per dimension.
759    bounds: Vec<(f64, f64)>,
760    /// Parameter names in order.
761    param_names: Vec<String>,
762}
763
764impl BayesianOptimization {
765    /// Create a new Bayesian Optimization instance.
766    ///
767    /// # Arguments
768    /// * `param_space` - Hyperparameter space definition
769    /// * `n_iterations` - Number of optimization iterations
770    /// * `n_initial_points` - Number of random initialization points
771    /// * `seed` - Random seed for reproducibility
772    pub fn new(
773        param_space: HashMap<String, HyperparamSpace>,
774        n_iterations: usize,
775        n_initial_points: usize,
776        seed: u64,
777    ) -> Self {
778        let mut param_names: Vec<String> = param_space.keys().cloned().collect();
779        param_names.sort(); // Ensure deterministic order
780
781        let bounds = Self::extract_bounds(&param_space, &param_names);
782
783        Self {
784            param_space,
785            n_iterations,
786            n_initial_points,
787            acquisition_fn: AcquisitionFunction::default(),
788            kernel: GpKernel::default(),
789            noise_variance: 1e-6,
790            rng: StdRng::seed_from_u64(seed),
791            results: Vec::new(),
792            bounds,
793            param_names,
794        }
795    }
796
797    /// Set acquisition function.
798    pub fn with_acquisition(mut self, acquisition_fn: AcquisitionFunction) -> Self {
799        self.acquisition_fn = acquisition_fn;
800        self
801    }
802
803    /// Set kernel.
804    pub fn with_kernel(mut self, kernel: GpKernel) -> Self {
805        self.kernel = kernel;
806        self
807    }
808
809    /// Set noise variance.
810    pub fn with_noise(mut self, noise_variance: f64) -> Self {
811        self.noise_variance = noise_variance;
812        self
813    }
814
815    /// Extract bounds from parameter space.
816    fn extract_bounds(
817        param_space: &HashMap<String, HyperparamSpace>,
818        param_names: &[String],
819    ) -> Vec<(f64, f64)> {
820        param_names
821            .iter()
822            .map(|name| {
823                match &param_space[name] {
824                    HyperparamSpace::Continuous { min, max } => (*min, *max),
825                    HyperparamSpace::LogUniform { min, max } => (min.ln(), max.ln()),
826                    HyperparamSpace::IntRange { min, max } => (*min as f64, *max as f64),
827                    HyperparamSpace::Discrete(values) => {
828                        // For discrete, we'll use indices
829                        (0.0, (values.len() - 1) as f64)
830                    }
831                }
832            })
833            .collect()
834    }
835
836    /// Suggest next hyperparameter configuration to evaluate.
837    pub fn suggest(&mut self) -> TrainResult<HyperparamConfig> {
838        // Use random sampling for initial points
839        if self.results.len() < self.n_initial_points {
840            return Ok(self.random_sample());
841        }
842
843        // Build GP from observed data
844        let (x_observed, y_observed) = self.get_observations();
845        let mut gp = GaussianProcess::new(self.kernel, self.noise_variance);
846        gp.fit(x_observed, y_observed)?;
847
848        // Optimize acquisition function
849        let best_x = self.optimize_acquisition(&gp)?;
850
851        // Convert to hyperparameter configuration
852        self.vector_to_config(&best_x)
853    }
854
855    /// Get observations as (X, y) matrices.
856    fn get_observations(&self) -> (Array2<f64>, Array1<f64>) {
857        let n_samples = self.results.len();
858        let n_dims = self.param_names.len();
859
860        let mut x = Array2::zeros((n_samples, n_dims));
861        let mut y = Array1::zeros(n_samples);
862
863        for (i, result) in self.results.iter().enumerate() {
864            let x_vec = self.config_to_vector(&result.config);
865            for (j, &val) in x_vec.iter().enumerate() {
866                x[[i, j]] = val;
867            }
868            y[i] = result.score;
869        }
870
871        (x, y)
872    }
873
874    /// Optimize acquisition function to find next point.
875    fn optimize_acquisition(&mut self, gp: &GaussianProcess) -> TrainResult<Array1<f64>> {
876        let n_dims = self.param_names.len();
877        let n_candidates = 1000;
878        let n_restarts = 10;
879
880        let mut best_acq_value = f64::NEG_INFINITY;
881        let mut best_x = Array1::zeros(n_dims);
882
883        // Random search with multiple restarts
884        for _ in 0..n_restarts {
885            for _ in 0..(n_candidates / n_restarts) {
886                // Generate random candidate
887                let mut x_candidate = Array1::zeros(n_dims);
888                for (i, (min, max)) in self.bounds.iter().enumerate() {
889                    x_candidate[i] = min + (max - min) * self.rng.random::<f64>();
890                }
891
892                // Evaluate acquisition
893                let acq_value = self.evaluate_acquisition(gp, &x_candidate)?;
894
895                if acq_value > best_acq_value {
896                    best_acq_value = acq_value;
897                    best_x = x_candidate;
898                }
899            }
900        }
901
902        Ok(best_x)
903    }
904
905    /// Evaluate acquisition function at a point.
906    fn evaluate_acquisition(&self, gp: &GaussianProcess, x: &Array1<f64>) -> TrainResult<f64> {
907        let x_mat = x.clone().into_shape_with_order((1, x.len())).unwrap();
908        let (mean, std) = gp.predict(&x_mat)?;
909        let mu = mean[0];
910        let sigma = std[0];
911
912        if sigma < 1e-10 {
913            return Ok(0.0);
914        }
915
916        let f_best = self
917            .results
918            .iter()
919            .map(|r| r.score)
920            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
921            .unwrap_or(0.0);
922
923        let acq = match self.acquisition_fn {
924            AcquisitionFunction::ExpectedImprovement { xi } => {
925                let z = (mu - f_best - xi) / sigma;
926                let phi = Self::normal_cdf(z);
927                let pdf = Self::normal_pdf(z);
928                (mu - f_best - xi) * phi + sigma * pdf
929            }
930            AcquisitionFunction::UpperConfidenceBound { kappa } => mu + kappa * sigma,
931            AcquisitionFunction::ProbabilityOfImprovement { xi } => {
932                let z = (mu - f_best - xi) / sigma;
933                Self::normal_cdf(z)
934            }
935        };
936
937        Ok(acq)
938    }
939
940    /// Standard normal CDF (cumulative distribution function).
941    fn normal_cdf(x: f64) -> f64 {
942        0.5 * (1.0 + Self::erf(x / 2.0_f64.sqrt()))
943    }
944
945    /// Standard normal PDF (probability density function).
946    fn normal_pdf(x: f64) -> f64 {
947        (-0.5 * x.powi(2)).exp() / (2.0 * std::f64::consts::PI).sqrt()
948    }
949
950    /// Error function approximation.
951    fn erf(x: f64) -> f64 {
952        // Abramowitz and Stegun approximation
953        let a1 = 0.254829592;
954        let a2 = -0.284496736;
955        let a3 = 1.421413741;
956        let a4 = -1.453152027;
957        let a5 = 1.061405429;
958        let p = 0.3275911;
959
960        let sign = if x < 0.0 { -1.0 } else { 1.0 };
961        let x = x.abs();
962
963        let t = 1.0 / (1.0 + p * x);
964        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
965
966        sign * y
967    }
968
969    /// Convert configuration to normalized vector [0, 1]^d.
970    fn config_to_vector(&self, config: &HyperparamConfig) -> Array1<f64> {
971        let n_dims = self.param_names.len();
972        let mut x = Array1::zeros(n_dims);
973
974        for (i, name) in self.param_names.iter().enumerate() {
975            let value = &config[name];
976            let (min, max) = self.bounds[i];
977
978            x[i] = match &self.param_space[name] {
979                HyperparamSpace::Continuous { .. } => {
980                    let v = value.as_float().unwrap();
981                    (v - min) / (max - min)
982                }
983                HyperparamSpace::LogUniform { .. } => {
984                    let v = value.as_float().unwrap();
985                    let log_v = v.ln();
986                    (log_v - min) / (max - min)
987                }
988                HyperparamSpace::IntRange { .. } => {
989                    let v = value.as_int().unwrap() as f64;
990                    (v - min) / (max - min)
991                }
992                HyperparamSpace::Discrete(values) => {
993                    let idx = values.iter().position(|v| v == value).unwrap_or(0);
994                    (idx as f64 - min) / (max - min)
995                }
996            };
997        }
998
999        x
1000    }
1001
1002    /// Convert normalized vector to configuration.
1003    fn vector_to_config(&self, x: &Array1<f64>) -> TrainResult<HyperparamConfig> {
1004        let mut config = HashMap::new();
1005
1006        for (i, name) in self.param_names.iter().enumerate() {
1007            let normalized = x[i].clamp(0.0, 1.0);
1008            let (min, max) = self.bounds[i];
1009            let value_raw = min + normalized * (max - min);
1010
1011            let value = match &self.param_space[name] {
1012                HyperparamSpace::Continuous { .. } => HyperparamValue::Float(value_raw),
1013                HyperparamSpace::LogUniform { .. } => HyperparamValue::Float(value_raw.exp()),
1014                HyperparamSpace::IntRange { .. } => HyperparamValue::Int(value_raw.round() as i64),
1015                HyperparamSpace::Discrete(values) => {
1016                    let idx = value_raw.round() as usize;
1017                    values[idx.min(values.len() - 1)].clone()
1018                }
1019            };
1020
1021            config.insert(name.clone(), value);
1022        }
1023
1024        Ok(config)
1025    }
1026
1027    /// Generate a random sample from parameter space.
1028    fn random_sample(&mut self) -> HyperparamConfig {
1029        let mut config = HashMap::new();
1030
1031        for (name, space) in &self.param_space {
1032            let value = space.sample(&mut self.rng);
1033            config.insert(name.clone(), value);
1034        }
1035
1036        config
1037    }
1038
1039    /// Add a result from evaluating a configuration.
1040    pub fn add_result(&mut self, result: HyperparamResult) {
1041        self.results.push(result);
1042    }
1043
1044    /// Get the best result found so far.
1045    pub fn best_result(&self) -> Option<&HyperparamResult> {
1046        self.results.iter().max_by(|a, b| {
1047            a.score
1048                .partial_cmp(&b.score)
1049                .unwrap_or(std::cmp::Ordering::Equal)
1050        })
1051    }
1052
1053    /// Get all results sorted by score (descending).
1054    pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
1055        let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
1056        results.sort_by(|a, b| {
1057            b.score
1058                .partial_cmp(&a.score)
1059                .unwrap_or(std::cmp::Ordering::Equal)
1060        });
1061        results
1062    }
1063
1064    /// Get all results.
1065    pub fn results(&self) -> &[HyperparamResult] {
1066        &self.results
1067    }
1068
1069    /// Check if optimization is complete.
1070    pub fn is_complete(&self) -> bool {
1071        self.results.len() >= self.n_iterations + self.n_initial_points
1072    }
1073
1074    /// Get current iteration number.
1075    pub fn current_iteration(&self) -> usize {
1076        self.results.len()
1077    }
1078
1079    /// Get total budget (initial + iterations).
1080    pub fn total_budget(&self) -> usize {
1081        self.n_iterations + self.n_initial_points
1082    }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088
1089    #[test]
1090    fn test_hyperparam_value() {
1091        let float_val = HyperparamValue::Float(3.5);
1092        assert_eq!(float_val.as_float(), Some(3.5));
1093        assert_eq!(float_val.as_int(), Some(3));
1094
1095        let int_val = HyperparamValue::Int(42);
1096        assert_eq!(int_val.as_int(), Some(42));
1097        assert_eq!(int_val.as_float(), Some(42.0));
1098
1099        let bool_val = HyperparamValue::Bool(true);
1100        assert_eq!(bool_val.as_bool(), Some(true));
1101
1102        let string_val = HyperparamValue::String("test".to_string());
1103        assert_eq!(string_val.as_string(), Some("test"));
1104    }
1105
1106    #[test]
1107    fn test_hyperparam_space_discrete() {
1108        let space = HyperparamSpace::discrete(vec![
1109            HyperparamValue::Float(0.1),
1110            HyperparamValue::Float(0.01),
1111        ])
1112        .unwrap();
1113
1114        let values = space.grid_values(10);
1115        assert_eq!(values.len(), 2);
1116
1117        let mut rng = StdRng::seed_from_u64(42);
1118        let sampled = space.sample(&mut rng);
1119        assert!(matches!(sampled, HyperparamValue::Float(_)));
1120    }
1121
1122    #[test]
1123    fn test_hyperparam_space_continuous() {
1124        let space = HyperparamSpace::continuous(0.0, 1.0).unwrap();
1125
1126        let values = space.grid_values(5);
1127        assert_eq!(values.len(), 5);
1128
1129        let mut rng = StdRng::seed_from_u64(42);
1130        let sampled = space.sample(&mut rng);
1131        if let HyperparamValue::Float(v) = sampled {
1132            assert!((0.0..=1.0).contains(&v));
1133        } else {
1134            panic!("Expected Float value");
1135        }
1136    }
1137
1138    #[test]
1139    fn test_hyperparam_space_log_uniform() {
1140        let space = HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap();
1141
1142        let values = space.grid_values(3);
1143        assert_eq!(values.len(), 3);
1144
1145        let mut rng = StdRng::seed_from_u64(42);
1146        let sampled = space.sample(&mut rng);
1147        if let HyperparamValue::Float(v) = sampled {
1148            assert!((1e-4..=1e-1).contains(&v));
1149        } else {
1150            panic!("Expected Float value");
1151        }
1152    }
1153
1154    #[test]
1155    fn test_hyperparam_space_int_range() {
1156        let space = HyperparamSpace::int_range(1, 10).unwrap();
1157
1158        let values = space.grid_values(5);
1159        assert!(!values.is_empty());
1160
1161        let mut rng = StdRng::seed_from_u64(42);
1162        let sampled = space.sample(&mut rng);
1163        if let HyperparamValue::Int(v) = sampled {
1164            assert!((1..=10).contains(&v));
1165        } else {
1166            panic!("Expected Int value");
1167        }
1168    }
1169
1170    #[test]
1171    fn test_hyperparam_space_invalid() {
1172        assert!(HyperparamSpace::discrete(vec![]).is_err());
1173        assert!(HyperparamSpace::continuous(1.0, 0.0).is_err());
1174        assert!(HyperparamSpace::log_uniform(0.0, 1.0).is_err());
1175        assert!(HyperparamSpace::log_uniform(1.0, 0.5).is_err());
1176        assert!(HyperparamSpace::int_range(10, 5).is_err());
1177    }
1178
1179    #[test]
1180    fn test_grid_search() {
1181        let mut param_space = HashMap::new();
1182        param_space.insert(
1183            "lr".to_string(),
1184            HyperparamSpace::discrete(vec![
1185                HyperparamValue::Float(0.1),
1186                HyperparamValue::Float(0.01),
1187            ])
1188            .unwrap(),
1189        );
1190        param_space.insert(
1191            "batch_size".to_string(),
1192            HyperparamSpace::int_range(16, 64).unwrap(),
1193        );
1194
1195        let grid_search = GridSearch::new(param_space, 3);
1196
1197        let configs = grid_search.generate_configs();
1198        assert!(!configs.is_empty());
1199
1200        // Should have 2 (lr values) * grid_points (batch_size values) configs
1201        assert!(configs.len() >= 2);
1202    }
1203
1204    #[test]
1205    fn test_grid_search_results() {
1206        let mut param_space = HashMap::new();
1207        param_space.insert(
1208            "lr".to_string(),
1209            HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
1210        );
1211
1212        let mut grid_search = GridSearch::new(param_space, 3);
1213
1214        let mut config = HashMap::new();
1215        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1216
1217        grid_search.add_result(HyperparamResult::new(config.clone(), 0.9));
1218        grid_search.add_result(HyperparamResult::new(config.clone(), 0.95));
1219        grid_search.add_result(HyperparamResult::new(config, 0.85));
1220
1221        let best = grid_search.best_result().unwrap();
1222        assert_eq!(best.score, 0.95);
1223
1224        let sorted = grid_search.sorted_results();
1225        assert_eq!(sorted[0].score, 0.95);
1226        assert_eq!(sorted[1].score, 0.9);
1227        assert_eq!(sorted[2].score, 0.85);
1228    }
1229
1230    #[test]
1231    fn test_random_search() {
1232        let mut param_space = HashMap::new();
1233        param_space.insert(
1234            "lr".to_string(),
1235            HyperparamSpace::continuous(1e-4, 1e-1).unwrap(),
1236        );
1237        param_space.insert(
1238            "dropout".to_string(),
1239            HyperparamSpace::continuous(0.0, 0.5).unwrap(),
1240        );
1241
1242        let mut random_search = RandomSearch::new(param_space, 10, 42);
1243
1244        let configs = random_search.generate_configs();
1245        assert_eq!(configs.len(), 10);
1246
1247        // Check that each config has all parameters
1248        for config in &configs {
1249            assert!(config.contains_key("lr"));
1250            assert!(config.contains_key("dropout"));
1251        }
1252    }
1253
1254    #[test]
1255    fn test_random_search_results() {
1256        let mut param_space = HashMap::new();
1257        param_space.insert(
1258            "lr".to_string(),
1259            HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
1260        );
1261
1262        let mut random_search = RandomSearch::new(param_space, 5, 42);
1263
1264        let mut config = HashMap::new();
1265        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1266
1267        random_search.add_result(HyperparamResult::new(config.clone(), 0.8));
1268        random_search.add_result(HyperparamResult::new(config, 0.9));
1269
1270        let best = random_search.best_result().unwrap();
1271        assert_eq!(best.score, 0.9);
1272
1273        assert_eq!(random_search.results().len(), 2);
1274    }
1275
1276    #[test]
1277    fn test_hyperparam_result_with_metrics() {
1278        let mut config = HashMap::new();
1279        config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1280
1281        let result = HyperparamResult::new(config, 0.95)
1282            .with_metric("accuracy".to_string(), 0.95)
1283            .with_metric("loss".to_string(), 0.05);
1284
1285        assert_eq!(result.score, 0.95);
1286        assert_eq!(result.metrics.get("accuracy"), Some(&0.95));
1287        assert_eq!(result.metrics.get("loss"), Some(&0.05));
1288    }
1289
1290    #[test]
1291    fn test_grid_search_empty_space() {
1292        let grid_search = GridSearch::new(HashMap::new(), 3);
1293        let configs = grid_search.generate_configs();
1294        assert_eq!(configs.len(), 1); // One empty config
1295        assert!(configs[0].is_empty());
1296    }
1297
1298    #[test]
1299    fn test_grid_search_total_configs() {
1300        let mut param_space = HashMap::new();
1301        param_space.insert(
1302            "lr".to_string(),
1303            HyperparamSpace::discrete(vec![
1304                HyperparamValue::Float(0.1),
1305                HyperparamValue::Float(0.01),
1306            ])
1307            .unwrap(),
1308        );
1309
1310        let grid_search = GridSearch::new(param_space, 3);
1311        assert_eq!(grid_search.total_configs(), 2);
1312    }
1313
1314    // ============================================================================
1315    // Bayesian Optimization Tests
1316    // ============================================================================
1317
1318    #[test]
1319    fn test_gp_kernel_rbf() {
1320        let kernel = GpKernel::Rbf {
1321            sigma: 1.0,
1322            length_scale: 1.0,
1323        };
1324
1325        let x1 = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1326        let x2 = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 0.5, 0.5]).unwrap();
1327
1328        let k = kernel.compute_kernel(&x1, &x2);
1329        assert_eq!(k.shape(), &[2, 2]);
1330
1331        // K(x, x) should be sigma^2
1332        assert!((k[[0, 0]] - 1.0).abs() < 1e-6);
1333    }
1334
1335    #[test]
1336    fn test_gp_kernel_matern() {
1337        let kernel = GpKernel::Matern32 {
1338            sigma: 1.0,
1339            length_scale: 1.0,
1340        };
1341
1342        let x = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap();
1343        let k = kernel.compute_kernel(&x, &x);
1344
1345        // K(x, x) should be sigma^2 for Matern kernel at same point
1346        assert!((k[[0, 0]] - 1.0).abs() < 1e-6);
1347    }
1348
1349    #[test]
1350    fn test_gp_fit_and_predict() {
1351        let kernel = GpKernel::Rbf {
1352            sigma: 1.0,
1353            length_scale: 0.5,
1354        };
1355        let mut gp = GaussianProcess::new(kernel, 1e-6);
1356
1357        // Training data: y = x^2
1358        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.5, 1.0, 1.5, 2.0]).unwrap();
1359        let y_train = Array1::from_vec(vec![0.0, 0.25, 1.0, 2.25, 4.0]);
1360
1361        gp.fit(x_train, y_train).unwrap();
1362
1363        // Test prediction
1364        let x_test = Array2::from_shape_vec((2, 1), vec![0.75, 1.25]).unwrap();
1365        let (means, _stds) = gp.predict(&x_test).unwrap();
1366
1367        assert_eq!(means.len(), 2);
1368        // Predictions should be reasonable (between 0 and 4)
1369        assert!(means[0] >= 0.0 && means[0] <= 4.0);
1370        assert!(means[1] >= 0.0 && means[1] <= 4.0);
1371    }
1372
1373    #[test]
1374    fn test_gp_predict_error_not_fitted() {
1375        let kernel = GpKernel::default();
1376        let gp = GaussianProcess::new(kernel, 1e-6);
1377
1378        let x_test = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
1379        let result = gp.predict(&x_test);
1380
1381        assert!(result.is_err());
1382    }
1383
1384    #[test]
1385    fn test_gp_fit_dimension_mismatch() {
1386        let kernel = GpKernel::default();
1387        let mut gp = GaussianProcess::new(kernel, 1e-6);
1388
1389        let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
1390        let y = Array1::from_vec(vec![0.0, 1.0]); // Wrong size
1391
1392        let result = gp.fit(x, y);
1393        assert!(result.is_err());
1394    }
1395
1396    #[test]
1397    fn test_acquisition_function_ei() {
1398        let acq = AcquisitionFunction::ExpectedImprovement { xi: 0.01 };
1399        assert!(matches!(
1400            acq,
1401            AcquisitionFunction::ExpectedImprovement { .. }
1402        ));
1403    }
1404
1405    #[test]
1406    fn test_acquisition_function_ucb() {
1407        let acq = AcquisitionFunction::UpperConfidenceBound { kappa: 2.0 };
1408        assert!(matches!(
1409            acq,
1410            AcquisitionFunction::UpperConfidenceBound { .. }
1411        ));
1412    }
1413
1414    #[test]
1415    fn test_acquisition_function_pi() {
1416        let acq = AcquisitionFunction::ProbabilityOfImprovement { xi: 0.01 };
1417        assert!(matches!(
1418            acq,
1419            AcquisitionFunction::ProbabilityOfImprovement { .. }
1420        ));
1421    }
1422
1423    #[test]
1424    fn test_bayesian_optimization_creation() {
1425        let mut param_space = HashMap::new();
1426        param_space.insert(
1427            "lr".to_string(),
1428            HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1429        );
1430
1431        let bayes_opt = BayesianOptimization::new(param_space, 10, 5, 42);
1432
1433        assert_eq!(bayes_opt.total_budget(), 15);
1434        assert_eq!(bayes_opt.current_iteration(), 0);
1435        assert!(!bayes_opt.is_complete());
1436    }
1437
1438    #[test]
1439    fn test_bayesian_optimization_suggest_initial() {
1440        let mut param_space = HashMap::new();
1441        param_space.insert(
1442            "lr".to_string(),
1443            HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1444        );
1445
1446        let mut bayes_opt = BayesianOptimization::new(param_space, 5, 3, 42);
1447
1448        // First suggestions should be random (initial phase)
1449        for _ in 0..3 {
1450            let config = bayes_opt.suggest().unwrap();
1451            assert!(config.contains_key("lr"));
1452
1453            // Simulate adding a result
1454            bayes_opt.add_result(HyperparamResult::new(config, 0.5));
1455        }
1456
1457        assert_eq!(bayes_opt.current_iteration(), 3);
1458    }
1459
1460    #[test]
1461    fn test_bayesian_optimization_suggest_gp_phase() {
1462        let mut param_space = HashMap::new();
1463        param_space.insert(
1464            "x".to_string(),
1465            HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1466        );
1467
1468        let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1469
1470        // Add initial observations
1471        let mut config1 = HashMap::new();
1472        config1.insert("x".to_string(), HyperparamValue::Float(0.25));
1473        bayes_opt.add_result(HyperparamResult::new(config1, 0.5));
1474
1475        let mut config2 = HashMap::new();
1476        config2.insert("x".to_string(), HyperparamValue::Float(0.75));
1477        bayes_opt.add_result(HyperparamResult::new(config2, 0.8));
1478
1479        // Next suggestion should use GP
1480        let config = bayes_opt.suggest().unwrap();
1481        assert!(config.contains_key("x"));
1482    }
1483
1484    #[test]
1485    fn test_bayesian_optimization_with_acquisition() {
1486        let mut param_space = HashMap::new();
1487        param_space.insert(
1488            "lr".to_string(),
1489            HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1490        );
1491
1492        let bayes_opt = BayesianOptimization::new(param_space, 10, 5, 42)
1493            .with_acquisition(AcquisitionFunction::UpperConfidenceBound { kappa: 2.0 })
1494            .with_kernel(GpKernel::Matern32 {
1495                sigma: 1.0,
1496                length_scale: 0.5,
1497            })
1498            .with_noise(1e-5);
1499
1500        assert!(bayes_opt.total_budget() == 15);
1501    }
1502
1503    #[test]
1504    fn test_bayesian_optimization_best_result() {
1505        let mut param_space = HashMap::new();
1506        param_space.insert(
1507            "x".to_string(),
1508            HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1509        );
1510
1511        let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1512
1513        let mut config1 = HashMap::new();
1514        config1.insert("x".to_string(), HyperparamValue::Float(0.3));
1515        bayes_opt.add_result(HyperparamResult::new(config1, 0.6));
1516
1517        let mut config2 = HashMap::new();
1518        config2.insert("x".to_string(), HyperparamValue::Float(0.7));
1519        bayes_opt.add_result(HyperparamResult::new(config2, 0.9));
1520
1521        let best = bayes_opt.best_result().unwrap();
1522        assert_eq!(best.score, 0.9);
1523    }
1524
1525    #[test]
1526    fn test_bayesian_optimization_is_complete() {
1527        let mut param_space = HashMap::new();
1528        param_space.insert(
1529            "x".to_string(),
1530            HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1531        );
1532
1533        let mut bayes_opt = BayesianOptimization::new(param_space, 2, 1, 42);
1534
1535        assert!(!bayes_opt.is_complete());
1536
1537        // Add results up to budget
1538        for i in 0..3 {
1539            let mut config = HashMap::new();
1540            config.insert("x".to_string(), HyperparamValue::Float(i as f64 * 0.3));
1541            bayes_opt.add_result(HyperparamResult::new(config, i as f64 * 0.2));
1542        }
1543
1544        assert!(bayes_opt.is_complete());
1545    }
1546
1547    #[test]
1548    fn test_bayesian_optimization_multivariate() {
1549        let mut param_space = HashMap::new();
1550        param_space.insert(
1551            "lr".to_string(),
1552            HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1553        );
1554        param_space.insert(
1555            "batch_size".to_string(),
1556            HyperparamSpace::int_range(16, 128).unwrap(),
1557        );
1558        param_space.insert(
1559            "dropout".to_string(),
1560            HyperparamSpace::continuous(0.0, 0.5).unwrap(),
1561        );
1562
1563        let mut bayes_opt = BayesianOptimization::new(param_space, 10, 3, 42);
1564
1565        let config = bayes_opt.suggest().unwrap();
1566        assert_eq!(config.len(), 3);
1567        assert!(config.contains_key("lr"));
1568        assert!(config.contains_key("batch_size"));
1569        assert!(config.contains_key("dropout"));
1570    }
1571
1572    #[test]
1573    fn test_bayesian_optimization_discrete_space() {
1574        let mut param_space = HashMap::new();
1575        param_space.insert(
1576            "optimizer".to_string(),
1577            HyperparamSpace::discrete(vec![
1578                HyperparamValue::String("adam".to_string()),
1579                HyperparamValue::String("sgd".to_string()),
1580                HyperparamValue::String("rmsprop".to_string()),
1581            ])
1582            .unwrap(),
1583        );
1584
1585        let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1586
1587        let config = bayes_opt.suggest().unwrap();
1588        assert!(config.contains_key("optimizer"));
1589
1590        let optimizer = config.get("optimizer").unwrap();
1591        assert!(matches!(optimizer, HyperparamValue::String(_)));
1592    }
1593
1594    #[test]
1595    fn test_normal_cdf() {
1596        // Test standard normal CDF at common points
1597        let cdf_0 = BayesianOptimization::normal_cdf(0.0);
1598        assert!((cdf_0 - 0.5).abs() < 1e-4);
1599
1600        let cdf_pos = BayesianOptimization::normal_cdf(1.96);
1601        assert!((cdf_pos - 0.975).abs() < 1e-2);
1602
1603        let cdf_neg = BayesianOptimization::normal_cdf(-1.96);
1604        assert!((cdf_neg - 0.025).abs() < 1e-2);
1605    }
1606
1607    #[test]
1608    fn test_normal_pdf() {
1609        // Test standard normal PDF at 0
1610        let pdf_0 = BayesianOptimization::normal_pdf(0.0);
1611        let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1612        assert!((pdf_0 - expected).abs() < 1e-6);
1613
1614        // PDF should be symmetric
1615        let pdf_pos = BayesianOptimization::normal_pdf(1.0);
1616        let pdf_neg = BayesianOptimization::normal_pdf(-1.0);
1617        assert!((pdf_pos - pdf_neg).abs() < 1e-10);
1618    }
1619
1620    #[test]
1621    fn test_erf() {
1622        // Test error function at known points
1623        assert!((BayesianOptimization::erf(0.0) - 0.0).abs() < 1e-6);
1624        assert!((BayesianOptimization::erf(1.0) - 0.8427).abs() < 1e-3);
1625        assert!((BayesianOptimization::erf(-1.0) + 0.8427).abs() < 1e-3);
1626    }
1627}