Skip to main content

scirs2_interpolate/auto_kernel_gp/
kernel.rs

1//! Kernel expressions for automatic kernel structure discovery.
2//!
3//! Provides a closed grammar of base kernels and composite expressions formed
4//! by sum and product operators.  Each [`KernelExpr`] can be evaluated at a
5//! pair of scalar inputs and described as a human-readable string.
6//!
7//! ## Base kernels
8//!
9//! | Variant | Formula |
10//! |---------|---------|
11//! | `Rbf` | exp(-‖x-y‖² / (2 ℓ²)) |
12//! | `Matern52` | (1 + √5·r/ℓ + 5r²/(3ℓ²)) exp(-√5·r/ℓ) |
13//! | `Periodic` | exp(-2 sin²(π‖x-y‖/p) / ℓ²) |
14//! | `Linear` | σ² · x · y |
15//! | `WhiteNoise` | σ² · 𝟙{x == y} |
16
17/// One of the atomic kernel functions in the grammar.
18#[derive(Debug, Clone, PartialEq)]
19pub enum BaseKernel {
20    /// Squared-exponential (RBF) kernel with length scale ℓ > 0.
21    Rbf { length_scale: f64 },
22    /// Matérn 5/2 kernel with length scale ℓ > 0.
23    Matern52 { length_scale: f64 },
24    /// Periodic kernel with period p > 0 and length scale ℓ > 0.
25    Periodic { period: f64, length_scale: f64 },
26    /// Linear kernel with variance σ² > 0: k(x, y) = σ² · x · y.
27    Linear { variance: f64 },
28    /// White-noise kernel: k(x, x) = σ², k(x, y) = 0 for x ≠ y.
29    WhiteNoise { variance: f64 },
30}
31
32impl BaseKernel {
33    /// Evaluate k(x1, x2).
34    pub fn eval(&self, x1: f64, x2: f64) -> f64 {
35        match self {
36            BaseKernel::Rbf { length_scale } => {
37                let ell = length_scale.max(1e-10);
38                let d = x1 - x2;
39                (-d * d / (2.0 * ell * ell)).exp()
40            }
41            BaseKernel::Matern52 { length_scale } => {
42                let ell = length_scale.max(1e-10);
43                let r = (x1 - x2).abs();
44                let s = 5.0_f64.sqrt() * r / ell;
45                (1.0 + s + s * s / 3.0) * (-s).exp()
46            }
47            BaseKernel::Periodic {
48                period,
49                length_scale,
50            } => {
51                let p = period.max(1e-10);
52                let ell = length_scale.max(1e-10);
53                let d = (x1 - x2).abs();
54                let sin_val = (std::f64::consts::PI * d / p).sin();
55                (-2.0 * sin_val * sin_val / (ell * ell)).exp()
56            }
57            BaseKernel::Linear { variance } => variance * x1 * x2,
58            BaseKernel::WhiteNoise { variance } => {
59                if (x1 - x2).abs() < 1e-12 {
60                    *variance
61                } else {
62                    0.0
63                }
64            }
65        }
66    }
67
68    /// Short human-readable name.
69    pub fn name(&self) -> &'static str {
70        match self {
71            BaseKernel::Rbf { .. } => "RBF",
72            BaseKernel::Matern52 { .. } => "Matern52",
73            BaseKernel::Periodic { .. } => "Periodic",
74            BaseKernel::Linear { .. } => "Linear",
75            BaseKernel::WhiteNoise { .. } => "WhiteNoise",
76        }
77    }
78
79    /// Return a version of this kernel with updated hyperparameters.
80    pub fn with_length_scale(&self, new_ell: f64) -> Self {
81        match self {
82            BaseKernel::Rbf { .. } => BaseKernel::Rbf {
83                length_scale: new_ell,
84            },
85            BaseKernel::Matern52 { .. } => BaseKernel::Matern52 {
86                length_scale: new_ell,
87            },
88            BaseKernel::Periodic { period, .. } => BaseKernel::Periodic {
89                period: *period,
90                length_scale: new_ell,
91            },
92            BaseKernel::Linear { .. } => BaseKernel::Linear { variance: new_ell },
93            BaseKernel::WhiteNoise { .. } => BaseKernel::WhiteNoise { variance: new_ell },
94        }
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Kernel expression tree
100// ---------------------------------------------------------------------------
101
102/// A compositional kernel expression formed by summing or multiplying sub-kernels.
103///
104/// The expression depth is defined recursively:
105/// - `Base(k)` has depth 0.
106/// - `Sum(a, b)` and `Product(a, b)` have depth 1 + max(depth(a), depth(b)).
107#[derive(Debug, Clone)]
108pub enum KernelExpr {
109    /// A single base kernel.
110    Base(BaseKernel),
111    /// Sum of two kernel expressions: k(x, y) = a(x, y) + b(x, y).
112    Sum(Box<KernelExpr>, Box<KernelExpr>),
113    /// Product of two kernel expressions: k(x, y) = a(x, y) · b(x, y).
114    Product(Box<KernelExpr>, Box<KernelExpr>),
115}
116
117impl KernelExpr {
118    /// Evaluate the kernel at `(x1, x2)`.
119    pub fn eval(&self, x1: f64, x2: f64) -> f64 {
120        match self {
121            KernelExpr::Base(k) => k.eval(x1, x2),
122            KernelExpr::Sum(a, b) => a.eval(x1, x2) + b.eval(x1, x2),
123            KernelExpr::Product(a, b) => a.eval(x1, x2) * b.eval(x1, x2),
124        }
125    }
126
127    /// Human-readable description, e.g. `"RBF + Periodic × Linear"`.
128    pub fn description(&self) -> String {
129        match self {
130            KernelExpr::Base(k) => k.name().to_string(),
131            KernelExpr::Sum(a, b) => format!("{} + {}", a.description(), b.description()),
132            KernelExpr::Product(a, b) => {
133                // Add parens around sum sub-expressions to make precedence clear.
134                let ad = a.description();
135                let bd = b.description();
136                let left = if matches!(a.as_ref(), KernelExpr::Sum(_, _)) {
137                    format!("({ad})")
138                } else {
139                    ad
140                };
141                let right = if matches!(b.as_ref(), KernelExpr::Sum(_, _)) {
142                    format!("({bd})")
143                } else {
144                    bd
145                };
146                format!("{left} × {right}")
147            }
148        }
149    }
150
151    /// Compute the full depth of this expression tree.
152    pub fn depth(&self) -> usize {
153        match self {
154            KernelExpr::Base(_) => 0,
155            KernelExpr::Sum(a, b) | KernelExpr::Product(a, b) => 1 + a.depth().max(b.depth()),
156        }
157    }
158}
159
160/// Enumerate all base kernels with initial hyperparameters.
161pub fn base_kernels() -> Vec<BaseKernel> {
162    vec![
163        BaseKernel::Rbf { length_scale: 1.0 },
164        BaseKernel::Matern52 { length_scale: 1.0 },
165        BaseKernel::Periodic {
166            period: 1.0,
167            length_scale: 1.0,
168        },
169        BaseKernel::Linear { variance: 1.0 },
170        BaseKernel::WhiteNoise { variance: 0.1 },
171    ]
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn kernel_expr_eval_rbf_is_symmetric() {
180        let k = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
181        let a = k.eval(0.5, 1.5);
182        let b = k.eval(1.5, 0.5);
183        assert!(
184            (a - b).abs() < 1e-12,
185            "RBF must be symmetric: k(0.5,1.5)={a}, k(1.5,0.5)={b}"
186        );
187    }
188
189    #[test]
190    fn kernel_expr_sum_is_additive() {
191        let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
192        let k2 = KernelExpr::Base(BaseKernel::Matern52 { length_scale: 1.0 });
193        let sum = KernelExpr::Sum(Box::new(k1.clone()), Box::new(k2.clone()));
194        let x = 0.3;
195        let y = 1.2;
196        let expected = k1.eval(x, y) + k2.eval(x, y);
197        let actual = sum.eval(x, y);
198        assert!(
199            (actual - expected).abs() < 1e-12,
200            "Sum kernel must equal sum of parts: {actual} vs {expected}"
201        );
202    }
203
204    #[test]
205    fn kernel_expr_product_is_multiplicative() {
206        let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 0.8 });
207        let k2 = KernelExpr::Base(BaseKernel::Linear { variance: 2.0 });
208        let prod = KernelExpr::Product(Box::new(k1.clone()), Box::new(k2.clone()));
209        let x = 1.5;
210        let y = 0.5;
211        let expected = k1.eval(x, y) * k2.eval(x, y);
212        let actual = prod.eval(x, y);
213        assert!(
214            (actual - expected).abs() < 1e-12,
215            "Product kernel must equal product of parts: {actual} vs {expected}"
216        );
217    }
218
219    #[test]
220    fn kernel_description_nonempty() {
221        let k = KernelExpr::Sum(
222            Box::new(KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 })),
223            Box::new(KernelExpr::Product(
224                Box::new(KernelExpr::Base(BaseKernel::Periodic {
225                    period: 1.0,
226                    length_scale: 0.5,
227                })),
228                Box::new(KernelExpr::Base(BaseKernel::Linear { variance: 1.0 })),
229            )),
230        );
231        let desc = k.description();
232        assert!(!desc.is_empty(), "description must not be empty");
233        assert!(desc.contains("RBF"), "description should contain RBF");
234        assert!(
235            desc.contains("Periodic"),
236            "description should contain Periodic"
237        );
238    }
239}