1use scirs2_core::ndarray::Array2;
7use scirs2_core::random::{rng, Rng, RngExt, SeedableRng};
8use std::collections::HashMap;
9
10use crate::error::{ClusteringError, Result};
11
12use super::config::*;
13
14pub struct BayesianOptimizer {
16 parameter_names: Vec<String>,
17 acquisition_function: AcquisitionFunction,
18 bayesian_state: BayesianState,
19 random_seed: Option<u64>,
20}
21
22impl BayesianOptimizer {
23 pub fn new(
25 parameter_names: Vec<String>,
26 acquisition_function: AcquisitionFunction,
27 random_seed: Option<u64>,
28 ) -> Self {
29 let bayesian_state = BayesianState {
30 observations: Vec::new(),
31 gp_mean: None,
32 gp_covariance: None,
33 acquisition_values: Vec::new(),
34 parameter_names: parameter_names.clone(),
35 gp_hyperparameters: GpHyperparameters {
36 length_scales: vec![1.0; parameter_names.len()],
37 signal_variance: 1.0,
38 noise_variance: 0.1,
39 kernel_type: KernelType::RBF { length_scale: 1.0 },
40 },
41 noise_level: 0.1,
42 currentbest: f64::NEG_INFINITY,
43 };
44
45 Self {
46 parameter_names,
47 acquisition_function,
48 bayesian_state,
49 random_seed,
50 }
51 }
52
53 pub fn update_observations(&mut self, combinations: &[HashMap<String, f64>]) {
55 if combinations.is_empty() {
56 return;
57 }
58
59 let n_samples = combinations.len();
60 let _n_features = self.parameter_names.len();
61
62 if n_samples < 2 {
63 return;
64 }
65
66 self.optimize_gp_hyperparameters(combinations);
67 self.build_covariance_matrix(combinations);
68 }
69
70 pub fn optimize_acquisition_function(
72 &self,
73 search_space: &SearchSpace,
74 ) -> Result<HashMap<String, f64>> {
75 let mut best_acquisition = f64::NEG_INFINITY;
76 let mut best_point = HashMap::new();
77
78 let n_candidates = 1000;
79 let candidates = self.generate_random_candidates(search_space, n_candidates)?;
80
81 for candidate in candidates {
82 let acquisition_value = self.evaluate_acquisition_function(&candidate);
83
84 if acquisition_value > best_acquisition {
85 best_acquisition = acquisition_value;
86 best_point = candidate;
87 }
88 }
89
90 Ok(best_point)
91 }
92
93 fn generate_random_candidates(
95 &self,
96 search_space: &SearchSpace,
97 n_candidates: usize,
98 ) -> Result<Vec<HashMap<String, f64>>> {
99 let mut candidates = Vec::new();
100 let mut rng = match self.random_seed {
101 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
102 None => scirs2_core::random::rngs::StdRng::seed_from_u64(42),
103 };
104
105 for _ in 0..n_candidates {
106 let mut candidate = HashMap::new();
107
108 for (name, param) in &search_space.parameters {
109 let value = match param {
110 HyperParameter::Integer { min, max } => rng.random_range(*min..=*max) as f64,
111 HyperParameter::Float { min, max } => rng.random_range(*min..=*max),
112 HyperParameter::Categorical { choices } => {
113 rng.random_range(0..choices.len()) as f64
114 }
115 HyperParameter::Boolean => {
116 if rng.random_range(0.0..1.0) < 0.5 {
117 1.0
118 } else {
119 0.0
120 }
121 }
122 HyperParameter::LogUniform { min, max } => {
123 let log_min = min.ln();
124 let log_max = max.ln();
125 let log_value = rng.random_range(log_min..=log_max);
126 log_value.exp()
127 }
128 HyperParameter::IntegerChoices { choices } => {
129 let idx = rng.random_range(0..choices.len());
130 choices[idx] as f64
131 }
132 };
133
134 candidate.insert(name.clone(), value);
135 }
136
137 candidates.push(candidate);
138 }
139
140 Ok(candidates)
141 }
142
143 fn evaluate_acquisition_function(&self, point: &HashMap<String, f64>) -> f64 {
145 let x = self.extract_feature_vector(point);
146 let (mean, variance) = self.predict_gp(&x);
147 let std_dev = variance.sqrt();
148
149 match &self.acquisition_function {
150 AcquisitionFunction::ExpectedImprovement => {
151 self.expected_improvement(mean, std_dev, self.bayesian_state.currentbest)
152 }
153 AcquisitionFunction::UpperConfidenceBound { beta } => mean + beta * std_dev,
154 AcquisitionFunction::ProbabilityOfImprovement => {
155 self.probability_of_improvement(mean, std_dev, self.bayesian_state.currentbest)
156 }
157 AcquisitionFunction::EntropySearch => -variance * (variance.ln()),
158 AcquisitionFunction::KnowledgeGradient => std_dev * (1.0 / (1.0 + variance)),
159 AcquisitionFunction::ThompsonSampling => {
160 let mut rng = scirs2_core::random::rng();
161 let sample: f64 = rng.random_range(0.0..1.0);
162 mean + std_dev * self.inverse_normal_cdf(sample)
163 }
164 }
165 }
166
167 fn expected_improvement(&self, mean: f64, std_dev: f64, currentbest: f64) -> f64 {
169 if std_dev <= 1e-10 {
170 return 0.0;
171 }
172
173 let improvement = mean - currentbest;
174 let z = improvement / std_dev;
175
176 improvement * self.normal_cdf(z) + std_dev * self.normal_pdf(z)
177 }
178
179 fn probability_of_improvement(&self, mean: f64, std_dev: f64, currentbest: f64) -> f64 {
181 if std_dev <= 1e-10 {
182 return if mean > currentbest { 1.0 } else { 0.0 };
183 }
184
185 let z = (mean - currentbest) / std_dev;
186 self.normal_cdf(z)
187 }
188
189 fn predict_gp(&self, x: &[f64]) -> (f64, f64) {
191 if self.bayesian_state.observations.is_empty() {
192 return (0.0, 1.0);
193 }
194
195 let mut mean = 0.0;
196 let mut variance = 1.0;
197
198 let mut total_weight = 0.0;
199 for (params, score) in &self.bayesian_state.observations {
200 let x_obs = self.extract_feature_vector(params);
201 let similarity = self.compute_kernel(x, &x_obs);
202 mean += similarity * score;
203 total_weight += similarity;
204 }
205
206 if total_weight > 1e-10 {
207 mean /= total_weight;
208 variance = 1.0 - total_weight.min(1.0);
209 }
210
211 (mean, variance.max(1e-6))
212 }
213
214 fn compute_kernel(&self, x1: &[f64], x2: &[f64]) -> f64 {
216 match &self.bayesian_state.gp_hyperparameters.kernel_type {
217 KernelType::RBF { length_scale } => {
218 let squared_distance: f64 =
219 x1.iter().zip(x2.iter()).map(|(a, b)| (a - b).powi(2)).sum();
220 self.bayesian_state.gp_hyperparameters.signal_variance
221 * (-squared_distance / (2.0 * length_scale.powi(2))).exp()
222 }
223 KernelType::Matern { length_scale, nu } => {
224 let distance: f64 = x1
225 .iter()
226 .zip(x2.iter())
227 .map(|(a, b)| (a - b).powi(2))
228 .sum::<f64>()
229 .sqrt();
230
231 if distance == 0.0 {
232 self.bayesian_state.gp_hyperparameters.signal_variance
233 } else {
234 let scaled_distance = (2.0 * nu).sqrt() * distance / length_scale;
235 let bessel_term = if nu == &0.5 {
236 (-scaled_distance).exp()
237 } else if nu == &1.5 {
238 (1.0 + scaled_distance) * (-scaled_distance).exp()
239 } else {
240 (-scaled_distance).exp()
241 };
242 self.bayesian_state.gp_hyperparameters.signal_variance * bessel_term
243 }
244 }
245 KernelType::Linear => {
246 let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
247 self.bayesian_state.gp_hyperparameters.signal_variance * dot_product
248 }
249 KernelType::Polynomial { degree } => {
250 let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
251 self.bayesian_state.gp_hyperparameters.signal_variance
252 * (1.0 + dot_product).powf(*degree as f64)
253 }
254 }
255 }
256
257 fn optimize_gp_hyperparameters(&mut self, combinations: &[HashMap<String, f64>]) {
259 if combinations.len() < 3 {
260 return;
261 }
262
263 for (i, param_name) in self.parameter_names.iter().enumerate() {
264 let values: Vec<f64> = combinations
265 .iter()
266 .filter_map(|c| c.get(param_name))
267 .copied()
268 .collect();
269
270 if !values.is_empty() {
271 let mean = values.iter().sum::<f64>() / values.len() as f64;
272 let variance =
273 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
274
275 if i < self.bayesian_state.gp_hyperparameters.length_scales.len() {
276 self.bayesian_state.gp_hyperparameters.length_scales[i] =
277 variance.sqrt().max(0.1);
278 }
279 }
280 }
281
282 if !self.bayesian_state.observations.is_empty() {
283 let scores: Vec<f64> = self
284 .bayesian_state
285 .observations
286 .iter()
287 .map(|(_, s)| *s)
288 .collect();
289 let score_mean = scores.iter().sum::<f64>() / scores.len() as f64;
290 let score_variance =
291 scores.iter().map(|s| (s - score_mean).powi(2)).sum::<f64>() / scores.len() as f64;
292
293 self.bayesian_state.gp_hyperparameters.signal_variance = score_variance.max(0.1);
294 self.bayesian_state.gp_hyperparameters.noise_variance =
295 (score_variance * 0.1).max(0.01);
296 }
297 }
298
299 fn build_covariance_matrix(&mut self, combinations: &[HashMap<String, f64>]) {
301 let n_samples = combinations.len();
302 let mut covariance = Array2::zeros((n_samples, n_samples));
303
304 for i in 0..n_samples {
305 for j in 0..n_samples {
306 let x_i = self.extract_feature_vector(&combinations[i]);
307 let x_j = self.extract_feature_vector(&combinations[j]);
308 covariance[[i, j]] = self.compute_kernel(&x_i, &x_j);
309 }
310 }
311
312 for i in 0..n_samples {
313 covariance[[i, i]] += self.bayesian_state.gp_hyperparameters.noise_variance;
314 }
315
316 self.bayesian_state.gp_covariance = Some(covariance);
317 }
318
319 fn extract_feature_vector(&self, params: &HashMap<String, f64>) -> Vec<f64> {
321 self.parameter_names
322 .iter()
323 .map(|name| params.get(name).copied().unwrap_or(0.0))
324 .collect()
325 }
326
327 fn normal_cdf(&self, x: f64) -> f64 {
329 0.5 * (1.0 + self.erf(x / 2.0_f64.sqrt()))
330 }
331
332 fn normal_pdf(&self, x: f64) -> f64 {
334 (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt()
335 }
336
337 fn erf(&self, x: f64) -> f64 {
339 let a1 = 0.254829592;
340 let a2 = -0.284496736;
341 let a3 = 1.421413741;
342 let a4 = -1.453152027;
343 let a5 = 1.061405429;
344 let p = 0.3275911;
345
346 let sign = if x < 0.0 { -1.0 } else { 1.0 };
347 let x = x.abs();
348
349 let t = 1.0 / (1.0 + p * x);
350 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
351
352 sign * y
353 }
354
355 fn inverse_normal_cdf(&self, p: f64) -> f64 {
357 if p <= 0.0 {
358 return f64::NEG_INFINITY;
359 }
360 if p >= 1.0 {
361 return f64::INFINITY;
362 }
363 if (p - 0.5).abs() < 1e-10 {
364 return 0.0;
365 }
366
367 let a0 = -3.969683028665376e+01;
368 let a1 = 2.209460984245205e+02;
369 let a2 = -2.759285104469687e+02;
370 let a3 = 1.383_577_518_672_69e2;
371 let a4 = -3.066479806614716e+01;
372 let a5 = 2.506628277459239e+00;
373
374 let b1 = -5.447609879822406e+01;
375 let b2 = 1.615858368580409e+02;
376 let b3 = -1.556989798598866e+02;
377 let b4 = 6.680131188771972e+01;
378 let b5 = -1.328068155288572e+01;
379
380 let c0 = -7.784894002430293e-03;
381 let c1 = -3.223964580411365e-01;
382 let c2 = -2.400758277161838e+00;
383 let c3 = -2.549732539343734e+00;
384 let c4 = 4.374664141464968e+00;
385 let c5 = 2.938163982698783e+00;
386
387 let d1 = 7.784695709041462e-03;
388 let d2 = 3.224671290700398e-01;
389 let d3 = 2.445134137142996e+00;
390 let d4 = 3.754408661907416e+00;
391
392 let p_low = 0.02425;
393 let p_high = 1.0 - p_low;
394
395 if p < p_low {
396 let q = (-2.0 * p.ln()).sqrt();
397 return (((((c0 * q + c1) * q + c2) * q + c3) * q + c4) * q + c5)
398 / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0);
399 }
400
401 if p <= p_high {
402 let q = p - 0.5;
403 let r = q * q;
404 return (((((a0 * r + a1) * r + a2) * r + a3) * r + a4) * r + a5) * q
405 / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1.0);
406 }
407
408 let q = (-2.0 * (1.0 - p).ln()).sqrt();
409 -(((((c0 * q + c1) * q + c2) * q + c3) * q + c4) * q + c5)
410 / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0)
411 }
412}