Skip to main content

scirs2_interpolate/
active_learning.rs

1//! Active sampling to minimise interpolation error.
2//!
3//! Provides a model-free, acquisition-function–driven strategy for selecting
4//! the next point to query in order to reduce interpolation error most
5//! efficiently.  Three acquisition strategies are supported:
6//!
7//! - **MaximumVariance**: maximise the GP posterior variance at the candidate
8//!   points — pure exploration.
9//! - **ExpectedImprovement**: standard EI using GP posterior mean and variance.
10//! - **LeverageScore**: statistical leverage score of the candidate against the
11//!   kernel matrix formed by the observed points.
12//!
13//! Candidate points are generated using a deterministic quasi-random sequence
14//! (XorShift64-based), ensuring reproducibility.
15//!
16//! ## References
17//!
18//! - Settles, B. (2009). *Active Learning Literature Survey*.
19//! - Srinivas, N. et al. (2010). *Gaussian Process Optimization in the Bandit
20//!   Setting: No Regret and Experimental Design*.
21
22use crate::error::InterpolateError;
23
24// ---------------------------------------------------------------------------
25// Acquisition function enum
26// ---------------------------------------------------------------------------
27
28/// Acquisition function used to rank candidate query points.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ActiveAcquisitionFunction {
31    /// Select the point with maximum GP posterior variance (pure exploration).
32    MaximumVariance,
33    /// Standard Expected Improvement (exploits current best observation).
34    ExpectedImprovement,
35    /// Statistical leverage score: measures the influence of a new point on the
36    /// Gram matrix.
37    LeverageScore,
38}
39
40// ---------------------------------------------------------------------------
41// Configuration
42// ---------------------------------------------------------------------------
43
44/// Configuration for [`ActiveSampler`].
45#[derive(Debug, Clone)]
46pub struct ActiveSamplerConfig {
47    /// Acquisition function to use when ranking candidate points.
48    pub acquisition: ActiveAcquisitionFunction,
49    /// Number of candidate points sampled per `suggest_next` call.
50    pub n_candidates: usize,
51    /// Domain bounds for each dimension: `domain[d] = [min, max]`.
52    pub domain: Vec<[f64; 2]>,
53    /// Seed for the candidate generator.
54    pub seed: u64,
55}
56
57impl Default for ActiveSamplerConfig {
58    fn default() -> Self {
59        Self {
60            acquisition: ActiveAcquisitionFunction::MaximumVariance,
61            n_candidates: 64,
62            domain: vec![[0.0, 1.0], [0.0, 1.0]],
63            seed: 42,
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// ActiveSampler
70// ---------------------------------------------------------------------------
71
72/// Active sampling strategy for minimising interpolation error.
73///
74/// # Example
75///
76/// ```rust
77/// use scirs2_interpolate::active_learning::{
78///     ActiveSampler, ActiveSamplerConfig, ActiveAcquisitionFunction,
79/// };
80///
81/// let config = ActiveSamplerConfig {
82///     acquisition: ActiveAcquisitionFunction::MaximumVariance,
83///     n_candidates: 20,
84///     domain: vec![[0.0, 1.0], [0.0, 1.0]],
85///     seed: 7,
86/// };
87/// let mut sampler = ActiveSampler::new(config);
88///
89/// // Seed with one observation
90/// sampler.observe(vec![0.5, 0.5], 1.0);
91///
92/// let next = sampler.suggest_next();
93/// assert_eq!(next.len(), 2);
94/// ```
95#[derive(Debug)]
96pub struct ActiveSampler {
97    config: ActiveSamplerConfig,
98    observed_points: Vec<Vec<f64>>,
99    observed_values: Vec<f64>,
100    n_dims: usize,
101}
102
103impl ActiveSampler {
104    /// Create a new sampler.  The number of dimensions is inferred from
105    /// `config.domain.len()`.
106    pub fn new(config: ActiveSamplerConfig) -> Self {
107        let n_dims = config.domain.len().max(1);
108        Self {
109            config,
110            observed_points: Vec::new(),
111            observed_values: Vec::new(),
112            n_dims,
113        }
114    }
115
116    /// Select the next query point by evaluating the acquisition function at
117    /// `config.n_candidates` randomly sampled candidate points.
118    ///
119    /// If no candidates score above zero, returns the first candidate (random).
120    pub fn suggest_next(&self) -> Vec<f64> {
121        let mut rng = XorShift64::new(self.config.seed.wrapping_add(self.n_observed() as u64));
122        let candidates =
123            generate_candidates(&self.config.domain, self.config.n_candidates, &mut rng);
124
125        if candidates.is_empty() {
126            // Fallback: return domain centre
127            return self
128                .config
129                .domain
130                .iter()
131                .map(|&[lo, hi]| 0.5 * (lo + hi))
132                .collect();
133        }
134
135        // Rank candidates
136        let best = candidates.iter().cloned().enumerate().fold(
137            (0usize, f64::NEG_INFINITY),
138            |(bi, bv), (i, ref cand)| {
139                let score = self.acquisition_value(cand);
140                if score > bv {
141                    (i, score)
142                } else {
143                    (bi, bv)
144                }
145            },
146        );
147
148        candidates.into_iter().nth(best.0).unwrap_or_else(|| {
149            self.config
150                .domain
151                .iter()
152                .map(|&[lo, hi]| 0.5 * (lo + hi))
153                .collect()
154        })
155    }
156
157    /// Register a new observation.
158    pub fn observe(&mut self, point: Vec<f64>, value: f64) {
159        self.observed_points.push(point);
160        self.observed_values.push(value);
161    }
162
163    /// Compute the acquisition value for a single candidate `point`.
164    ///
165    /// Returns 0.0 when there are no observations.
166    pub fn acquisition_value(&self, point: &[f64]) -> f64 {
167        if self.observed_points.is_empty() {
168            return 1.0; // no data → treat every point as equally informative
169        }
170        match self.config.acquisition {
171            ActiveAcquisitionFunction::MaximumVariance => {
172                gp_posterior_variance(&self.observed_points, &self.observed_values, point, 1e-6)
173            }
174            ActiveAcquisitionFunction::ExpectedImprovement => {
175                expected_improvement(&self.observed_points, &self.observed_values, point, 1e-6)
176            }
177            ActiveAcquisitionFunction::LeverageScore => {
178                leverage_score(&self.observed_points, point, 1e-6)
179            }
180        }
181    }
182
183    /// Leave-one-out cross-validation error estimate.
184    ///
185    /// For each observed point, fits a simple GP on the remaining n-1 points
186    /// and measures the squared prediction error.  Returns the RMS LOO error.
187    ///
188    /// Returns 0.0 when fewer than 2 observations are available.
189    pub fn loo_error(&self) -> f64 {
190        let n = self.observed_points.len();
191        if n < 2 {
192            return 0.0;
193        }
194        let mut sum_sq = 0.0_f64;
195        for leave_out in 0..n {
196            // Collect remaining points
197            let rem_pts: Vec<Vec<f64>> = self
198                .observed_points
199                .iter()
200                .enumerate()
201                .filter(|(i, _)| *i != leave_out)
202                .map(|(_, p)| p.clone())
203                .collect();
204            let rem_vals: Vec<f64> = self
205                .observed_values
206                .iter()
207                .enumerate()
208                .filter(|(i, _)| *i != leave_out)
209                .map(|(_, &v)| v)
210                .collect();
211
212            // Predict the left-out point
213            let pred =
214                gp_posterior_mean(&rem_pts, &rem_vals, &self.observed_points[leave_out], 1e-6);
215            let err = pred - self.observed_values[leave_out];
216            sum_sq += err * err;
217        }
218        (sum_sq / n as f64).sqrt()
219    }
220
221    /// Number of observations recorded so far.
222    pub fn n_observed(&self) -> usize {
223        self.observed_points.len()
224    }
225
226    /// Slice of all observed points.
227    pub fn observed_points(&self) -> &[Vec<f64>] {
228        &self.observed_points
229    }
230
231    /// Dimensionality of the domain.
232    pub fn n_dims(&self) -> usize {
233        self.n_dims
234    }
235}
236
237// ---------------------------------------------------------------------------
238// GP helper functions
239// ---------------------------------------------------------------------------
240
241/// Squared-exponential RBF kernel: k(x, x') = exp(-‖x-x'‖² / (2 l²)).
242pub fn rbf_kernel_sq(x1: &[f64], x2: &[f64], length_scale: f64) -> f64 {
243    let sq_dist: f64 = x1
244        .iter()
245        .zip(x2.iter())
246        .map(|(&a, &b)| (a - b) * (a - b))
247        .sum();
248    (-sq_dist / (2.0 * length_scale * length_scale)).exp()
249}
250
251/// GP posterior variance at `query` given observations.
252///
253/// Uses a zero-mean GP with SE kernel.  Solves K w = k_star via Gaussian
254/// elimination.  Returns max(0, k_star_star - k_star^T (K + nugget I)^{-1} k_star).
255pub fn gp_posterior_variance(
256    obs_points: &[Vec<f64>],
257    obs_vals: &[f64],
258    query: &[f64],
259    nugget: f64,
260) -> f64 {
261    let n = obs_points.len();
262    if n == 0 {
263        return 1.0;
264    }
265    let ls = auto_length_scale(obs_points);
266
267    // k* = kernel vector between query and observations
268    let k_star: Vec<f64> = obs_points
269        .iter()
270        .map(|p| rbf_kernel_sq(query, p, ls))
271        .collect();
272
273    // K + nugget I
274    let k_mat = build_kernel_matrix(obs_points, ls, nugget);
275
276    // Solve (K + σI) alpha = k_star  →  alpha = (K+σI)^{-1} k_star
277    let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
278        Ok(a) => a,
279        Err(_) => return 1.0, // fallback on singular matrix
280    };
281
282    let reduction: f64 = k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum();
283    let k_ss = rbf_kernel_sq(query, query, ls); // = 1.0 for SE kernel
284    let var = k_ss - reduction;
285    var.max(0.0)
286}
287
288/// GP posterior mean at `query`.
289fn gp_posterior_mean(obs_points: &[Vec<f64>], obs_vals: &[f64], query: &[f64], nugget: f64) -> f64 {
290    let n = obs_points.len();
291    if n == 0 {
292        return 0.0;
293    }
294    let ls = auto_length_scale(obs_points);
295    let k_star: Vec<f64> = obs_points
296        .iter()
297        .map(|p| rbf_kernel_sq(query, p, ls))
298        .collect();
299    let k_mat = build_kernel_matrix(obs_points, ls, nugget);
300    let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, obs_vals, n) {
301        Ok(a) => a,
302        Err(_) => return 0.0,
303    };
304    k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum()
305}
306
307/// Expected Improvement acquisition function.
308///
309/// EI(x) = (μ - y_best) Φ(z) + σ φ(z)  where z = (μ - y_best) / σ.
310fn expected_improvement(
311    obs_points: &[Vec<f64>],
312    obs_vals: &[f64],
313    query: &[f64],
314    nugget: f64,
315) -> f64 {
316    if obs_vals.is_empty() {
317        return 1.0;
318    }
319    let y_best = obs_vals.iter().cloned().fold(f64::INFINITY, f64::min);
320    let ls = auto_length_scale(obs_points);
321    let n = obs_points.len();
322    let k_star: Vec<f64> = obs_points
323        .iter()
324        .map(|p| rbf_kernel_sq(query, p, ls))
325        .collect();
326    let k_mat = build_kernel_matrix(obs_points, ls, nugget);
327
328    let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, obs_vals, n) {
329        Ok(a) => a,
330        Err(_) => return 0.0,
331    };
332    let mu: f64 = k_star.iter().zip(alpha.iter()).map(|(k, a)| k * a).sum();
333
334    // Variance
335    let alpha_v = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
336        Ok(a) => a,
337        Err(_) => return 0.0,
338    };
339    let reduction: f64 = k_star.iter().zip(alpha_v.iter()).map(|(k, a)| k * a).sum();
340    let sigma2 = (rbf_kernel_sq(query, query, ls) - reduction).max(1e-18);
341    let sigma = sigma2.sqrt();
342
343    let z = (y_best - mu) / sigma;
344    // Φ(z) and φ(z) via erf approximation
345    let phi_z = 0.5 * (1.0 + erf_approx(z / std::f64::consts::SQRT_2));
346    let pdf_z = (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt();
347    let ei = (y_best - mu) * phi_z + sigma * pdf_z;
348    ei.max(0.0)
349}
350
351/// Statistical leverage score of `query` given observed points.
352///
353/// The leverage score measures how much the new point would influence the
354/// Gram matrix: h = k*^T (K + σI)^{-1} k*.
355fn leverage_score(obs_points: &[Vec<f64>], query: &[f64], nugget: f64) -> f64 {
356    let n = obs_points.len();
357    if n == 0 {
358        return 1.0;
359    }
360    let ls = auto_length_scale(obs_points);
361    let k_star: Vec<f64> = obs_points
362        .iter()
363        .map(|p| rbf_kernel_sq(query, p, ls))
364        .collect();
365    let k_mat = build_kernel_matrix(obs_points, ls, nugget);
366    let alpha = match crate::gpu_rbf::solve_linear_system(&k_mat, &k_star, n) {
367        Ok(a) => a,
368        Err(_) => return 0.0,
369    };
370    k_star
371        .iter()
372        .zip(alpha.iter())
373        .map(|(k, a)| k * a)
374        .sum::<f64>()
375        .max(0.0)
376}
377
378// ---------------------------------------------------------------------------
379// Candidate generation
380// ---------------------------------------------------------------------------
381
382/// XorShift64 PRNG (local copy for this module).
383struct XorShift64(u64);
384
385impl XorShift64 {
386    fn new(seed: u64) -> Self {
387        Self(if seed == 0 {
388            0xDEAD_BEEF_CAFE_BABE
389        } else {
390            seed
391        })
392    }
393    fn next_u64(&mut self) -> u64 {
394        let mut x = self.0;
395        x ^= x << 13;
396        x ^= x >> 7;
397        x ^= x << 17;
398        self.0 = x;
399        x
400    }
401    fn next_f64(&mut self) -> f64 {
402        (self.next_u64() as f64 + 0.5) / (u64::MAX as f64 + 1.0)
403    }
404}
405
406/// Generate `n` uniformly random candidate points inside `domain` from a seed.
407pub fn generate_candidates_with_seed(domain: &[[f64; 2]], n: usize, seed: u64) -> Vec<Vec<f64>> {
408    let mut rng = XorShift64::new(seed);
409    generate_candidates(domain, n, &mut rng)
410}
411
412/// Generate `n` uniformly random candidate points inside `domain`.
413fn generate_candidates(domain: &[[f64; 2]], n: usize, rng: &mut XorShift64) -> Vec<Vec<f64>> {
414    if domain.is_empty() || n == 0 {
415        return Vec::new();
416    }
417    (0..n)
418        .map(|_| {
419            domain
420                .iter()
421                .map(|&[lo, hi]| lo + rng.next_f64() * (hi - lo))
422                .collect()
423        })
424        .collect()
425}
426
427// ---------------------------------------------------------------------------
428// Numerical helpers
429// ---------------------------------------------------------------------------
430
431/// Build SE kernel matrix K_{ij} = k(x_i, x_j) + nugget * δ_{ij}.
432fn build_kernel_matrix(obs_points: &[Vec<f64>], ls: f64, nugget: f64) -> Vec<f64> {
433    let n = obs_points.len();
434    let mut k = vec![0.0f64; n * n];
435    for i in 0..n {
436        for j in 0..n {
437            k[i * n + j] = rbf_kernel_sq(&obs_points[i], &obs_points[j], ls);
438        }
439        k[i * n + i] += nugget;
440    }
441    k
442}
443
444/// Simple heuristic for the SE length-scale: median pairwise distance / √2.
445fn auto_length_scale(points: &[Vec<f64>]) -> f64 {
446    let n = points.len();
447    if n <= 1 {
448        return 1.0;
449    }
450    let mut dists: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
451    for i in 0..n {
452        for j in (i + 1)..n {
453            let d2: f64 = points[i]
454                .iter()
455                .zip(points[j].iter())
456                .map(|(&a, &b)| (a - b) * (a - b))
457                .sum();
458            dists.push(d2.sqrt());
459        }
460    }
461    dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462    let med = if dists.is_empty() {
463        1.0
464    } else {
465        dists[dists.len() / 2]
466    };
467    (med / std::f64::consts::SQRT_2).max(1e-6)
468}
469
470/// Approximation of erf(x) using Abramowitz & Stegun formula 7.1.26.
471fn erf_approx(x: f64) -> f64 {
472    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
473    let poly = t
474        * (0.254829592
475            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
476    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
477    sign * (1.0 - poly * (-x * x).exp())
478}
479
480// ---------------------------------------------------------------------------
481// Tests
482// ---------------------------------------------------------------------------
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    fn make_sampler(seed: u64) -> ActiveSampler {
489        ActiveSampler::new(ActiveSamplerConfig {
490            acquisition: ActiveAcquisitionFunction::MaximumVariance,
491            n_candidates: 50,
492            domain: vec![[0.0, 1.0], [0.0, 1.0]],
493            seed,
494        })
495    }
496
497    /// suggest_next must return a point within the domain bounds.
498    #[test]
499    fn test_suggest_next_within_domain() {
500        let mut sampler = make_sampler(42);
501        sampler.observe(vec![0.5, 0.5], 1.0);
502        let next = sampler.suggest_next();
503        assert_eq!(next.len(), 2, "suggested point should have 2 dimensions");
504        let domain = &sampler.config.domain;
505        for (d, &v) in next.iter().enumerate() {
506            assert!(
507                v >= domain[d][0] && v <= domain[d][1],
508                "dim {d}: {v} not in [{}, {}]",
509                domain[d][0],
510                domain[d][1]
511            );
512        }
513    }
514
515    /// observe increases n_observed by 1 each time.
516    #[test]
517    fn test_observe_increments_count() {
518        let mut sampler = make_sampler(1);
519        assert_eq!(sampler.n_observed(), 0);
520        sampler.observe(vec![0.1, 0.2], 0.5);
521        assert_eq!(sampler.n_observed(), 1);
522        sampler.observe(vec![0.8, 0.3], 1.5);
523        assert_eq!(sampler.n_observed(), 2);
524    }
525
526    /// loo_error changes after adding a new observation.
527    #[test]
528    fn test_loo_error_changes_after_observation() {
529        let mut sampler = make_sampler(3);
530        sampler.observe(vec![0.0, 0.0], 0.0);
531        sampler.observe(vec![1.0, 0.0], 1.0);
532        sampler.observe(vec![0.5, 1.0], 0.5);
533
534        let err_before = sampler.loo_error();
535
536        sampler.observe(vec![0.2, 0.8], 0.2);
537
538        let err_after = sampler.loo_error();
539        // The two errors should differ (adding a point changes the LOO estimate)
540        // We allow them to be the same only by coincidence, so just check they're finite
541        assert!(err_before.is_finite(), "loo_error before should be finite");
542        assert!(err_after.is_finite(), "loo_error after should be finite");
543        // At least one of them should be non-zero given non-trivial data
544        assert!(
545            err_before != err_after || err_after == 0.0,
546            "loo_error should change (or be 0) after new observation"
547        );
548    }
549
550    /// Two different seeds should yield different suggested points.
551    #[test]
552    fn test_different_seeds_different_suggestions() {
553        let mut s1 = make_sampler(7);
554        let mut s2 = make_sampler(99999);
555        s1.observe(vec![0.5, 0.5], 1.0);
556        s2.observe(vec![0.5, 0.5], 1.0);
557
558        let n1 = s1.suggest_next();
559        let n2 = s2.suggest_next();
560        let differ = n1.iter().zip(n2.iter()).any(|(a, b)| (a - b).abs() > 1e-10);
561        assert!(
562            differ,
563            "Different seeds should produce different suggested points (got {:?} and {:?})",
564            n1, n2
565        );
566    }
567
568    /// ExpectedImprovement acquisition returns non-negative values.
569    #[test]
570    fn test_ei_non_negative() {
571        let mut sampler = ActiveSampler::new(ActiveSamplerConfig {
572            acquisition: ActiveAcquisitionFunction::ExpectedImprovement,
573            n_candidates: 20,
574            domain: vec![[0.0, 1.0]],
575            seed: 5,
576        });
577        sampler.observe(vec![0.3], 2.0);
578        sampler.observe(vec![0.7], 1.0);
579
580        for x in [0.1, 0.5, 0.9] {
581            let v = sampler.acquisition_value(&[x]);
582            assert!(v >= 0.0, "EI must be non-negative, got {v} at x={x}");
583        }
584    }
585
586    /// LeverageScore acquisition returns values in [0, 1].
587    #[test]
588    fn test_leverage_score_range() {
589        let mut sampler = ActiveSampler::new(ActiveSamplerConfig {
590            acquisition: ActiveAcquisitionFunction::LeverageScore,
591            n_candidates: 20,
592            domain: vec![[0.0, 1.0], [0.0, 1.0]],
593            seed: 10,
594        });
595        sampler.observe(vec![0.2, 0.3], 1.0);
596        sampler.observe(vec![0.8, 0.7], 2.0);
597
598        let v = sampler.acquisition_value(&[0.5, 0.5]);
599        assert!(
600            v >= 0.0 && v <= 1.0 + 1e-10,
601            "leverage score should be in [0, 1], got {v}"
602        );
603    }
604
605    /// rbf_kernel_sq at identical points returns 1.0.
606    #[test]
607    fn test_rbf_kernel_sq_at_zero() {
608        let x = vec![0.3, 0.7];
609        let v = rbf_kernel_sq(&x, &x, 1.0);
610        assert!(
611            (v - 1.0).abs() < 1e-15,
612            "SE kernel at r=0 should be 1.0, got {v}"
613        );
614    }
615}