tensorlogic_train/hyperparameter/
bayesian.rs1use crate::TrainResult;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::{RngExt, SeedableRng, StdRng};
6use std::collections::HashMap;
7
8use super::acquisition::AcquisitionFunction;
9use super::gp::GaussianProcess;
10use super::kernel::GpKernel;
11use super::space::HyperparamSpace;
12use super::value::{HyperparamConfig, HyperparamResult, HyperparamValue};
13
14#[derive(Debug)]
48pub struct BayesianOptimization {
49 param_space: HashMap<String, HyperparamSpace>,
51 n_iterations: usize,
53 n_initial_points: usize,
55 acquisition_fn: AcquisitionFunction,
57 kernel: GpKernel,
59 noise_variance: f64,
61 rng: StdRng,
63 results: Vec<HyperparamResult>,
65 bounds: Vec<(f64, f64)>,
67 param_names: Vec<String>,
69}
70
71impl BayesianOptimization {
72 pub fn new(
80 param_space: HashMap<String, HyperparamSpace>,
81 n_iterations: usize,
82 n_initial_points: usize,
83 seed: u64,
84 ) -> Self {
85 let mut param_names: Vec<String> = param_space.keys().cloned().collect();
86 param_names.sort();
87 let bounds = Self::extract_bounds(¶m_space, ¶m_names);
88 Self {
89 param_space,
90 n_iterations,
91 n_initial_points,
92 acquisition_fn: AcquisitionFunction::default(),
93 kernel: GpKernel::default(),
94 noise_variance: 1e-6,
95 rng: StdRng::seed_from_u64(seed),
96 results: Vec::new(),
97 bounds,
98 param_names,
99 }
100 }
101
102 pub fn with_acquisition(mut self, acquisition_fn: AcquisitionFunction) -> Self {
104 self.acquisition_fn = acquisition_fn;
105 self
106 }
107
108 pub fn with_kernel(mut self, kernel: GpKernel) -> Self {
110 self.kernel = kernel;
111 self
112 }
113
114 pub fn with_noise(mut self, noise_variance: f64) -> Self {
116 self.noise_variance = noise_variance;
117 self
118 }
119
120 fn extract_bounds(
122 param_space: &HashMap<String, HyperparamSpace>,
123 param_names: &[String],
124 ) -> Vec<(f64, f64)> {
125 param_names
126 .iter()
127 .map(|name| match ¶m_space[name] {
128 HyperparamSpace::Continuous { min, max } => (*min, *max),
129 HyperparamSpace::LogUniform { min, max } => (min.ln(), max.ln()),
130 HyperparamSpace::IntRange { min, max } => (*min as f64, *max as f64),
131 HyperparamSpace::Discrete(values) => (0.0, (values.len() - 1) as f64),
132 })
133 .collect()
134 }
135
136 pub fn suggest(&mut self) -> TrainResult<HyperparamConfig> {
138 if self.results.len() < self.n_initial_points {
139 return Ok(self.random_sample());
140 }
141 let (x_observed, y_observed) = self.get_observations();
142 let mut gp = GaussianProcess::new(self.kernel, self.noise_variance);
143 gp.fit(x_observed, y_observed)?;
144 let best_x = self.optimize_acquisition(&gp)?;
145 self.vector_to_config(&best_x)
146 }
147
148 fn get_observations(&self) -> (Array2<f64>, Array1<f64>) {
150 let n_samples = self.results.len();
151 let n_dims = self.param_names.len();
152 let mut x = Array2::zeros((n_samples, n_dims));
153 let mut y = Array1::zeros(n_samples);
154 for (i, result) in self.results.iter().enumerate() {
155 let x_vec = self.config_to_vector(&result.config);
156 for (j, &val) in x_vec.iter().enumerate() {
157 x[[i, j]] = val;
158 }
159 y[i] = result.score;
160 }
161 (x, y)
162 }
163
164 fn optimize_acquisition(&mut self, gp: &GaussianProcess) -> TrainResult<Array1<f64>> {
166 let n_dims = self.param_names.len();
167 let n_candidates = 1000;
168 let n_restarts = 10;
169 let mut best_acq_value = f64::NEG_INFINITY;
170 let mut best_x = Array1::zeros(n_dims);
171 for _ in 0..n_restarts {
172 for _ in 0..(n_candidates / n_restarts) {
173 let mut x_candidate = Array1::zeros(n_dims);
174 for (i, (min, max)) in self.bounds.iter().enumerate() {
175 x_candidate[i] = min + (max - min) * self.rng.random::<f64>();
176 }
177 let acq_value = self.evaluate_acquisition(gp, &x_candidate)?;
178 if acq_value > best_acq_value {
179 best_acq_value = acq_value;
180 best_x = x_candidate;
181 }
182 }
183 }
184 Ok(best_x)
185 }
186
187 fn evaluate_acquisition(&self, gp: &GaussianProcess, x: &Array1<f64>) -> TrainResult<f64> {
189 let x_mat = x
190 .clone()
191 .into_shape_with_order((1, x.len()))
192 .expect("shape and data length match");
193 let (mean, std) = gp.predict(&x_mat)?;
194 let mu = mean[0];
195 let sigma = std[0];
196 if sigma < 1e-10 {
197 return Ok(0.0);
198 }
199 let f_best = self
200 .results
201 .iter()
202 .map(|r| r.score)
203 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
204 .unwrap_or(0.0);
205 let acq = match self.acquisition_fn {
206 AcquisitionFunction::ExpectedImprovement { xi } => {
207 let z = (mu - f_best - xi) / sigma;
208 let phi = Self::normal_cdf(z);
209 let pdf = Self::normal_pdf(z);
210 (mu - f_best - xi) * phi + sigma * pdf
211 }
212 AcquisitionFunction::UpperConfidenceBound { kappa } => mu + kappa * sigma,
213 AcquisitionFunction::ProbabilityOfImprovement { xi } => {
214 let z = (mu - f_best - xi) / sigma;
215 Self::normal_cdf(z)
216 }
217 };
218 Ok(acq)
219 }
220
221 pub(super) fn normal_cdf(x: f64) -> f64 {
223 0.5 * (1.0 + Self::erf(x / 2.0_f64.sqrt()))
224 }
225
226 pub(super) fn normal_pdf(x: f64) -> f64 {
228 (-0.5 * x.powi(2)).exp() / (2.0 * std::f64::consts::PI).sqrt()
229 }
230
231 pub(super) fn erf(x: f64) -> f64 {
233 let a1 = 0.254829592;
234 let a2 = -0.284496736;
235 let a3 = 1.421413741;
236 let a4 = -1.453152027;
237 let a5 = 1.061405429;
238 let p = 0.3275911;
239 let sign = if x < 0.0 { -1.0 } else { 1.0 };
240 let x = x.abs();
241 let t = 1.0 / (1.0 + p * x);
242 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
243 sign * y
244 }
245
246 fn config_to_vector(&self, config: &HyperparamConfig) -> Array1<f64> {
248 let n_dims = self.param_names.len();
249 let mut x = Array1::zeros(n_dims);
250 for (i, name) in self.param_names.iter().enumerate() {
251 let value = &config[name];
252 let (min, max) = self.bounds[i];
253 x[i] = match &self.param_space[name] {
254 HyperparamSpace::Continuous { .. } => {
255 let v = value
256 .as_float()
257 .expect("Continuous space requires float value");
258 (v - min) / (max - min)
259 }
260 HyperparamSpace::LogUniform { .. } => {
261 let v = value
262 .as_float()
263 .expect("LogUniform space requires float value");
264 let log_v = v.ln();
265 (log_v - min) / (max - min)
266 }
267 HyperparamSpace::IntRange { .. } => {
268 let v = value.as_int().expect("IntRange space requires int value") as f64;
269 (v - min) / (max - min)
270 }
271 HyperparamSpace::Discrete(values) => {
272 let idx = values.iter().position(|v| v == value).unwrap_or(0);
273 (idx as f64 - min) / (max - min)
274 }
275 };
276 }
277 x
278 }
279
280 fn vector_to_config(&self, x: &Array1<f64>) -> TrainResult<HyperparamConfig> {
282 let mut config = HashMap::new();
283 for (i, name) in self.param_names.iter().enumerate() {
284 let normalized = x[i].clamp(0.0, 1.0);
285 let (min, max) = self.bounds[i];
286 let value_raw = min + normalized * (max - min);
287 let value = match &self.param_space[name] {
288 HyperparamSpace::Continuous { .. } => HyperparamValue::Float(value_raw),
289 HyperparamSpace::LogUniform { .. } => HyperparamValue::Float(value_raw.exp()),
290 HyperparamSpace::IntRange { .. } => HyperparamValue::Int(value_raw.round() as i64),
291 HyperparamSpace::Discrete(values) => {
292 let idx = value_raw.round() as usize;
293 values[idx.min(values.len() - 1)].clone()
294 }
295 };
296 config.insert(name.clone(), value);
297 }
298 Ok(config)
299 }
300
301 fn random_sample(&mut self) -> HyperparamConfig {
303 let mut config = HashMap::new();
304 for (name, space) in &self.param_space {
305 let value = space.sample(&mut self.rng);
306 config.insert(name.clone(), value);
307 }
308 config
309 }
310
311 pub fn add_result(&mut self, result: HyperparamResult) {
313 self.results.push(result);
314 }
315
316 pub fn best_result(&self) -> Option<&HyperparamResult> {
318 self.results.iter().max_by(|a, b| {
319 a.score
320 .partial_cmp(&b.score)
321 .unwrap_or(std::cmp::Ordering::Equal)
322 })
323 }
324
325 pub fn sorted_results(&self) -> Vec<&HyperparamResult> {
327 let mut results: Vec<&HyperparamResult> = self.results.iter().collect();
328 results.sort_by(|a, b| {
329 b.score
330 .partial_cmp(&a.score)
331 .unwrap_or(std::cmp::Ordering::Equal)
332 });
333 results
334 }
335
336 pub fn results(&self) -> &[HyperparamResult] {
338 &self.results
339 }
340
341 pub fn is_complete(&self) -> bool {
343 self.results.len() >= self.n_iterations + self.n_initial_points
344 }
345
346 pub fn current_iteration(&self) -> usize {
348 self.results.len()
349 }
350
351 pub fn total_budget(&self) -> usize {
353 self.n_iterations + self.n_initial_points
354 }
355}