Skip to main content

scirs2_interpolate/auto_kernel_gp/
mod.rs

1//! Gaussian Process interpolator with automatic kernel structure discovery.
2//!
3//! [`AutoKernelGp`] implements a simplified version of the Duvenaud-style
4//! Automatic Statistician approach: it searches a grammar of base kernels
5//! (`Rbf`, `Matern52`, `Periodic`, `Linear`, `WhiteNoise`) combined by sum
6//! and product operators, and selects the best composite kernel using k-fold
7//! cross-validated mean squared error (CV-MSE).
8//!
9//! ## Algorithm summary
10//!
11//! 1. **Generate candidates** — BFS over the kernel grammar up to `max_depth`
12//!    (default 2), producing depth-0 (base), depth-1 (base ± base), and
13//!    depth-2 (depth-1 ± base) expressions.  Commutative duplicates are pruned.
14//!
15//! 2. **Optimise hyperparameters** — For each candidate:
16//!    - RBF / Matérn: golden-section on ℓ ∈ [0.05, 10].
17//!    - Periodic: grid on p ∈ {0.1, 0.2, 0.5, 1.0, 2.0, 5.0}, golden-section on ℓ.
18//!    - Linear / WhiteNoise: golden-section on variance.
19//!    - Composite: sub-expressions optimised independently (greedy).
20//!
21//! 3. **Cross-validate** — k-fold (default 5) leave-out MSE on the training data.
22//!    The kernel with the lowest CV-MSE is selected.
23//!
24//! 4. **Final fit** — GP alpha vector `α = (K + σ²I)⁻¹ · y` is computed via
25//!    inline Cholesky (Banachiewicz), with a small jitter `1e-8` for stability.
26//!
27//! 5. **Predict** — `y* = K(x*, X) · α`.
28//!
29//! ## Selection criterion
30//!
31//! CV-MSE on held-out data is used (not log marginal likelihood on the full
32//! training set) to guard against over-fitting by complex kernels and to make
33//! the selection criterion directly interpretable as predictive error.
34
35pub mod kernel;
36pub mod search;
37
38pub use kernel::{BaseKernel, KernelExpr};
39
40use crate::error::{InterpolateError, InterpolateResult};
41use search::{build_cross_kernel, gp_fit, gp_predict, search_kernels};
42
43// Default period grid for Periodic kernel hyperparameter search.
44const DEFAULT_PERIOD_GRID: &[f64] = &[0.1, 0.2, 0.5, 1.0, 2.0, 5.0];
45
46// ---------------------------------------------------------------------------
47// Configuration
48// ---------------------------------------------------------------------------
49
50/// Configuration for [`AutoKernelGp`].
51#[derive(Debug, Clone)]
52pub struct AutoKernelGpConfig {
53    /// Maximum depth of kernel expression tree.  Default: 2.
54    ///
55    /// - Depth 0: only base kernels.
56    /// - Depth 1: base ± base (10 unique pairs × 2 ops = 20 extra).
57    /// - Depth 2: depth-1 ± base (≈ 40 extra).
58    pub max_depth: usize,
59    /// Random restarts for hyperparameter optimisation (increases search
60    /// budget for golden-section).  Default: 3.
61    pub n_restarts: usize,
62    /// Observation noise variance added to the kernel diagonal.  Default: 0.01.
63    pub noise_variance: f64,
64    /// Number of cross-validation folds.  Default: 5.
65    pub cv_folds: usize,
66    /// Deterministic seed (reserved for future use with random restarts).
67    pub seed: u64,
68}
69
70impl Default for AutoKernelGpConfig {
71    fn default() -> Self {
72        Self {
73            max_depth: 2,
74            n_restarts: 3,
75            noise_variance: 0.01,
76            cv_folds: 5,
77            seed: 42,
78        }
79    }
80}
81
82// ---------------------------------------------------------------------------
83// AutoKernelGp
84// ---------------------------------------------------------------------------
85
86/// Gaussian Process interpolator with automatic kernel structure discovery.
87///
88/// # Example
89///
90/// ```rust
91/// use scirs2_interpolate::auto_kernel_gp::{AutoKernelGp, AutoKernelGpConfig};
92///
93/// let x: Vec<f64> = (0..20).map(|i| i as f64 * std::f64::consts::PI / 10.0).collect();
94/// let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
95///
96/// let config = AutoKernelGpConfig {
97///     max_depth: 1,
98///     cv_folds: 3,
99///     ..Default::default()
100/// };
101/// let mut gp = AutoKernelGp::new(config);
102/// gp.fit(&x, &y).expect("fit ok");
103///
104/// let x_new = vec![0.1, 0.5, 1.0];
105/// let preds = gp.predict(&x_new).expect("predict ok");
106/// assert_eq!(preds.len(), 3);
107/// ```
108pub struct AutoKernelGp {
109    /// The selected kernel expression.
110    best_kernel: KernelExpr,
111    /// Cross-validation score of the selected kernel (lower = better).
112    best_cv_score: f64,
113    /// GP dual variables α = (K + σ²I)⁻¹ y.
114    alpha: Vec<f64>,
115    /// Training input locations.
116    train_x: Vec<f64>,
117    /// Training output values.
118    train_y: Vec<f64>,
119    /// Configuration.
120    config: AutoKernelGpConfig,
121    /// All kernel-search results, sorted by CV score.
122    search_results: Vec<(String, f64)>,
123    /// Whether the GP has been fitted.
124    is_fitted: bool,
125}
126
127impl AutoKernelGp {
128    /// Create a new (unfitted) `AutoKernelGp`.
129    pub fn new(config: AutoKernelGpConfig) -> Self {
130        Self {
131            best_kernel: KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 }),
132            best_cv_score: f64::MAX,
133            alpha: Vec::new(),
134            train_x: Vec::new(),
135            train_y: Vec::new(),
136            config,
137            search_results: Vec::new(),
138            is_fitted: false,
139        }
140    }
141
142    /// Search the kernel grammar and fit the GP on `x`, `y`.
143    ///
144    /// `x` must be strictly sorted (ascending) for well-defined kernel matrices,
145    /// though this is not enforced.  Duplicate `x` values will cause a
146    /// near-singular kernel matrix which is handled by jitter.
147    pub fn fit(&mut self, x: &[f64], y: &[f64]) -> InterpolateResult<()> {
148        if x.len() != y.len() {
149            return Err(InterpolateError::DimensionMismatch(format!(
150                "x length {} ≠ y length {}",
151                x.len(),
152                y.len()
153            )));
154        }
155        if x.len() < 2 {
156            return Err(InterpolateError::InvalidInput {
157                message: "at least 2 training points are required".to_string(),
158            });
159        }
160
161        // Run kernel search.
162        let (ranked, best_kernel) = search_kernels(
163            x,
164            y,
165            self.config.max_depth,
166            self.config.n_restarts,
167            self.config.noise_variance,
168            self.config.cv_folds,
169            DEFAULT_PERIOD_GRID,
170        );
171
172        self.search_results = ranked;
173        self.best_cv_score = self
174            .search_results
175            .first()
176            .map(|(_, s)| *s)
177            .unwrap_or(f64::MAX);
178        self.best_kernel = best_kernel;
179
180        // Final fit with the selected kernel on the full training set.
181        self.alpha =
182            gp_fit(&self.best_kernel, x, y, self.config.noise_variance).ok_or_else(|| {
183                InterpolateError::ComputationError(
184                    "Cholesky failed for selected kernel on full training set".to_string(),
185                )
186            })?;
187
188        self.train_x = x.to_vec();
189        self.train_y = y.to_vec();
190        self.is_fitted = true;
191        Ok(())
192    }
193
194    /// Predict at new input locations `x_new`.
195    pub fn predict(&self, x_new: &[f64]) -> InterpolateResult<Vec<f64>> {
196        if !self.is_fitted {
197            return Err(InterpolateError::InvalidState(
198                "GP must be fitted before prediction".to_string(),
199            ));
200        }
201        if x_new.is_empty() {
202            return Ok(Vec::new());
203        }
204        Ok(gp_predict(
205            &self.best_kernel,
206            &self.train_x,
207            &self.alpha,
208            x_new,
209        ))
210    }
211
212    /// Return a human-readable description of the selected kernel structure.
213    pub fn selected_kernel_description(&self) -> String {
214        self.best_kernel.description()
215    }
216
217    /// Return the best cross-validation score (lower is better).
218    pub fn best_cv_score(&self) -> f64 {
219        self.best_cv_score
220    }
221
222    /// Return the full ranked list of `(kernel_description, cv_mse_score)` pairs.
223    ///
224    /// Sorted by CV-MSE ascending (best first).
225    pub fn kernel_search_results(&self) -> &[(String, f64)] {
226        &self.search_results
227    }
228
229    /// Return the selected kernel expression.
230    pub fn kernel(&self) -> &KernelExpr {
231        &self.best_kernel
232    }
233}
234
235// ---------------------------------------------------------------------------
236// Tests
237// ---------------------------------------------------------------------------
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn sin_data(n: usize) -> (Vec<f64>, Vec<f64>) {
244        let x: Vec<f64> = (0..n)
245            .map(|i| i as f64 * 2.0 * std::f64::consts::PI / n as f64)
246            .collect();
247        let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
248        (x, y)
249    }
250
251    #[test]
252    fn auto_kernel_gp_fits_sin_data() {
253        let (x, y) = sin_data(20);
254        let config = AutoKernelGpConfig {
255            max_depth: 1,
256            cv_folds: 3,
257            n_restarts: 1,
258            ..Default::default()
259        };
260        let mut gp = AutoKernelGp::new(config);
261        gp.fit(&x, &y).expect("fit: should succeed on sin data");
262        // Predict at training points — should be close.
263        let preds = gp.predict(&x).expect("predict: should succeed");
264        assert_eq!(preds.len(), x.len());
265        let mse: f64 = preds
266            .iter()
267            .zip(y.iter())
268            .map(|(p, t)| (p - t).powi(2))
269            .sum::<f64>()
270            / x.len() as f64;
271        assert!(
272            mse < 0.5,
273            "MSE at training points should be small, got {mse}"
274        );
275    }
276
277    #[test]
278    fn auto_kernel_gp_predict_shape_correct() {
279        let (x, y) = sin_data(15);
280        let config = AutoKernelGpConfig {
281            max_depth: 0, // only base kernels
282            cv_folds: 3,
283            n_restarts: 1,
284            ..Default::default()
285        };
286        let mut gp = AutoKernelGp::new(config);
287        gp.fit(&x, &y).expect("fit ok");
288        let x_new = vec![0.1, 0.5, 1.0, 2.0, 4.0];
289        let preds = gp.predict(&x_new).expect("predict ok");
290        assert_eq!(
291            preds.len(),
292            x_new.len(),
293            "prediction shape must match query length"
294        );
295    }
296
297    #[test]
298    fn auto_kernel_gp_description_is_nonempty() {
299        let (x, y) = sin_data(12);
300        let config = AutoKernelGpConfig {
301            max_depth: 1,
302            cv_folds: 3,
303            n_restarts: 1,
304            ..Default::default()
305        };
306        let mut gp = AutoKernelGp::new(config);
307        gp.fit(&x, &y).expect("fit ok");
308        let desc = gp.selected_kernel_description();
309        assert!(
310            !desc.is_empty(),
311            "kernel description must not be empty: '{desc}'"
312        );
313    }
314
315    #[test]
316    fn auto_kernel_gp_selects_periodic_kernel_for_sin() {
317        // With enough data and depth ≥ 1 including the Periodic base kernel,
318        // the search should find a kernel that fits sin well.  We use a soft
319        // assertion: the CV score is finite.
320        let (x, y) = sin_data(20);
321        let config = AutoKernelGpConfig {
322            max_depth: 1,
323            cv_folds: 3,
324            n_restarts: 2,
325            noise_variance: 1e-4,
326            ..Default::default()
327        };
328        let mut gp = AutoKernelGp::new(config);
329        gp.fit(&x, &y).expect("fit ok");
330        assert!(
331            gp.best_cv_score().is_finite(),
332            "CV score must be finite, got {}",
333            gp.best_cv_score()
334        );
335        assert!(
336            !gp.kernel_search_results().is_empty(),
337            "search results must not be empty"
338        );
339    }
340
341    #[test]
342    fn auto_kernel_gp_cv_score_improves_with_depth() {
343        // Depth-2 search should find a kernel with CV-MSE ≤ depth-1 on sin data
344        // (same data — strictly: at least one depth-2 candidate covers all depth-1 ones).
345        let (x, y) = sin_data(18);
346        let mut scores = Vec::new();
347        for max_depth in [0usize, 1, 2] {
348            let config = AutoKernelGpConfig {
349                max_depth,
350                cv_folds: 3,
351                n_restarts: 1,
352                noise_variance: 1e-3,
353                ..Default::default()
354            };
355            let mut gp = AutoKernelGp::new(config);
356            gp.fit(&x, &y).expect("fit ok");
357            scores.push(gp.best_cv_score());
358        }
359        // Each deeper search covers at least as many candidates as the shallower one,
360        // so CV score should be non-increasing.
361        assert!(
362            scores[1] <= scores[0] * 1.1,
363            "depth-1 score {} should be ≤ depth-0 score {} (with 10% tolerance)",
364            scores[1],
365            scores[0]
366        );
367        assert!(
368            scores[2] <= scores[1] * 1.1,
369            "depth-2 score {} should be ≤ depth-1 score {} (with 10% tolerance)",
370            scores[2],
371            scores[1]
372        );
373    }
374
375    #[test]
376    fn auto_kernel_gp_predict_before_fit_errors() {
377        let gp = AutoKernelGp::new(AutoKernelGpConfig::default());
378        let result = gp.predict(&[0.5, 1.0]);
379        assert!(result.is_err(), "predict before fit should return an error");
380    }
381}