scirs2_interpolate/auto_kernel_gp/mod.rs
1//! Gaussian Process interpolator with automatic kernel structure discovery.
2//!
3//! [`AutoKernelGp`] implements a simplified version of the Duvenaud-style
4//! Automatic Statistician approach: it searches a grammar of base kernels
5//! (`Rbf`, `Matern52`, `Periodic`, `Linear`, `WhiteNoise`) combined by sum
6//! and product operators, and selects the best composite kernel using k-fold
7//! cross-validated mean squared error (CV-MSE).
8//!
9//! ## Algorithm summary
10//!
11//! 1. **Generate candidates** — BFS over the kernel grammar up to `max_depth`
12//! (default 2), producing depth-0 (base), depth-1 (base ± base), and
13//! depth-2 (depth-1 ± base) expressions. Commutative duplicates are pruned.
14//!
15//! 2. **Optimise hyperparameters** — For each candidate:
16//! - RBF / Matérn: golden-section on ℓ ∈ [0.05, 10].
17//! - Periodic: grid on p ∈ {0.1, 0.2, 0.5, 1.0, 2.0, 5.0}, golden-section on ℓ.
18//! - Linear / WhiteNoise: golden-section on variance.
19//! - Composite: sub-expressions optimised independently (greedy).
20//!
21//! 3. **Cross-validate** — k-fold (default 5) leave-out MSE on the training data.
22//! The kernel with the lowest CV-MSE is selected.
23//!
24//! 4. **Final fit** — GP alpha vector `α = (K + σ²I)⁻¹ · y` is computed via
25//! inline Cholesky (Banachiewicz), with a small jitter `1e-8` for stability.
26//!
27//! 5. **Predict** — `y* = K(x*, X) · α`.
28//!
29//! ## Selection criterion
30//!
31//! CV-MSE on held-out data is used (not log marginal likelihood on the full
32//! training set) to guard against over-fitting by complex kernels and to make
33//! the selection criterion directly interpretable as predictive error.
34
35pub mod kernel;
36pub mod search;
37
38pub use kernel::{BaseKernel, KernelExpr};
39
40use crate::error::{InterpolateError, InterpolateResult};
41use search::{build_cross_kernel, gp_fit, gp_predict, search_kernels};
42
43// Default period grid for Periodic kernel hyperparameter search.
44const DEFAULT_PERIOD_GRID: &[f64] = &[0.1, 0.2, 0.5, 1.0, 2.0, 5.0];
45
46// ---------------------------------------------------------------------------
47// Configuration
48// ---------------------------------------------------------------------------
49
50/// Configuration for [`AutoKernelGp`].
51#[derive(Debug, Clone)]
52pub struct AutoKernelGpConfig {
53 /// Maximum depth of kernel expression tree. Default: 2.
54 ///
55 /// - Depth 0: only base kernels.
56 /// - Depth 1: base ± base (10 unique pairs × 2 ops = 20 extra).
57 /// - Depth 2: depth-1 ± base (≈ 40 extra).
58 pub max_depth: usize,
59 /// Random restarts for hyperparameter optimisation (increases search
60 /// budget for golden-section). Default: 3.
61 pub n_restarts: usize,
62 /// Observation noise variance added to the kernel diagonal. Default: 0.01.
63 pub noise_variance: f64,
64 /// Number of cross-validation folds. Default: 5.
65 pub cv_folds: usize,
66 /// Deterministic seed (reserved for future use with random restarts).
67 pub seed: u64,
68}
69
70impl Default for AutoKernelGpConfig {
71 fn default() -> Self {
72 Self {
73 max_depth: 2,
74 n_restarts: 3,
75 noise_variance: 0.01,
76 cv_folds: 5,
77 seed: 42,
78 }
79 }
80}
81
82// ---------------------------------------------------------------------------
83// AutoKernelGp
84// ---------------------------------------------------------------------------
85
86/// Gaussian Process interpolator with automatic kernel structure discovery.
87///
88/// # Example
89///
90/// ```rust
91/// use scirs2_interpolate::auto_kernel_gp::{AutoKernelGp, AutoKernelGpConfig};
92///
93/// let x: Vec<f64> = (0..20).map(|i| i as f64 * std::f64::consts::PI / 10.0).collect();
94/// let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
95///
96/// let config = AutoKernelGpConfig {
97/// max_depth: 1,
98/// cv_folds: 3,
99/// ..Default::default()
100/// };
101/// let mut gp = AutoKernelGp::new(config);
102/// gp.fit(&x, &y).expect("fit ok");
103///
104/// let x_new = vec![0.1, 0.5, 1.0];
105/// let preds = gp.predict(&x_new).expect("predict ok");
106/// assert_eq!(preds.len(), 3);
107/// ```
108pub struct AutoKernelGp {
109 /// The selected kernel expression.
110 best_kernel: KernelExpr,
111 /// Cross-validation score of the selected kernel (lower = better).
112 best_cv_score: f64,
113 /// GP dual variables α = (K + σ²I)⁻¹ y.
114 alpha: Vec<f64>,
115 /// Training input locations.
116 train_x: Vec<f64>,
117 /// Training output values.
118 train_y: Vec<f64>,
119 /// Configuration.
120 config: AutoKernelGpConfig,
121 /// All kernel-search results, sorted by CV score.
122 search_results: Vec<(String, f64)>,
123 /// Whether the GP has been fitted.
124 is_fitted: bool,
125}
126
127impl AutoKernelGp {
128 /// Create a new (unfitted) `AutoKernelGp`.
129 pub fn new(config: AutoKernelGpConfig) -> Self {
130 Self {
131 best_kernel: KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 }),
132 best_cv_score: f64::MAX,
133 alpha: Vec::new(),
134 train_x: Vec::new(),
135 train_y: Vec::new(),
136 config,
137 search_results: Vec::new(),
138 is_fitted: false,
139 }
140 }
141
142 /// Search the kernel grammar and fit the GP on `x`, `y`.
143 ///
144 /// `x` must be strictly sorted (ascending) for well-defined kernel matrices,
145 /// though this is not enforced. Duplicate `x` values will cause a
146 /// near-singular kernel matrix which is handled by jitter.
147 pub fn fit(&mut self, x: &[f64], y: &[f64]) -> InterpolateResult<()> {
148 if x.len() != y.len() {
149 return Err(InterpolateError::DimensionMismatch(format!(
150 "x length {} ≠ y length {}",
151 x.len(),
152 y.len()
153 )));
154 }
155 if x.len() < 2 {
156 return Err(InterpolateError::InvalidInput {
157 message: "at least 2 training points are required".to_string(),
158 });
159 }
160
161 // Run kernel search.
162 let (ranked, best_kernel) = search_kernels(
163 x,
164 y,
165 self.config.max_depth,
166 self.config.n_restarts,
167 self.config.noise_variance,
168 self.config.cv_folds,
169 DEFAULT_PERIOD_GRID,
170 );
171
172 self.search_results = ranked;
173 self.best_cv_score = self
174 .search_results
175 .first()
176 .map(|(_, s)| *s)
177 .unwrap_or(f64::MAX);
178 self.best_kernel = best_kernel;
179
180 // Final fit with the selected kernel on the full training set.
181 self.alpha =
182 gp_fit(&self.best_kernel, x, y, self.config.noise_variance).ok_or_else(|| {
183 InterpolateError::ComputationError(
184 "Cholesky failed for selected kernel on full training set".to_string(),
185 )
186 })?;
187
188 self.train_x = x.to_vec();
189 self.train_y = y.to_vec();
190 self.is_fitted = true;
191 Ok(())
192 }
193
194 /// Predict at new input locations `x_new`.
195 pub fn predict(&self, x_new: &[f64]) -> InterpolateResult<Vec<f64>> {
196 if !self.is_fitted {
197 return Err(InterpolateError::InvalidState(
198 "GP must be fitted before prediction".to_string(),
199 ));
200 }
201 if x_new.is_empty() {
202 return Ok(Vec::new());
203 }
204 Ok(gp_predict(
205 &self.best_kernel,
206 &self.train_x,
207 &self.alpha,
208 x_new,
209 ))
210 }
211
212 /// Return a human-readable description of the selected kernel structure.
213 pub fn selected_kernel_description(&self) -> String {
214 self.best_kernel.description()
215 }
216
217 /// Return the best cross-validation score (lower is better).
218 pub fn best_cv_score(&self) -> f64 {
219 self.best_cv_score
220 }
221
222 /// Return the full ranked list of `(kernel_description, cv_mse_score)` pairs.
223 ///
224 /// Sorted by CV-MSE ascending (best first).
225 pub fn kernel_search_results(&self) -> &[(String, f64)] {
226 &self.search_results
227 }
228
229 /// Return the selected kernel expression.
230 pub fn kernel(&self) -> &KernelExpr {
231 &self.best_kernel
232 }
233}
234
235// ---------------------------------------------------------------------------
236// Tests
237// ---------------------------------------------------------------------------
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn sin_data(n: usize) -> (Vec<f64>, Vec<f64>) {
244 let x: Vec<f64> = (0..n)
245 .map(|i| i as f64 * 2.0 * std::f64::consts::PI / n as f64)
246 .collect();
247 let y: Vec<f64> = x.iter().map(|&xi| xi.sin()).collect();
248 (x, y)
249 }
250
251 #[test]
252 fn auto_kernel_gp_fits_sin_data() {
253 let (x, y) = sin_data(20);
254 let config = AutoKernelGpConfig {
255 max_depth: 1,
256 cv_folds: 3,
257 n_restarts: 1,
258 ..Default::default()
259 };
260 let mut gp = AutoKernelGp::new(config);
261 gp.fit(&x, &y).expect("fit: should succeed on sin data");
262 // Predict at training points — should be close.
263 let preds = gp.predict(&x).expect("predict: should succeed");
264 assert_eq!(preds.len(), x.len());
265 let mse: f64 = preds
266 .iter()
267 .zip(y.iter())
268 .map(|(p, t)| (p - t).powi(2))
269 .sum::<f64>()
270 / x.len() as f64;
271 assert!(
272 mse < 0.5,
273 "MSE at training points should be small, got {mse}"
274 );
275 }
276
277 #[test]
278 fn auto_kernel_gp_predict_shape_correct() {
279 let (x, y) = sin_data(15);
280 let config = AutoKernelGpConfig {
281 max_depth: 0, // only base kernels
282 cv_folds: 3,
283 n_restarts: 1,
284 ..Default::default()
285 };
286 let mut gp = AutoKernelGp::new(config);
287 gp.fit(&x, &y).expect("fit ok");
288 let x_new = vec![0.1, 0.5, 1.0, 2.0, 4.0];
289 let preds = gp.predict(&x_new).expect("predict ok");
290 assert_eq!(
291 preds.len(),
292 x_new.len(),
293 "prediction shape must match query length"
294 );
295 }
296
297 #[test]
298 fn auto_kernel_gp_description_is_nonempty() {
299 let (x, y) = sin_data(12);
300 let config = AutoKernelGpConfig {
301 max_depth: 1,
302 cv_folds: 3,
303 n_restarts: 1,
304 ..Default::default()
305 };
306 let mut gp = AutoKernelGp::new(config);
307 gp.fit(&x, &y).expect("fit ok");
308 let desc = gp.selected_kernel_description();
309 assert!(
310 !desc.is_empty(),
311 "kernel description must not be empty: '{desc}'"
312 );
313 }
314
315 #[test]
316 fn auto_kernel_gp_selects_periodic_kernel_for_sin() {
317 // With enough data and depth ≥ 1 including the Periodic base kernel,
318 // the search should find a kernel that fits sin well. We use a soft
319 // assertion: the CV score is finite.
320 let (x, y) = sin_data(20);
321 let config = AutoKernelGpConfig {
322 max_depth: 1,
323 cv_folds: 3,
324 n_restarts: 2,
325 noise_variance: 1e-4,
326 ..Default::default()
327 };
328 let mut gp = AutoKernelGp::new(config);
329 gp.fit(&x, &y).expect("fit ok");
330 assert!(
331 gp.best_cv_score().is_finite(),
332 "CV score must be finite, got {}",
333 gp.best_cv_score()
334 );
335 assert!(
336 !gp.kernel_search_results().is_empty(),
337 "search results must not be empty"
338 );
339 }
340
341 #[test]
342 fn auto_kernel_gp_cv_score_improves_with_depth() {
343 // Depth-2 search should find a kernel with CV-MSE ≤ depth-1 on sin data
344 // (same data — strictly: at least one depth-2 candidate covers all depth-1 ones).
345 let (x, y) = sin_data(18);
346 let mut scores = Vec::new();
347 for max_depth in [0usize, 1, 2] {
348 let config = AutoKernelGpConfig {
349 max_depth,
350 cv_folds: 3,
351 n_restarts: 1,
352 noise_variance: 1e-3,
353 ..Default::default()
354 };
355 let mut gp = AutoKernelGp::new(config);
356 gp.fit(&x, &y).expect("fit ok");
357 scores.push(gp.best_cv_score());
358 }
359 // Each deeper search covers at least as many candidates as the shallower one,
360 // so CV score should be non-increasing.
361 assert!(
362 scores[1] <= scores[0] * 1.1,
363 "depth-1 score {} should be ≤ depth-0 score {} (with 10% tolerance)",
364 scores[1],
365 scores[0]
366 );
367 assert!(
368 scores[2] <= scores[1] * 1.1,
369 "depth-2 score {} should be ≤ depth-1 score {} (with 10% tolerance)",
370 scores[2],
371 scores[1]
372 );
373 }
374
375 #[test]
376 fn auto_kernel_gp_predict_before_fit_errors() {
377 let gp = AutoKernelGp::new(AutoKernelGpConfig::default());
378 let result = gp.predict(&[0.5, 1.0]);
379 assert!(result.is_err(), "predict before fit should return an error");
380 }
381}