1use crate::error::InterpolateError;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ActiveAcquisitionFunction {
31 MaximumVariance,
33 ExpectedImprovement,
35 LeverageScore,
38}
39
40#[derive(Debug, Clone)]
46pub struct ActiveSamplerConfig {
47 pub acquisition: ActiveAcquisitionFunction,
49 pub n_candidates: usize,
51 pub domain: Vec<[f64; 2]>,
53 pub seed: u64,
55}
56
57impl Default for ActiveSamplerConfig {
58 fn default() -> Self {
59 Self {
60 acquisition: ActiveAcquisitionFunction::MaximumVariance,
61 n_candidates: 64,
62 domain: vec![[0.0, 1.0], [0.0, 1.0]],
63 seed: 42,
64 }
65 }
66}
67
68#[derive(Debug)]
96pub struct ActiveSampler {
97 config: ActiveSamplerConfig,
98 observed_points: Vec<Vec<f64>>,
99 observed_values: Vec<f64>,
100 n_dims: usize,
101}
102
103impl ActiveSampler {
104 pub fn new(config: ActiveSamplerConfig) -> Self {
107 let n_dims = config.domain.len().max(1);
108 Self {
109 config,
110 observed_points: Vec::new(),
111 observed_values: Vec::new(),
112 n_dims,
113 }
114 }
115
116 pub fn suggest_next(&self) -> Vec<f64> {
121 let mut rng = XorShift64::new(self.config.seed.wrapping_add(self.n_observed() as u64));
122 let candidates =
123 generate_candidates(&self.config.domain, self.config.n_candidates, &mut rng);
124
125 if candidates.is_empty() {
126 return self
128 .config
129 .domain
130 .iter()
131 .map(|&[lo, hi]| 0.5 * (lo + hi))
132 .collect();
133 }
134
135 let best = candidates.iter().cloned().enumerate().fold(
137 (0usize, f64::NEG_INFINITY),
138 |(bi, bv), (i, ref cand)| {
139 let score = self.acquisition_value(cand);
140 if score > bv {
141 (i, score)
142 } else {
143 (bi, bv)
144 }
145 },
146 );
147
148 candidates.into_iter().nth(best.0).unwrap_or_else(|| {
149 self.config
150 .domain
151 .iter()
152 .map(|&[lo, hi]| 0.5 * (lo + hi))
153 .collect()
154 })
155 }
156
157 pub fn observe(&mut self, point: Vec<f64>, value: f64) {
159 self.observed_points.push(point);
160 self.observed_values.push(value);
161 }
162
163 pub fn acquisition_value(&self, point: &[f64]) -> f64 {
167 if self.observed_points.is_empty() {
168 return 1.0; }
170 match self.config.acquisition {
171 ActiveAcquisitionFunction::MaximumVariance => {
172 gp_posterior_variance(&self.observed_points, &self.observed_values, point, 1e-6)
173 }
174 ActiveAcquisitionFunction::ExpectedImprovement => {
175 expected_improvement(&self.observed_points, &self.observed_values, point, 1e-6)
176 }
177 ActiveAcquisitionFunction::LeverageScore => {
178 leverage_score(&self.observed_points, point, 1e-6)
179 }
180 }
181 }
182
183 pub fn loo_error(&self) -> f64 {
190 let n = self.observed_points.len();
191 if n < 2 {
192 return 0.0;
193 }
194 let mut sum_sq = 0.0_f64;
195 for leave_out in 0..n {
196 let rem_pts: Vec<Vec<f64>> = self
198 .observed_points
199 .iter()
200 .enumerate()
201 .filter(|(i, _)| *i != leave_out)
202 .map(|(_, p)| p.clone())
203 .collect();
204 let rem_vals: Vec<f64> = self
205 .observed_values
206 .iter()
207 .enumerate()
208 .filter(|(i, _)| *i != leave_out)
209 .map(|(_, &v)| v)
210 .collect();
211
212 let pred =
214 gp_posterior_mean(&rem_pts, &rem_vals, &self.observed_points[leave_out], 1e-6);
215 let err = pred - self.observed_values[leave_out];
216 sum_sq += err * err;
217 }
218 (sum_sq / n as f64).sqrt()
219 }
220
221 pub fn n_observed(&self) -> usize {
223 self.observed_points.len()
224 }
225
226 pub fn observed_points(&self) -> &[Vec<f64>] {
228 &self.observed_points
229 }
230
231 pub fn n_dims(&self) -> usize {
233 self.n_dims
234 }
235}
236
237pub fn rbf_kernel_sq(x1: &[f64], x2: &[f64], length_scale: f64) -> f64 {
243 let sq_dist: f64 = x1
244 .iter()
245 .zip(x2.iter())
246 .map(|(&a, &b)| (a - b) * (a - b))
247 .sum();
248 (-sq_dist / (2.0 * length_scale * length_scale)).exp()
249}
250
251pub fn gp_posterior_variance(
256 obs_points: &[Vec<f64>],
257 obs_vals: &[f64],
258 query: &[f64],
259 nugget: f64,
260) -> f64 {
261 let n = obs_points.len();
262 if n == 0 {
263 return 1.0;
264 }
265 let ls = auto_length_scale(obs_points);
266
267 let k_star: Vec<f64> = obs_points
269 .iter()
270 .map(|p| rbf_kernel_sq(query, p, ls))
271 .collect();
272
273 let k_mat = build_kernel_matrix(obs_points, ls, nugget);
275
276 let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
278 Ok(a) => a,
279 Err(_) => return 1.0, };
281
282 let reduction: f64 = k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum();
283 let k_ss = rbf_kernel_sq(query, query, ls); let var = k_ss - reduction;
285 var.max(0.0)
286}
287
288fn gp_posterior_mean(obs_points: &[Vec<f64>], obs_vals: &[f64], query: &[f64], nugget: f64) -> f64 {
290 let n = obs_points.len();
291 if n == 0 {
292 return 0.0;
293 }
294 let ls = auto_length_scale(obs_points);
295 let k_star: Vec<f64> = obs_points
296 .iter()
297 .map(|p| rbf_kernel_sq(query, p, ls))
298 .collect();
299 let k_mat = build_kernel_matrix(obs_points, ls, nugget);
300 let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, obs_vals, n) {
301 Ok(a) => a,
302 Err(_) => return 0.0,
303 };
304 k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum()
305}
306
307fn expected_improvement(
311 obs_points: &[Vec<f64>],
312 obs_vals: &[f64],
313 query: &[f64],
314 nugget: f64,
315) -> f64 {
316 if obs_vals.is_empty() {
317 return 1.0;
318 }
319 let y_best = obs_vals.iter().cloned().fold(f64::INFINITY, f64::min);
320 let ls = auto_length_scale(obs_points);
321 let n = obs_points.len();
322 let k_star: Vec<f64> = obs_points
323 .iter()
324 .map(|p| rbf_kernel_sq(query, p, ls))
325 .collect();
326 let k_mat = build_kernel_matrix(obs_points, ls, nugget);
327
328 let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, obs_vals, n) {
329 Ok(a) => a,
330 Err(_) => return 0.0,
331 };
332 let mu: f64 = k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum();
333
334 let alpha_v = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
336 Ok(a) => a,
337 Err(_) => return 0.0,
338 };
339 let reduction: f64 = k_star.iter().zip(alpha_v.iter()).map(|(k, a)| k * a).sum();
340 let sigma2 = (rbf_kernel_sq(query, query, ls) - reduction).max(1e-18);
341 let sigma = sigma2.sqrt();
342
343 let z = (y_best - mu) / sigma;
344 let phi_z = 0.5 * (1.0 + erf_approx(z / std::f64::consts::SQRT_2));
346 let pdf_z = (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt();
347 let ei = (y_best - mu) * phi_z + sigma * pdf_z;
348 ei.max(0.0)
349}
350
351fn leverage_score(obs_points: &[Vec<f64>], query: &[f64], nugget: f64) -> f64 {
356 let n = obs_points.len();
357 if n == 0 {
358 return 1.0;
359 }
360 let ls = auto_length_scale(obs_points);
361 let k_star: Vec<f64> = obs_points
362 .iter()
363 .map(|p| rbf_kernel_sq(query, p, ls))
364 .collect();
365 let k_mat = build_kernel_matrix(obs_points, ls, nugget);
366 let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
367 Ok(a) => a,
368 Err(_) => return 0.0,
369 };
370 k_star
371 .iter()
372 .zip(alpha.iter())
373 .map(|(k, a)| k * a)
374 .sum::<f64>()
375 .max(0.0)
376}
377
378struct XorShift64(u64);
384
385impl XorShift64 {
386 fn new(seed: u64) -> Self {
387 Self(if seed == 0 {
388 0xDEAD_BEEF_CAFE_BABE
389 } else {
390 seed
391 })
392 }
393 fn next_u64(&mut self) -> u64 {
394 let mut x = self.0;
395 x ^= x << 13;
396 x ^= x >> 7;
397 x ^= x << 17;
398 self.0 = x;
399 x
400 }
401 fn next_f64(&mut self) -> f64 {
402 (self.next_u64() as f64 + 0.5) / (u64::MAX as f64 + 1.0)
403 }
404}
405
406pub fn generate_candidates_with_seed(domain: &[[f64; 2]], n: usize, seed: u64) -> Vec<Vec<f64>> {
408 let mut rng = XorShift64::new(seed);
409 generate_candidates(domain, n, &mut rng)
410}
411
412fn generate_candidates(domain: &[[f64; 2]], n: usize, rng: &mut XorShift64) -> Vec<Vec<f64>> {
414 if domain.is_empty() || n == 0 {
415 return Vec::new();
416 }
417 (0..n)
418 .map(|_| {
419 domain
420 .iter()
421 .map(|&[lo, hi]| lo + rng.next_f64() * (hi - lo))
422 .collect()
423 })
424 .collect()
425}
426
427fn build_kernel_matrix(obs_points: &[Vec<f64>], ls: f64, nugget: f64) -> Vec<f64> {
433 let n = obs_points.len();
434 let mut k = vec![0.0f64; n * n];
435 for i in 0..n {
436 for j in 0..n {
437 k[i * n + j] = rbf_kernel_sq(&obs_points[i], &obs_points[j], ls);
438 }
439 k[i * n + i] += nugget;
440 }
441 k
442}
443
444fn auto_length_scale(points: &[Vec<f64>]) -> f64 {
446 let n = points.len();
447 if n <= 1 {
448 return 1.0;
449 }
450 let mut dists: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
451 for i in 0..n {
452 for j in (i + 1)..n {
453 let d2: f64 = points[i]
454 .iter()
455 .zip(points[j].iter())
456 .map(|(&a, &b)| (a - b) * (a - b))
457 .sum();
458 dists.push(d2.sqrt());
459 }
460 }
461 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462 let med = if dists.is_empty() {
463 1.0
464 } else {
465 dists[dists.len() / 2]
466 };
467 (med / std::f64::consts::SQRT_2).max(1e-6)
468}
469
470fn erf_approx(x: f64) -> f64 {
472 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
473 let poly = t
474 * (0.254829592
475 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
476 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
477 sign * (1.0 - poly * (-x * x).exp())
478}
479
480#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn make_sampler(seed: u64) -> ActiveSampler {
489 ActiveSampler::new(ActiveSamplerConfig {
490 acquisition: ActiveAcquisitionFunction::MaximumVariance,
491 n_candidates: 50,
492 domain: vec![[0.0, 1.0], [0.0, 1.0]],
493 seed,
494 })
495 }
496
497 #[test]
499 fn test_suggest_next_within_domain() {
500 let mut sampler = make_sampler(42);
501 sampler.observe(vec![0.5, 0.5], 1.0);
502 let next = sampler.suggest_next();
503 assert_eq!(next.len(), 2, "suggested point should have 2 dimensions");
504 let domain = &sampler.config.domain;
505 for (d, &v) in next.iter().enumerate() {
506 assert!(
507 v >= domain[d][0] && v <= domain[d][1],
508 "dim {d}: {v} not in [{}, {}]",
509 domain[d][0],
510 domain[d][1]
511 );
512 }
513 }
514
515 #[test]
517 fn test_observe_increments_count() {
518 let mut sampler = make_sampler(1);
519 assert_eq!(sampler.n_observed(), 0);
520 sampler.observe(vec![0.1, 0.2], 0.5);
521 assert_eq!(sampler.n_observed(), 1);
522 sampler.observe(vec![0.8, 0.3], 1.5);
523 assert_eq!(sampler.n_observed(), 2);
524 }
525
526 #[test]
528 fn test_loo_error_changes_after_observation() {
529 let mut sampler = make_sampler(3);
530 sampler.observe(vec![0.0, 0.0], 0.0);
531 sampler.observe(vec![1.0, 0.0], 1.0);
532 sampler.observe(vec![0.5, 1.0], 0.5);
533
534 let err_before = sampler.loo_error();
535
536 sampler.observe(vec![0.2, 0.8], 0.2);
537
538 let err_after = sampler.loo_error();
539 assert!(err_before.is_finite(), "loo_error before should be finite");
542 assert!(err_after.is_finite(), "loo_error after should be finite");
543 assert!(
545 err_before != err_after || err_after == 0.0,
546 "loo_error should change (or be 0) after new observation"
547 );
548 }
549
550 #[test]
552 fn test_different_seeds_different_suggestions() {
553 let mut s1 = make_sampler(7);
554 let mut s2 = make_sampler(99999);
555 s1.observe(vec![0.5, 0.5], 1.0);
556 s2.observe(vec![0.5, 0.5], 1.0);
557
558 let n1 = s1.suggest_next();
559 let n2 = s2.suggest_next();
560 let differ = n1.iter().zip(n2.iter()).any(|(a, b)| (a - b).abs() > 1e-10);
561 assert!(
562 differ,
563 "Different seeds should produce different suggested points (got {:?} and {:?})",
564 n1, n2
565 );
566 }
567
568 #[test]
570 fn test_ei_non_negative() {
571 let mut sampler = ActiveSampler::new(ActiveSamplerConfig {
572 acquisition: ActiveAcquisitionFunction::ExpectedImprovement,
573 n_candidates: 20,
574 domain: vec![[0.0, 1.0]],
575 seed: 5,
576 });
577 sampler.observe(vec![0.3], 2.0);
578 sampler.observe(vec![0.7], 1.0);
579
580 for x in [0.1, 0.5, 0.9] {
581 let v = sampler.acquisition_value(&[x]);
582 assert!(v >= 0.0, "EI must be non-negative, got {v} at x={x}");
583 }
584 }
585
586 #[test]
588 fn test_leverage_score_range() {
589 let mut sampler = ActiveSampler::new(ActiveSamplerConfig {
590 acquisition: ActiveAcquisitionFunction::LeverageScore,
591 n_candidates: 20,
592 domain: vec![[0.0, 1.0], [0.0, 1.0]],
593 seed: 10,
594 });
595 sampler.observe(vec![0.2, 0.3], 1.0);
596 sampler.observe(vec![0.8, 0.7], 2.0);
597
598 let v = sampler.acquisition_value(&[0.5, 0.5]);
599 assert!(
600 v >= 0.0 && v <= 1.0 + 1e-10,
601 "leverage score should be in [0, 1], got {v}"
602 );
603 }
604
605 #[test]
607 fn test_rbf_kernel_sq_at_zero() {
608 let x = vec![0.3, 0.7];
609 let v = rbf_kernel_sq(&x, &x, 1.0);
610 assert!(
611 (v - 1.0).abs() < 1e-15,
612 "SE kernel at r=0 should be 1.0, got {v}"
613 );
614 }
615}