1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{s, Array1, Array2};
12use scirs2_core::random::{Rng, SeedableRng, StdRng};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum HyperparamValue {
18 Float(f64),
20 Int(i64),
22 Bool(bool),
24 String(String),
26}
27
28impl HyperparamValue {
29 pub fn as_float(&self) -> Option<f64> {
31 match self {
32 HyperparamValue::Float(v) => Some(*v),
33 HyperparamValue::Int(v) => Some(*v as f64),
34 _ => None,
35 }
36 }
37
38 pub fn as_int(&self) -> Option<i64> {
40 match self {
41 HyperparamValue::Int(v) => Some(*v),
42 HyperparamValue::Float(v) => Some(*v as i64),
43 _ => None,
44 }
45 }
46
47 pub fn as_bool(&self) -> Option<bool> {
49 match self {
50 HyperparamValue::Bool(v) => Some(*v),
51 _ => None,
52 }
53 }
54
55 pub fn as_string(&self) -> Option<&str> {
57 match self {
58 HyperparamValue::String(v) => Some(v),
59 _ => None,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub enum HyperparamSpace {
67 Discrete(Vec<HyperparamValue>),
69 Continuous { min: f64, max: f64 },
71 LogUniform { min: f64, max: f64 },
73 IntRange { min: i64, max: i64 },
75}
76
77impl HyperparamSpace {
78 pub fn discrete(values: Vec<HyperparamValue>) -> TrainResult<Self> {
80 if values.is_empty() {
81 return Err(TrainError::InvalidParameter(
82 "Discrete space cannot be empty".to_string(),
83 ));
84 }
85 Ok(Self::Discrete(values))
86 }
87
88 pub fn continuous(min: f64, max: f64) -> TrainResult<Self> {
90 if min >= max {
91 return Err(TrainError::InvalidParameter(
92 "min must be less than max".to_string(),
93 ));
94 }
95 Ok(Self::Continuous { min, max })
96 }
97
98 pub fn log_uniform(min: f64, max: f64) -> TrainResult<Self> {
100 if min <= 0.0 || max <= 0.0 || min >= max {
101 return Err(TrainError::InvalidParameter(
102 "min and max must be positive and min < max".to_string(),
103 ));
104 }
105 Ok(Self::LogUniform { min, max })
106 }
107
108 pub fn int_range(min: i64, max: i64) -> TrainResult<Self> {
110 if min >= max {
111 return Err(TrainError::InvalidParameter(
112 "min must be less than max".to_string(),
113 ));
114 }
115 Ok(Self::IntRange { min, max })
116 }
117
118 pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
120 match self {
121 HyperparamSpace::Discrete(values) => {
122 let idx = rng.gen_range(0..values.len());
123 values[idx].clone()
124 }
125 HyperparamSpace::Continuous { min, max } => {
126 let value = min + (max - min) * rng.random::<f64>();
127 HyperparamValue::Float(value)
128 }
129 HyperparamSpace::LogUniform { min, max } => {
130 let log_min = min.ln();
131 let log_max = max.ln();
132 let log_value = log_min + (log_max - log_min) * rng.random::<f64>();
133 HyperparamValue::Float(log_value.exp())
134 }
135 HyperparamSpace::IntRange { min, max } => {
136 let value = rng.gen_range(*min..=*max);
137 HyperparamValue::Int(value)
138 }
139 }
140 }
141
142 pub fn grid_values(&self, num_samples: usize) -> Vec<HyperparamValue> {
144 match self {
145 HyperparamSpace::Discrete(values) => values.clone(),
146 HyperparamSpace::IntRange { min, max } => {
147 let range_size = (max - min + 1) as usize;
148 let step = (range_size / num_samples).max(1);
149 (*min..=*max)
150 .step_by(step)
151 .map(HyperparamValue::Int)
152 .collect()
153 }
154 HyperparamSpace::Continuous { min, max } => {
155 let step = (max - min) / (num_samples as f64);
156 (0..num_samples)
157 .map(|i| HyperparamValue::Float(min + step * i as f64))
158 .collect()
159 }
160 HyperparamSpace::LogUniform { min, max } => {
161 let log_min = min.ln();
162 let log_max = max.ln();
163 let log_step = (log_max - log_min) / (num_samples as f64);
164 (0..num_samples)
165 .map(|i| HyperparamValue::Float((log_min + log_step * i as f64).exp()))
166 .collect()
167 }
168 }
169 }
170}
171
172pub type HyperparamConfig = HashMap<String, HyperparamValue>;
174
175#[derive(Debug, Clone)]
177pub struct HyperparamResult {
178 pub config: HyperparamConfig,
180 pub score: f64,
182 pub metrics: HashMap<String, f64>,
184}
185
186impl HyperparamResult {
187 pub fn new(config: HyperparamConfig, score: f64) -> Self {
189 Self {
190 config,
191 score,
192 metrics: HashMap::new(),
193 }
194 }
195
196 pub fn with_metric(mut self, name: String, value: f64) -> Self {
198 self.metrics.insert(name, value);
199 self
200 }
201}
202
203#[derive(Debug)]
207pub struct GridSearch {
208 param_space: HashMap<String, HyperparamSpace>,
210 num_grid_points: usize,
212 results: Vec<HyperparamResult>,
214}
215
216impl GridSearch {
217 pub fn new(param_space: HashMap<String, HyperparamSpace>, num_grid_points: usize) -> Self {
223 Self {
224 param_space,
225 num_grid_points,
226 results: Vec::new(),
227 }
228 }
229
230 pub fn generate_configs(&self) -> Vec<HyperparamConfig> {
232 if self.param_space.is_empty() {
233 return vec![HashMap::new()];
234 }
235
236 let mut param_names: Vec<String> = self.param_space.keys().cloned().collect();
237 param_names.sort(); let mut all_values: Vec<Vec<HyperparamValue>> = Vec::new();
240 for name in ¶m_names {
241 let space = &self.param_space[name];
242 all_values.push(space.grid_values(self.num_grid_points));
243 }
244
245 let mut configs = Vec::new();
247 self.generate_cartesian_product(
248 ¶m_names,
249 &all_values,
250 0,
251 &mut HashMap::new(),
252 &mut configs,
253 );
254
255 configs
256 }
257
258 #[allow(clippy::only_used_in_recursion)]
260 fn generate_cartesian_product(
261 &self,
262 param_names: &[String],
263 all_values: &[Vec<HyperparamValue>],
264 depth: usize,
265 current_config: &mut HyperparamConfig,
266 configs: &mut Vec<HyperparamConfig>,
267 ) {
268 if depth == param_names.len() {
269 configs.push(current_config.clone());
270 return;
271 }
272
273 let param_name = ¶m_names[depth];
274 let values = &all_values[depth];
275
276 for value in values {
277 current_config.insert(param_name.clone(), value.clone());
278 self.generate_cartesian_product(
279 param_names,
280 all_values,
281 depth + 1,
282 current_config,
283 configs,
284 );
285 }
286
287 current_config.remove(param_name);
288 }
289
290 pub fn add_result(&mut self, result: HyperparamResult) {
292 self.results.push(result);
293 }
294
295 pub fn best_result(&self) -> Option<&HyperparamResult> {
297 self.results.iter().max_by(|a, b| {
298 a.score
299 .partial_cmp(&b.score)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 })
302 }
303
304 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
306 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
307 results.sort_by(|a, b| {
308 b.score
309 .partial_cmp(&a.score)
310 .unwrap_or(std::cmp::Ordering::Equal)
311 });
312 results
313 }
314
315 pub fn results(&self) -> &[HyperparamResult] {
317 &self.results
318 }
319
320 pub fn total_configs(&self) -> usize {
322 self.generate_configs().len()
323 }
324}
325
326#[derive(Debug)]
330pub struct RandomSearch {
331 param_space: HashMap<String, HyperparamSpace>,
333 num_samples: usize,
335 rng: StdRng,
337 results: Vec<HyperparamResult>,
339}
340
341impl RandomSearch {
342 pub fn new(
349 param_space: HashMap<String, HyperparamSpace>,
350 num_samples: usize,
351 seed: u64,
352 ) -> Self {
353 Self {
354 param_space,
355 num_samples,
356 rng: StdRng::seed_from_u64(seed),
357 results: Vec::new(),
358 }
359 }
360
361 pub fn generate_configs(&mut self) -> Vec<HyperparamConfig> {
363 let mut configs = Vec::with_capacity(self.num_samples);
364
365 for _ in 0..self.num_samples {
366 let mut config = HashMap::new();
367
368 for (name, space) in &self.param_space {
369 let value = space.sample(&mut self.rng);
370 config.insert(name.clone(), value);
371 }
372
373 configs.push(config);
374 }
375
376 configs
377 }
378
379 pub fn add_result(&mut self, result: HyperparamResult) {
381 self.results.push(result);
382 }
383
384 pub fn best_result(&self) -> Option<&HyperparamResult> {
386 self.results.iter().max_by(|a, b| {
387 a.score
388 .partial_cmp(&b.score)
389 .unwrap_or(std::cmp::Ordering::Equal)
390 })
391 }
392
393 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
395 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
396 results.sort_by(|a, b| {
397 b.score
398 .partial_cmp(&a.score)
399 .unwrap_or(std::cmp::Ordering::Equal)
400 });
401 results
402 }
403
404 pub fn results(&self) -> &[HyperparamResult] {
406 &self.results
407 }
408}
409
410#[derive(Debug, Clone, Copy, PartialEq)]
416pub enum AcquisitionFunction {
417 ExpectedImprovement { xi: f64 },
419 UpperConfidenceBound { kappa: f64 },
421 ProbabilityOfImprovement { xi: f64 },
423}
424
425impl Default for AcquisitionFunction {
426 fn default() -> Self {
427 Self::ExpectedImprovement { xi: 0.01 }
428 }
429}
430
431#[derive(Debug, Clone, Copy)]
433pub enum GpKernel {
434 Rbf {
437 sigma: f64,
439 length_scale: f64,
441 },
442 Matern32 {
445 sigma: f64,
447 length_scale: f64,
449 },
450}
451
452impl Default for GpKernel {
453 fn default() -> Self {
454 Self::Rbf {
455 sigma: 1.0,
456 length_scale: 1.0,
457 }
458 }
459}
460
461impl GpKernel {
462 fn compute_kernel(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
464 let n1 = x1.nrows();
465 let n2 = x2.nrows();
466 let mut k = Array2::zeros((n1, n2));
467
468 for i in 0..n1 {
469 for j in 0..n2 {
470 let x1_row = x1.row(i);
471 let x2_row = x2.row(j);
472 let dist_sq = x1_row
473 .iter()
474 .zip(x2_row.iter())
475 .map(|(a, b)| (a - b).powi(2))
476 .sum::<f64>();
477
478 k[[i, j]] = match self {
479 Self::Rbf {
480 sigma,
481 length_scale,
482 } => sigma.powi(2) * (-dist_sq / (2.0 * length_scale.powi(2))).exp(),
483 Self::Matern32 {
484 sigma,
485 length_scale,
486 } => {
487 let r = dist_sq.sqrt();
488 let sqrt3_r_l = (3.0_f64).sqrt() * r / length_scale;
489 sigma.powi(2) * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
490 }
491 };
492 }
493 }
494
495 k
496 }
497
498 fn compute_kernel_vector(&self, x_train: &Array2<f64>, x_test: &Array1<f64>) -> Array1<f64> {
500 let n = x_train.nrows();
501 let mut k = Array1::zeros(n);
502
503 for i in 0..n {
504 let x_train_row = x_train.row(i);
505 let dist_sq = x_train_row
506 .iter()
507 .zip(x_test.iter())
508 .map(|(a, b)| (a - b).powi(2))
509 .sum::<f64>();
510
511 k[i] = match self {
512 Self::Rbf {
513 sigma,
514 length_scale,
515 } => sigma.powi(2) * (-dist_sq / (2.0 * length_scale.powi(2))).exp(),
516 Self::Matern32 {
517 sigma,
518 length_scale,
519 } => {
520 let r = dist_sq.sqrt();
521 let sqrt3_r_l = (3.0_f64).sqrt() * r / length_scale;
522 sigma.powi(2) * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
523 }
524 };
525 }
526
527 k
528 }
529}
530
531#[derive(Debug)]
535pub struct GaussianProcess {
536 kernel: GpKernel,
538 noise_variance: f64,
540 x_train: Option<Array2<f64>>,
542 y_train: Option<Array1<f64>>,
544 y_mean: f64,
546 y_std: f64,
548 l_matrix: Option<Array2<f64>>,
550 alpha: Option<Array1<f64>>,
552}
553
554impl GaussianProcess {
555 pub fn new(kernel: GpKernel, noise_variance: f64) -> Self {
557 Self {
558 kernel,
559 noise_variance,
560 x_train: None,
561 y_train: None,
562 y_mean: 0.0,
563 y_std: 1.0,
564 l_matrix: None,
565 alpha: None,
566 }
567 }
568
569 pub fn fit(&mut self, x: Array2<f64>, y: Array1<f64>) -> TrainResult<()> {
571 if x.nrows() != y.len() {
572 return Err(TrainError::InvalidParameter(
573 "X and y must have same number of samples".to_string(),
574 ));
575 }
576
577 let y_mean = y.mean().unwrap_or(0.0);
579 let y_std = y.std(0.0).max(1e-8);
580 let y_standardized = (&y - y_mean) / y_std;
581
582 let k = self.kernel.compute_kernel(&x, &x);
584
585 let mut k_noisy = k;
587 for i in 0..k_noisy.nrows() {
588 k_noisy[[i, i]] += self.noise_variance;
589 }
590
591 let l = self.cholesky(&k_noisy)?;
593
594 let alpha_prime = self.forward_substitution(&l, &y_standardized)?;
596 let alpha = self.backward_substitution(&l, &alpha_prime)?;
598
599 self.x_train = Some(x);
600 self.y_train = Some(y_standardized);
601 self.y_mean = y_mean;
602 self.y_std = y_std;
603 self.l_matrix = Some(l);
604 self.alpha = Some(alpha);
605
606 Ok(())
607 }
608
609 pub fn predict(&self, x_test: &Array2<f64>) -> TrainResult<(Array1<f64>, Array1<f64>)> {
611 let x_train = self
612 .x_train
613 .as_ref()
614 .ok_or_else(|| TrainError::InvalidParameter("GP not fitted".to_string()))?;
615 let l_matrix = self.l_matrix.as_ref().unwrap();
616 let alpha = self.alpha.as_ref().unwrap();
617
618 let n_test = x_test.nrows();
619 let mut means = Array1::zeros(n_test);
620 let mut stds = Array1::zeros(n_test);
621
622 for i in 0..n_test {
623 let x = x_test.row(i).to_owned();
624
625 let k_star = self.kernel.compute_kernel_vector(x_train, &x);
627
628 let mean_standardized = k_star.dot(alpha);
630 means[i] = mean_standardized * self.y_std + self.y_mean;
631
632 let k_star_star = self
634 .kernel
635 .compute_kernel_vector(&x_test.slice(s![i..i + 1, ..]).to_owned(), &x)[0];
636 let v = self
637 .forward_substitution(l_matrix, &k_star)
638 .unwrap_or_else(|_| Array1::zeros(k_star.len()));
639 let variance_standardized = k_star_star - v.dot(&v);
640 stds[i] = (variance_standardized.max(1e-10) * self.y_std.powi(2)).sqrt();
641 }
642
643 Ok((means, stds))
644 }
645
646 fn cholesky(&self, k: &Array2<f64>) -> TrainResult<Array2<f64>> {
648 let n = k.nrows();
649 let mut l = Array2::zeros((n, n));
650
651 for i in 0..n {
652 for j in 0..=i {
653 let mut sum = 0.0;
654 for k_idx in 0..j {
655 sum += l[[i, k_idx]] * l[[j, k_idx]];
656 }
657
658 if i == j {
659 let val = k[[i, i]] - sum;
660 if val <= 0.0 {
661 l[[i, j]] = (k[[i, i]] - sum + 1e-6).sqrt();
663 } else {
664 l[[i, j]] = val.sqrt();
665 }
666 } else {
667 l[[i, j]] = (k[[i, j]] - sum) / l[[j, j]];
668 }
669 }
670 }
671
672 Ok(l)
673 }
674
675 fn forward_substitution(&self, l: &Array2<f64>, b: &Array1<f64>) -> TrainResult<Array1<f64>> {
677 let n = l.nrows();
678 let mut x = Array1::zeros(n);
679
680 for i in 0..n {
681 let mut sum = 0.0;
682 for j in 0..i {
683 sum += l[[i, j]] * x[j];
684 }
685 x[i] = (b[i] - sum) / l[[i, i]];
686 }
687
688 Ok(x)
689 }
690
691 fn backward_substitution(&self, l: &Array2<f64>, b: &Array1<f64>) -> TrainResult<Array1<f64>> {
693 let n = l.nrows();
694 let mut x = Array1::zeros(n);
695
696 for i in (0..n).rev() {
697 let mut sum = 0.0;
698 for j in (i + 1)..n {
699 sum += l[[j, i]] * x[j];
700 }
701 x[i] = (b[i] - sum) / l[[i, i]];
702 }
703
704 Ok(x)
705 }
706}
707
708pub struct BayesianOptimization {
742 param_space: HashMap<String, HyperparamSpace>,
744 n_iterations: usize,
746 n_initial_points: usize,
748 acquisition_fn: AcquisitionFunction,
750 kernel: GpKernel,
752 noise_variance: f64,
754 rng: StdRng,
756 results: Vec<HyperparamResult>,
758 bounds: Vec<(f64, f64)>,
760 param_names: Vec<String>,
762}
763
764impl BayesianOptimization {
765 pub fn new(
773 param_space: HashMap<String, HyperparamSpace>,
774 n_iterations: usize,
775 n_initial_points: usize,
776 seed: u64,
777 ) -> Self {
778 let mut param_names: Vec<String> = param_space.keys().cloned().collect();
779 param_names.sort(); let bounds = Self::extract_bounds(¶m_space, ¶m_names);
782
783 Self {
784 param_space,
785 n_iterations,
786 n_initial_points,
787 acquisition_fn: AcquisitionFunction::default(),
788 kernel: GpKernel::default(),
789 noise_variance: 1e-6,
790 rng: StdRng::seed_from_u64(seed),
791 results: Vec::new(),
792 bounds,
793 param_names,
794 }
795 }
796
797 pub fn with_acquisition(mut self, acquisition_fn: AcquisitionFunction) -> Self {
799 self.acquisition_fn = acquisition_fn;
800 self
801 }
802
803 pub fn with_kernel(mut self, kernel: GpKernel) -> Self {
805 self.kernel = kernel;
806 self
807 }
808
809 pub fn with_noise(mut self, noise_variance: f64) -> Self {
811 self.noise_variance = noise_variance;
812 self
813 }
814
815 fn extract_bounds(
817 param_space: &HashMap<String, HyperparamSpace>,
818 param_names: &[String],
819 ) -> Vec<(f64, f64)> {
820 param_names
821 .iter()
822 .map(|name| {
823 match ¶m_space[name] {
824 HyperparamSpace::Continuous { min, max } => (*min, *max),
825 HyperparamSpace::LogUniform { min, max } => (min.ln(), max.ln()),
826 HyperparamSpace::IntRange { min, max } => (*min as f64, *max as f64),
827 HyperparamSpace::Discrete(values) => {
828 (0.0, (values.len() - 1) as f64)
830 }
831 }
832 })
833 .collect()
834 }
835
836 pub fn suggest(&mut self) -> TrainResult<HyperparamConfig> {
838 if self.results.len() < self.n_initial_points {
840 return Ok(self.random_sample());
841 }
842
843 let (x_observed, y_observed) = self.get_observations();
845 let mut gp = GaussianProcess::new(self.kernel, self.noise_variance);
846 gp.fit(x_observed, y_observed)?;
847
848 let best_x = self.optimize_acquisition(&gp)?;
850
851 self.vector_to_config(&best_x)
853 }
854
855 fn get_observations(&self) -> (Array2<f64>, Array1<f64>) {
857 let n_samples = self.results.len();
858 let n_dims = self.param_names.len();
859
860 let mut x = Array2::zeros((n_samples, n_dims));
861 let mut y = Array1::zeros(n_samples);
862
863 for (i, result) in self.results.iter().enumerate() {
864 let x_vec = self.config_to_vector(&result.config);
865 for (j, &val) in x_vec.iter().enumerate() {
866 x[[i, j]] = val;
867 }
868 y[i] = result.score;
869 }
870
871 (x, y)
872 }
873
874 fn optimize_acquisition(&mut self, gp: &GaussianProcess) -> TrainResult<Array1<f64>> {
876 let n_dims = self.param_names.len();
877 let n_candidates = 1000;
878 let n_restarts = 10;
879
880 let mut best_acq_value = f64::NEG_INFINITY;
881 let mut best_x = Array1::zeros(n_dims);
882
883 for _ in 0..n_restarts {
885 for _ in 0..(n_candidates / n_restarts) {
886 let mut x_candidate = Array1::zeros(n_dims);
888 for (i, (min, max)) in self.bounds.iter().enumerate() {
889 x_candidate[i] = min + (max - min) * self.rng.random::<f64>();
890 }
891
892 let acq_value = self.evaluate_acquisition(gp, &x_candidate)?;
894
895 if acq_value > best_acq_value {
896 best_acq_value = acq_value;
897 best_x = x_candidate;
898 }
899 }
900 }
901
902 Ok(best_x)
903 }
904
905 fn evaluate_acquisition(&self, gp: &GaussianProcess, x: &Array1<f64>) -> TrainResult<f64> {
907 let x_mat = x.clone().into_shape_with_order((1, x.len())).unwrap();
908 let (mean, std) = gp.predict(&x_mat)?;
909 let mu = mean[0];
910 let sigma = std[0];
911
912 if sigma < 1e-10 {
913 return Ok(0.0);
914 }
915
916 let f_best = self
917 .results
918 .iter()
919 .map(|r| r.score)
920 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
921 .unwrap_or(0.0);
922
923 let acq = match self.acquisition_fn {
924 AcquisitionFunction::ExpectedImprovement { xi } => {
925 let z = (mu - f_best - xi) / sigma;
926 let phi = Self::normal_cdf(z);
927 let pdf = Self::normal_pdf(z);
928 (mu - f_best - xi) * phi + sigma * pdf
929 }
930 AcquisitionFunction::UpperConfidenceBound { kappa } => mu + kappa * sigma,
931 AcquisitionFunction::ProbabilityOfImprovement { xi } => {
932 let z = (mu - f_best - xi) / sigma;
933 Self::normal_cdf(z)
934 }
935 };
936
937 Ok(acq)
938 }
939
940 fn normal_cdf(x: f64) -> f64 {
942 0.5 * (1.0 + Self::erf(x / 2.0_f64.sqrt()))
943 }
944
945 fn normal_pdf(x: f64) -> f64 {
947 (-0.5 * x.powi(2)).exp() / (2.0 * std::f64::consts::PI).sqrt()
948 }
949
950 fn erf(x: f64) -> f64 {
952 let a1 = 0.254829592;
954 let a2 = -0.284496736;
955 let a3 = 1.421413741;
956 let a4 = -1.453152027;
957 let a5 = 1.061405429;
958 let p = 0.3275911;
959
960 let sign = if x < 0.0 { -1.0 } else { 1.0 };
961 let x = x.abs();
962
963 let t = 1.0 / (1.0 + p * x);
964 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
965
966 sign * y
967 }
968
969 fn config_to_vector(&self, config: &HyperparamConfig) -> Array1<f64> {
971 let n_dims = self.param_names.len();
972 let mut x = Array1::zeros(n_dims);
973
974 for (i, name) in self.param_names.iter().enumerate() {
975 let value = &config[name];
976 let (min, max) = self.bounds[i];
977
978 x[i] = match &self.param_space[name] {
979 HyperparamSpace::Continuous { .. } => {
980 let v = value.as_float().unwrap();
981 (v - min) / (max - min)
982 }
983 HyperparamSpace::LogUniform { .. } => {
984 let v = value.as_float().unwrap();
985 let log_v = v.ln();
986 (log_v - min) / (max - min)
987 }
988 HyperparamSpace::IntRange { .. } => {
989 let v = value.as_int().unwrap() as f64;
990 (v - min) / (max - min)
991 }
992 HyperparamSpace::Discrete(values) => {
993 let idx = values.iter().position(|v| v == value).unwrap_or(0);
994 (idx as f64 - min) / (max - min)
995 }
996 };
997 }
998
999 x
1000 }
1001
1002 fn vector_to_config(&self, x: &Array1<f64>) -> TrainResult<HyperparamConfig> {
1004 let mut config = HashMap::new();
1005
1006 for (i, name) in self.param_names.iter().enumerate() {
1007 let normalized = x[i].clamp(0.0, 1.0);
1008 let (min, max) = self.bounds[i];
1009 let value_raw = min + normalized * (max - min);
1010
1011 let value = match &self.param_space[name] {
1012 HyperparamSpace::Continuous { .. } => HyperparamValue::Float(value_raw),
1013 HyperparamSpace::LogUniform { .. } => HyperparamValue::Float(value_raw.exp()),
1014 HyperparamSpace::IntRange { .. } => HyperparamValue::Int(value_raw.round() as i64),
1015 HyperparamSpace::Discrete(values) => {
1016 let idx = value_raw.round() as usize;
1017 values[idx.min(values.len() - 1)].clone()
1018 }
1019 };
1020
1021 config.insert(name.clone(), value);
1022 }
1023
1024 Ok(config)
1025 }
1026
1027 fn random_sample(&mut self) -> HyperparamConfig {
1029 let mut config = HashMap::new();
1030
1031 for (name, space) in &self.param_space {
1032 let value = space.sample(&mut self.rng);
1033 config.insert(name.clone(), value);
1034 }
1035
1036 config
1037 }
1038
1039 pub fn add_result(&mut self, result: HyperparamResult) {
1041 self.results.push(result);
1042 }
1043
1044 pub fn best_result(&self) -> Option<&HyperparamResult> {
1046 self.results.iter().max_by(|a, b| {
1047 a.score
1048 .partial_cmp(&b.score)
1049 .unwrap_or(std::cmp::Ordering::Equal)
1050 })
1051 }
1052
1053 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
1055 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
1056 results.sort_by(|a, b| {
1057 b.score
1058 .partial_cmp(&a.score)
1059 .unwrap_or(std::cmp::Ordering::Equal)
1060 });
1061 results
1062 }
1063
1064 pub fn results(&self) -> &[HyperparamResult] {
1066 &self.results
1067 }
1068
1069 pub fn is_complete(&self) -> bool {
1071 self.results.len() >= self.n_iterations + self.n_initial_points
1072 }
1073
1074 pub fn current_iteration(&self) -> usize {
1076 self.results.len()
1077 }
1078
1079 pub fn total_budget(&self) -> usize {
1081 self.n_iterations + self.n_initial_points
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088
1089 #[test]
1090 fn test_hyperparam_value() {
1091 let float_val = HyperparamValue::Float(3.5);
1092 assert_eq!(float_val.as_float(), Some(3.5));
1093 assert_eq!(float_val.as_int(), Some(3));
1094
1095 let int_val = HyperparamValue::Int(42);
1096 assert_eq!(int_val.as_int(), Some(42));
1097 assert_eq!(int_val.as_float(), Some(42.0));
1098
1099 let bool_val = HyperparamValue::Bool(true);
1100 assert_eq!(bool_val.as_bool(), Some(true));
1101
1102 let string_val = HyperparamValue::String("test".to_string());
1103 assert_eq!(string_val.as_string(), Some("test"));
1104 }
1105
1106 #[test]
1107 fn test_hyperparam_space_discrete() {
1108 let space = HyperparamSpace::discrete(vec![
1109 HyperparamValue::Float(0.1),
1110 HyperparamValue::Float(0.01),
1111 ])
1112 .unwrap();
1113
1114 let values = space.grid_values(10);
1115 assert_eq!(values.len(), 2);
1116
1117 let mut rng = StdRng::seed_from_u64(42);
1118 let sampled = space.sample(&mut rng);
1119 assert!(matches!(sampled, HyperparamValue::Float(_)));
1120 }
1121
1122 #[test]
1123 fn test_hyperparam_space_continuous() {
1124 let space = HyperparamSpace::continuous(0.0, 1.0).unwrap();
1125
1126 let values = space.grid_values(5);
1127 assert_eq!(values.len(), 5);
1128
1129 let mut rng = StdRng::seed_from_u64(42);
1130 let sampled = space.sample(&mut rng);
1131 if let HyperparamValue::Float(v) = sampled {
1132 assert!((0.0..=1.0).contains(&v));
1133 } else {
1134 panic!("Expected Float value");
1135 }
1136 }
1137
1138 #[test]
1139 fn test_hyperparam_space_log_uniform() {
1140 let space = HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap();
1141
1142 let values = space.grid_values(3);
1143 assert_eq!(values.len(), 3);
1144
1145 let mut rng = StdRng::seed_from_u64(42);
1146 let sampled = space.sample(&mut rng);
1147 if let HyperparamValue::Float(v) = sampled {
1148 assert!((1e-4..=1e-1).contains(&v));
1149 } else {
1150 panic!("Expected Float value");
1151 }
1152 }
1153
1154 #[test]
1155 fn test_hyperparam_space_int_range() {
1156 let space = HyperparamSpace::int_range(1, 10).unwrap();
1157
1158 let values = space.grid_values(5);
1159 assert!(!values.is_empty());
1160
1161 let mut rng = StdRng::seed_from_u64(42);
1162 let sampled = space.sample(&mut rng);
1163 if let HyperparamValue::Int(v) = sampled {
1164 assert!((1..=10).contains(&v));
1165 } else {
1166 panic!("Expected Int value");
1167 }
1168 }
1169
1170 #[test]
1171 fn test_hyperparam_space_invalid() {
1172 assert!(HyperparamSpace::discrete(vec![]).is_err());
1173 assert!(HyperparamSpace::continuous(1.0, 0.0).is_err());
1174 assert!(HyperparamSpace::log_uniform(0.0, 1.0).is_err());
1175 assert!(HyperparamSpace::log_uniform(1.0, 0.5).is_err());
1176 assert!(HyperparamSpace::int_range(10, 5).is_err());
1177 }
1178
1179 #[test]
1180 fn test_grid_search() {
1181 let mut param_space = HashMap::new();
1182 param_space.insert(
1183 "lr".to_string(),
1184 HyperparamSpace::discrete(vec![
1185 HyperparamValue::Float(0.1),
1186 HyperparamValue::Float(0.01),
1187 ])
1188 .unwrap(),
1189 );
1190 param_space.insert(
1191 "batch_size".to_string(),
1192 HyperparamSpace::int_range(16, 64).unwrap(),
1193 );
1194
1195 let grid_search = GridSearch::new(param_space, 3);
1196
1197 let configs = grid_search.generate_configs();
1198 assert!(!configs.is_empty());
1199
1200 assert!(configs.len() >= 2);
1202 }
1203
1204 #[test]
1205 fn test_grid_search_results() {
1206 let mut param_space = HashMap::new();
1207 param_space.insert(
1208 "lr".to_string(),
1209 HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
1210 );
1211
1212 let mut grid_search = GridSearch::new(param_space, 3);
1213
1214 let mut config = HashMap::new();
1215 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1216
1217 grid_search.add_result(HyperparamResult::new(config.clone(), 0.9));
1218 grid_search.add_result(HyperparamResult::new(config.clone(), 0.95));
1219 grid_search.add_result(HyperparamResult::new(config, 0.85));
1220
1221 let best = grid_search.best_result().unwrap();
1222 assert_eq!(best.score, 0.95);
1223
1224 let sorted = grid_search.sorted_results();
1225 assert_eq!(sorted[0].score, 0.95);
1226 assert_eq!(sorted[1].score, 0.9);
1227 assert_eq!(sorted[2].score, 0.85);
1228 }
1229
1230 #[test]
1231 fn test_random_search() {
1232 let mut param_space = HashMap::new();
1233 param_space.insert(
1234 "lr".to_string(),
1235 HyperparamSpace::continuous(1e-4, 1e-1).unwrap(),
1236 );
1237 param_space.insert(
1238 "dropout".to_string(),
1239 HyperparamSpace::continuous(0.0, 0.5).unwrap(),
1240 );
1241
1242 let mut random_search = RandomSearch::new(param_space, 10, 42);
1243
1244 let configs = random_search.generate_configs();
1245 assert_eq!(configs.len(), 10);
1246
1247 for config in &configs {
1249 assert!(config.contains_key("lr"));
1250 assert!(config.contains_key("dropout"));
1251 }
1252 }
1253
1254 #[test]
1255 fn test_random_search_results() {
1256 let mut param_space = HashMap::new();
1257 param_space.insert(
1258 "lr".to_string(),
1259 HyperparamSpace::discrete(vec![HyperparamValue::Float(0.1)]).unwrap(),
1260 );
1261
1262 let mut random_search = RandomSearch::new(param_space, 5, 42);
1263
1264 let mut config = HashMap::new();
1265 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1266
1267 random_search.add_result(HyperparamResult::new(config.clone(), 0.8));
1268 random_search.add_result(HyperparamResult::new(config, 0.9));
1269
1270 let best = random_search.best_result().unwrap();
1271 assert_eq!(best.score, 0.9);
1272
1273 assert_eq!(random_search.results().len(), 2);
1274 }
1275
1276 #[test]
1277 fn test_hyperparam_result_with_metrics() {
1278 let mut config = HashMap::new();
1279 config.insert("lr".to_string(), HyperparamValue::Float(0.1));
1280
1281 let result = HyperparamResult::new(config, 0.95)
1282 .with_metric("accuracy".to_string(), 0.95)
1283 .with_metric("loss".to_string(), 0.05);
1284
1285 assert_eq!(result.score, 0.95);
1286 assert_eq!(result.metrics.get("accuracy"), Some(&0.95));
1287 assert_eq!(result.metrics.get("loss"), Some(&0.05));
1288 }
1289
1290 #[test]
1291 fn test_grid_search_empty_space() {
1292 let grid_search = GridSearch::new(HashMap::new(), 3);
1293 let configs = grid_search.generate_configs();
1294 assert_eq!(configs.len(), 1); assert!(configs[0].is_empty());
1296 }
1297
1298 #[test]
1299 fn test_grid_search_total_configs() {
1300 let mut param_space = HashMap::new();
1301 param_space.insert(
1302 "lr".to_string(),
1303 HyperparamSpace::discrete(vec![
1304 HyperparamValue::Float(0.1),
1305 HyperparamValue::Float(0.01),
1306 ])
1307 .unwrap(),
1308 );
1309
1310 let grid_search = GridSearch::new(param_space, 3);
1311 assert_eq!(grid_search.total_configs(), 2);
1312 }
1313
1314 #[test]
1319 fn test_gp_kernel_rbf() {
1320 let kernel = GpKernel::Rbf {
1321 sigma: 1.0,
1322 length_scale: 1.0,
1323 };
1324
1325 let x1 = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).unwrap();
1326 let x2 = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 0.5, 0.5]).unwrap();
1327
1328 let k = kernel.compute_kernel(&x1, &x2);
1329 assert_eq!(k.shape(), &[2, 2]);
1330
1331 assert!((k[[0, 0]] - 1.0).abs() < 1e-6);
1333 }
1334
1335 #[test]
1336 fn test_gp_kernel_matern() {
1337 let kernel = GpKernel::Matern32 {
1338 sigma: 1.0,
1339 length_scale: 1.0,
1340 };
1341
1342 let x = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).unwrap();
1343 let k = kernel.compute_kernel(&x, &x);
1344
1345 assert!((k[[0, 0]] - 1.0).abs() < 1e-6);
1347 }
1348
1349 #[test]
1350 fn test_gp_fit_and_predict() {
1351 let kernel = GpKernel::Rbf {
1352 sigma: 1.0,
1353 length_scale: 0.5,
1354 };
1355 let mut gp = GaussianProcess::new(kernel, 1e-6);
1356
1357 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.5, 1.0, 1.5, 2.0]).unwrap();
1359 let y_train = Array1::from_vec(vec![0.0, 0.25, 1.0, 2.25, 4.0]);
1360
1361 gp.fit(x_train, y_train).unwrap();
1362
1363 let x_test = Array2::from_shape_vec((2, 1), vec![0.75, 1.25]).unwrap();
1365 let (means, _stds) = gp.predict(&x_test).unwrap();
1366
1367 assert_eq!(means.len(), 2);
1368 assert!(means[0] >= 0.0 && means[0] <= 4.0);
1370 assert!(means[1] >= 0.0 && means[1] <= 4.0);
1371 }
1372
1373 #[test]
1374 fn test_gp_predict_error_not_fitted() {
1375 let kernel = GpKernel::default();
1376 let gp = GaussianProcess::new(kernel, 1e-6);
1377
1378 let x_test = Array2::from_shape_vec((1, 1), vec![0.5]).unwrap();
1379 let result = gp.predict(&x_test);
1380
1381 assert!(result.is_err());
1382 }
1383
1384 #[test]
1385 fn test_gp_fit_dimension_mismatch() {
1386 let kernel = GpKernel::default();
1387 let mut gp = GaussianProcess::new(kernel, 1e-6);
1388
1389 let x = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
1390 let y = Array1::from_vec(vec![0.0, 1.0]); let result = gp.fit(x, y);
1393 assert!(result.is_err());
1394 }
1395
1396 #[test]
1397 fn test_acquisition_function_ei() {
1398 let acq = AcquisitionFunction::ExpectedImprovement { xi: 0.01 };
1399 assert!(matches!(
1400 acq,
1401 AcquisitionFunction::ExpectedImprovement { .. }
1402 ));
1403 }
1404
1405 #[test]
1406 fn test_acquisition_function_ucb() {
1407 let acq = AcquisitionFunction::UpperConfidenceBound { kappa: 2.0 };
1408 assert!(matches!(
1409 acq,
1410 AcquisitionFunction::UpperConfidenceBound { .. }
1411 ));
1412 }
1413
1414 #[test]
1415 fn test_acquisition_function_pi() {
1416 let acq = AcquisitionFunction::ProbabilityOfImprovement { xi: 0.01 };
1417 assert!(matches!(
1418 acq,
1419 AcquisitionFunction::ProbabilityOfImprovement { .. }
1420 ));
1421 }
1422
1423 #[test]
1424 fn test_bayesian_optimization_creation() {
1425 let mut param_space = HashMap::new();
1426 param_space.insert(
1427 "lr".to_string(),
1428 HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1429 );
1430
1431 let bayes_opt = BayesianOptimization::new(param_space, 10, 5, 42);
1432
1433 assert_eq!(bayes_opt.total_budget(), 15);
1434 assert_eq!(bayes_opt.current_iteration(), 0);
1435 assert!(!bayes_opt.is_complete());
1436 }
1437
1438 #[test]
1439 fn test_bayesian_optimization_suggest_initial() {
1440 let mut param_space = HashMap::new();
1441 param_space.insert(
1442 "lr".to_string(),
1443 HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1444 );
1445
1446 let mut bayes_opt = BayesianOptimization::new(param_space, 5, 3, 42);
1447
1448 for _ in 0..3 {
1450 let config = bayes_opt.suggest().unwrap();
1451 assert!(config.contains_key("lr"));
1452
1453 bayes_opt.add_result(HyperparamResult::new(config, 0.5));
1455 }
1456
1457 assert_eq!(bayes_opt.current_iteration(), 3);
1458 }
1459
1460 #[test]
1461 fn test_bayesian_optimization_suggest_gp_phase() {
1462 let mut param_space = HashMap::new();
1463 param_space.insert(
1464 "x".to_string(),
1465 HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1466 );
1467
1468 let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1469
1470 let mut config1 = HashMap::new();
1472 config1.insert("x".to_string(), HyperparamValue::Float(0.25));
1473 bayes_opt.add_result(HyperparamResult::new(config1, 0.5));
1474
1475 let mut config2 = HashMap::new();
1476 config2.insert("x".to_string(), HyperparamValue::Float(0.75));
1477 bayes_opt.add_result(HyperparamResult::new(config2, 0.8));
1478
1479 let config = bayes_opt.suggest().unwrap();
1481 assert!(config.contains_key("x"));
1482 }
1483
1484 #[test]
1485 fn test_bayesian_optimization_with_acquisition() {
1486 let mut param_space = HashMap::new();
1487 param_space.insert(
1488 "lr".to_string(),
1489 HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1490 );
1491
1492 let bayes_opt = BayesianOptimization::new(param_space, 10, 5, 42)
1493 .with_acquisition(AcquisitionFunction::UpperConfidenceBound { kappa: 2.0 })
1494 .with_kernel(GpKernel::Matern32 {
1495 sigma: 1.0,
1496 length_scale: 0.5,
1497 })
1498 .with_noise(1e-5);
1499
1500 assert!(bayes_opt.total_budget() == 15);
1501 }
1502
1503 #[test]
1504 fn test_bayesian_optimization_best_result() {
1505 let mut param_space = HashMap::new();
1506 param_space.insert(
1507 "x".to_string(),
1508 HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1509 );
1510
1511 let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1512
1513 let mut config1 = HashMap::new();
1514 config1.insert("x".to_string(), HyperparamValue::Float(0.3));
1515 bayes_opt.add_result(HyperparamResult::new(config1, 0.6));
1516
1517 let mut config2 = HashMap::new();
1518 config2.insert("x".to_string(), HyperparamValue::Float(0.7));
1519 bayes_opt.add_result(HyperparamResult::new(config2, 0.9));
1520
1521 let best = bayes_opt.best_result().unwrap();
1522 assert_eq!(best.score, 0.9);
1523 }
1524
1525 #[test]
1526 fn test_bayesian_optimization_is_complete() {
1527 let mut param_space = HashMap::new();
1528 param_space.insert(
1529 "x".to_string(),
1530 HyperparamSpace::continuous(0.0, 1.0).unwrap(),
1531 );
1532
1533 let mut bayes_opt = BayesianOptimization::new(param_space, 2, 1, 42);
1534
1535 assert!(!bayes_opt.is_complete());
1536
1537 for i in 0..3 {
1539 let mut config = HashMap::new();
1540 config.insert("x".to_string(), HyperparamValue::Float(i as f64 * 0.3));
1541 bayes_opt.add_result(HyperparamResult::new(config, i as f64 * 0.2));
1542 }
1543
1544 assert!(bayes_opt.is_complete());
1545 }
1546
1547 #[test]
1548 fn test_bayesian_optimization_multivariate() {
1549 let mut param_space = HashMap::new();
1550 param_space.insert(
1551 "lr".to_string(),
1552 HyperparamSpace::log_uniform(1e-4, 1e-1).unwrap(),
1553 );
1554 param_space.insert(
1555 "batch_size".to_string(),
1556 HyperparamSpace::int_range(16, 128).unwrap(),
1557 );
1558 param_space.insert(
1559 "dropout".to_string(),
1560 HyperparamSpace::continuous(0.0, 0.5).unwrap(),
1561 );
1562
1563 let mut bayes_opt = BayesianOptimization::new(param_space, 10, 3, 42);
1564
1565 let config = bayes_opt.suggest().unwrap();
1566 assert_eq!(config.len(), 3);
1567 assert!(config.contains_key("lr"));
1568 assert!(config.contains_key("batch_size"));
1569 assert!(config.contains_key("dropout"));
1570 }
1571
1572 #[test]
1573 fn test_bayesian_optimization_discrete_space() {
1574 let mut param_space = HashMap::new();
1575 param_space.insert(
1576 "optimizer".to_string(),
1577 HyperparamSpace::discrete(vec![
1578 HyperparamValue::String("adam".to_string()),
1579 HyperparamValue::String("sgd".to_string()),
1580 HyperparamValue::String("rmsprop".to_string()),
1581 ])
1582 .unwrap(),
1583 );
1584
1585 let mut bayes_opt = BayesianOptimization::new(param_space, 5, 2, 42);
1586
1587 let config = bayes_opt.suggest().unwrap();
1588 assert!(config.contains_key("optimizer"));
1589
1590 let optimizer = config.get("optimizer").unwrap();
1591 assert!(matches!(optimizer, HyperparamValue::String(_)));
1592 }
1593
1594 #[test]
1595 fn test_normal_cdf() {
1596 let cdf_0 = BayesianOptimization::normal_cdf(0.0);
1598 assert!((cdf_0 - 0.5).abs() < 1e-4);
1599
1600 let cdf_pos = BayesianOptimization::normal_cdf(1.96);
1601 assert!((cdf_pos - 0.975).abs() < 1e-2);
1602
1603 let cdf_neg = BayesianOptimization::normal_cdf(-1.96);
1604 assert!((cdf_neg - 0.025).abs() < 1e-2);
1605 }
1606
1607 #[test]
1608 fn test_normal_pdf() {
1609 let pdf_0 = BayesianOptimization::normal_pdf(0.0);
1611 let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1612 assert!((pdf_0 - expected).abs() < 1e-6);
1613
1614 let pdf_pos = BayesianOptimization::normal_pdf(1.0);
1616 let pdf_neg = BayesianOptimization::normal_pdf(-1.0);
1617 assert!((pdf_pos - pdf_neg).abs() < 1e-10);
1618 }
1619
1620 #[test]
1621 fn test_erf() {
1622 assert!((BayesianOptimization::erf(0.0) - 0.0).abs() < 1e-6);
1624 assert!((BayesianOptimization::erf(1.0) - 0.8427).abs() < 1e-3);
1625 assert!((BayesianOptimization::erf(-1.0) + 0.8427).abs() < 1e-3);
1626 }
1627}