Skip to main content

scirs2_cluster/tuning/
bayesian_optimization.rs

1//! Bayesian optimization implementation for hyperparameter tuning
2//!
3//! This module provides Gaussian Process-based Bayesian optimization
4//! for efficient hyperparameter search.
5
6use scirs2_core::ndarray::Array2;
7use scirs2_core::random::{rng, Rng, RngExt, SeedableRng};
8use std::collections::HashMap;
9
10use crate::error::{ClusteringError, Result};
11
12use super::config::*;
13
14/// Bayesian optimizer using Gaussian Processes
15pub struct BayesianOptimizer {
16    parameter_names: Vec<String>,
17    acquisition_function: AcquisitionFunction,
18    bayesian_state: BayesianState,
19    random_seed: Option<u64>,
20}
21
22impl BayesianOptimizer {
23    /// Create a new Bayesian optimizer
24    pub fn new(
25        parameter_names: Vec<String>,
26        acquisition_function: AcquisitionFunction,
27        random_seed: Option<u64>,
28    ) -> Self {
29        let bayesian_state = BayesianState {
30            observations: Vec::new(),
31            gp_mean: None,
32            gp_covariance: None,
33            acquisition_values: Vec::new(),
34            parameter_names: parameter_names.clone(),
35            gp_hyperparameters: GpHyperparameters {
36                length_scales: vec![1.0; parameter_names.len()],
37                signal_variance: 1.0,
38                noise_variance: 0.1,
39                kernel_type: KernelType::RBF { length_scale: 1.0 },
40            },
41            noise_level: 0.1,
42            currentbest: f64::NEG_INFINITY,
43        };
44
45        Self {
46            parameter_names,
47            acquisition_function,
48            bayesian_state,
49            random_seed,
50        }
51    }
52
53    /// Update observations with new parameter combinations
54    pub fn update_observations(&mut self, combinations: &[HashMap<String, f64>]) {
55        if combinations.is_empty() {
56            return;
57        }
58
59        let n_samples = combinations.len();
60        let _n_features = self.parameter_names.len();
61
62        if n_samples < 2 {
63            return;
64        }
65
66        self.optimize_gp_hyperparameters(combinations);
67        self.build_covariance_matrix(combinations);
68    }
69
70    /// Optimize acquisition function to find next evaluation point
71    pub fn optimize_acquisition_function(
72        &self,
73        search_space: &SearchSpace,
74    ) -> Result<HashMap<String, f64>> {
75        let mut best_acquisition = f64::NEG_INFINITY;
76        let mut best_point = HashMap::new();
77
78        let n_candidates = 1000;
79        let candidates = self.generate_random_candidates(search_space, n_candidates)?;
80
81        for candidate in candidates {
82            let acquisition_value = self.evaluate_acquisition_function(&candidate);
83
84            if acquisition_value > best_acquisition {
85                best_acquisition = acquisition_value;
86                best_point = candidate;
87            }
88        }
89
90        Ok(best_point)
91    }
92
93    /// Generate random candidate points for acquisition optimization
94    fn generate_random_candidates(
95        &self,
96        search_space: &SearchSpace,
97        n_candidates: usize,
98    ) -> Result<Vec<HashMap<String, f64>>> {
99        let mut candidates = Vec::new();
100        let mut rng = match self.random_seed {
101            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
102            None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
103        };
104
105        for _ in 0..n_candidates {
106            let mut candidate = HashMap::new();
107
108            for (name, param) in &search_space.parameters {
109                let value = match param {
110                    HyperParameter::Integer { min, max } => rng.random_range(*min..=*max) as f64,
111                    HyperParameter::Float { min, max } => rng.random_range(*min..=*max),
112                    HyperParameter::Categorical { choices } => {
113                        rng.random_range(0..choices.len()) as f64
114                    }
115                    HyperParameter::Boolean => {
116                        if rng.random_range(0.0..1.0) < 0.5 {
117                            1.0
118                        } else {
119                            0.0
120                        }
121                    }
122                    HyperParameter::LogUniform { min, max } => {
123                        let log_min = min.ln();
124                        let log_max = max.ln();
125                        let log_value = rng.random_range(log_min..=log_max);
126                        log_value.exp()
127                    }
128                    HyperParameter::IntegerChoices { choices } => {
129                        let idx = rng.random_range(0..choices.len());
130                        choices[idx] as f64
131                    }
132                };
133
134                candidate.insert(name.clone(), value);
135            }
136
137            candidates.push(candidate);
138        }
139
140        Ok(candidates)
141    }
142
143    /// Evaluate acquisition function at a point
144    fn evaluate_acquisition_function(&self, point: &HashMap<String, f64>) -> f64 {
145        let x = self.extract_feature_vector(point);
146        let (mean, variance) = self.predict_gp(&x);
147        let std_dev = variance.sqrt();
148
149        match &self.acquisition_function {
150            AcquisitionFunction::ExpectedImprovement => {
151                self.expected_improvement(mean, std_dev, self.bayesian_state.currentbest)
152            }
153            AcquisitionFunction::UpperConfidenceBound { beta } => mean + beta * std_dev,
154            AcquisitionFunction::ProbabilityOfImprovement => {
155                self.probability_of_improvement(mean, std_dev, self.bayesian_state.currentbest)
156            }
157            AcquisitionFunction::EntropySearch => -variance * (variance.ln()),
158            AcquisitionFunction::KnowledgeGradient => std_dev * (1.0 / (1.0 + variance)),
159            AcquisitionFunction::ThompsonSampling => {
160                let mut rng = scirs2_core::random::rng();
161                let sample: f64 = rng.random_range(0.0..1.0);
162                mean + std_dev * self.inverse_normal_cdf(sample)
163            }
164        }
165    }
166
167    /// Expected Improvement acquisition function
168    fn expected_improvement(&self, mean: f64, std_dev: f64, currentbest: f64) -> f64 {
169        if std_dev <= 1e-10 {
170            return 0.0;
171        }
172
173        let improvement = mean - currentbest;
174        let z = improvement / std_dev;
175
176        improvement * self.normal_cdf(z) + std_dev * self.normal_pdf(z)
177    }
178
179    /// Probability of Improvement acquisition function
180    fn probability_of_improvement(&self, mean: f64, std_dev: f64, currentbest: f64) -> f64 {
181        if std_dev <= 1e-10 {
182            return if mean > currentbest { 1.0 } else { 0.0 };
183        }
184
185        let z = (mean - currentbest) / std_dev;
186        self.normal_cdf(z)
187    }
188
189    /// Gaussian Process prediction
190    fn predict_gp(&self, x: &[f64]) -> (f64, f64) {
191        if self.bayesian_state.observations.is_empty() {
192            return (0.0, 1.0);
193        }
194
195        let mut mean = 0.0;
196        let mut variance = 1.0;
197
198        let mut total_weight = 0.0;
199        for (params, score) in &self.bayesian_state.observations {
200            let x_obs = self.extract_feature_vector(params);
201            let similarity = self.compute_kernel(x, &x_obs);
202            mean += similarity * score;
203            total_weight += similarity;
204        }
205
206        if total_weight > 1e-10 {
207            mean /= total_weight;
208            variance = 1.0 - total_weight.min(1.0);
209        }
210
211        (mean, variance.max(1e-6))
212    }
213
214    /// Compute kernel function
215    fn compute_kernel(&self, x1: &[f64], x2: &[f64]) -> f64 {
216        match &self.bayesian_state.gp_hyperparameters.kernel_type {
217            KernelType::RBF { length_scale } => {
218                let squared_distance: f64 =
219                    x1.iter().zip(x2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
220                self.bayesian_state.gp_hyperparameters.signal_variance
221                    * (-squared_distance / (2.0 * length_scale.powi(2))).exp()
222            }
223            KernelType::Matern { length_scale, nu } => {
224                let distance: f64 = x1
225                    .iter()
226                    .zip(x2.iter())
227                    .map(|(a, b)| (a - b).powi(2))
228                    .sum::<f64>()
229                    .sqrt();
230
231                if distance == 0.0 {
232                    self.bayesian_state.gp_hyperparameters.signal_variance
233                } else {
234                    let scaled_distance = (2.0 * nu).sqrt() * distance / length_scale;
235                    let bessel_term = if nu == &0.5 {
236                        (-scaled_distance).exp()
237                    } else if nu == &1.5 {
238                        (1.0 + scaled_distance) * (-scaled_distance).exp()
239                    } else {
240                        (-scaled_distance).exp()
241                    };
242                    self.bayesian_state.gp_hyperparameters.signal_variance * bessel_term
243                }
244            }
245            KernelType::Linear => {
246                let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
247                self.bayesian_state.gp_hyperparameters.signal_variance * dot_product
248            }
249            KernelType::Polynomial { degree } => {
250                let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
251                self.bayesian_state.gp_hyperparameters.signal_variance
252                    * (1.0 + dot_product).powf(*degree as f64)
253            }
254        }
255    }
256
257    /// Optimize GP hyperparameters using maximum likelihood
258    fn optimize_gp_hyperparameters(&mut self, combinations: &[HashMap<String, f64>]) {
259        if combinations.len() < 3 {
260            return;
261        }
262
263        for (i, param_name) in self.parameter_names.iter().enumerate() {
264            let values: Vec<f64> = combinations
265                .iter()
266                .filter_map(|c| c.get(param_name))
267                .copied()
268                .collect();
269
270            if !values.is_empty() {
271                let mean = values.iter().sum::<f64>() / values.len() as f64;
272                let variance =
273                    values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
274
275                if i < self.bayesian_state.gp_hyperparameters.length_scales.len() {
276                    self.bayesian_state.gp_hyperparameters.length_scales[i] =
277                        variance.sqrt().max(0.1);
278                }
279            }
280        }
281
282        if !self.bayesian_state.observations.is_empty() {
283            let scores: Vec<f64> = self
284                .bayesian_state
285                .observations
286                .iter()
287                .map(|(_, s)| *s)
288                .collect();
289            let score_mean = scores.iter().sum::<f64>() / scores.len() as f64;
290            let score_variance =
291                scores.iter().map(|s| (s - score_mean).powi(2)).sum::<f64>() / scores.len() as f64;
292
293            self.bayesian_state.gp_hyperparameters.signal_variance = score_variance.max(0.1);
294            self.bayesian_state.gp_hyperparameters.noise_variance =
295                (score_variance * 0.1).max(0.01);
296        }
297    }
298
299    /// Build covariance matrix for Gaussian Process
300    fn build_covariance_matrix(&mut self, combinations: &[HashMap<String, f64>]) {
301        let n_samples = combinations.len();
302        let mut covariance = Array2::zeros((n_samples, n_samples));
303
304        for i in 0..n_samples {
305            for j in 0..n_samples {
306                let x_i = self.extract_feature_vector(&combinations[i]);
307                let x_j = self.extract_feature_vector(&combinations[j]);
308                covariance[[i, j]] = self.compute_kernel(&x_i, &x_j);
309            }
310        }
311
312        for i in 0..n_samples {
313            covariance[[i, i]] += self.bayesian_state.gp_hyperparameters.noise_variance;
314        }
315
316        self.bayesian_state.gp_covariance = Some(covariance);
317    }
318
319    /// Extract feature vector from parameter map
320    fn extract_feature_vector(&self, params: &HashMap<String, f64>) -> Vec<f64> {
321        self.parameter_names
322            .iter()
323            .map(|name| params.get(name).copied().unwrap_or(0.0))
324            .collect()
325    }
326
327    /// Standard normal CDF approximation
328    fn normal_cdf(&self, x: f64) -> f64 {
329        0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
330    }
331
332    /// Standard normal PDF
333    fn normal_pdf(&self, x: f64) -> f64 {
334        (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt()
335    }
336
337    /// Error function approximation
338    fn erf(&self, x: f64) -> f64 {
339        let a1 = 0.254829592;
340        let a2 = -0.284496736;
341        let a3 = 1.421413741;
342        let a4 = -1.453152027;
343        let a5 = 1.061405429;
344        let p = 0.3275911;
345
346        let sign = if x < 0.0 { -1.0 } else { 1.0 };
347        let x = x.abs();
348
349        let t = 1.0 / (1.0 + p * x);
350        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
351
352        sign * y
353    }
354
355    /// Inverse normal CDF approximation
356    fn inverse_normal_cdf(&self, p: f64) -> f64 {
357        if p <= 0.0 {
358            return f64::NEG_INFINITY;
359        }
360        if p >= 1.0 {
361            return f64::INFINITY;
362        }
363        if (p - 0.5).abs() < 1e-10 {
364            return 0.0;
365        }
366
367        let a0 = -3.969683028665376e+01;
368        let a1 = 2.209460984245205e+02;
369        let a2 = -2.759285104469687e+02;
370        let a3 = 1.383_577_518_672_69e2;
371        let a4 = -3.066479806614716e+01;
372        let a5 = 2.506628277459239e+00;
373
374        let b1 = -5.447609879822406e+01;
375        let b2 = 1.615858368580409e+02;
376        let b3 = -1.556989798598866e+02;
377        let b4 = 6.680131188771972e+01;
378        let b5 = -1.328068155288572e+01;
379
380        let c0 = -7.784894002430293e-03;
381        let c1 = -3.223964580411365e-01;
382        let c2 = -2.400758277161838e+00;
383        let c3 = -2.549732539343734e+00;
384        let c4 = 4.374664141464968e+00;
385        let c5 = 2.938163982698783e+00;
386
387        let d1 = 7.784695709041462e-03;
388        let d2 = 3.224671290700398e-01;
389        let d3 = 2.445134137142996e+00;
390        let d4 = 3.754408661907416e+00;
391
392        let p_low = 0.02425;
393        let p_high = 1.0 - p_low;
394
395        if p < p_low {
396            let q = (-2.0 * p.ln()).sqrt();
397            return (((((c0 * q + c1) * q + c2) * q + c3) * q + c4) * q + c5)
398                / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0);
399        }
400
401        if p <= p_high {
402            let q = p - 0.5;
403            let r = q * q;
404            return (((((a0 * r + a1) * r + a2) * r + a3) * r + a4) * r + a5) * q
405                / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1.0);
406        }
407
408        let q = (-2.0 * (1.0 - p).ln()).sqrt();
409        -(((((c0 * q + c1) * q + c2) * q + c3) * q + c4) * q + c5)
410            / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0)
411    }
412}