Skip to main content

scirs2_stats/gaussian_process/
kernel.rs

1//! Kernel functions for Gaussian Processes
2//!
3//! Kernels (covariance functions) determine the properties of functions drawn from
4//! a Gaussian Process. They encode assumptions about the function being modeled.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8
9/// Trait for kernel functions (covariance functions)
10pub trait Kernel: Clone + Send + Sync {
11    /// Compute covariance between two points
12    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64;
13
14    /// Compute covariance matrix for a set of points
15    fn compute_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
16        let n = x.nrows();
17        let mut k = Array2::zeros((n, n));
18
19        for i in 0..n {
20            for j in 0..=i {
21                let kij = self.compute(&x.row(i), &x.row(j));
22                k[[i, j]] = kij;
23                if i != j {
24                    k[[j, i]] = kij;
25                }
26            }
27        }
28
29        k
30    }
31
32    /// Compute covariance matrix between two sets of points
33    fn compute_cross_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
34        let n1 = x1.nrows();
35        let n2 = x2.nrows();
36        let mut k = Array2::zeros((n1, n2));
37
38        for i in 0..n1 {
39            for j in 0..n2 {
40                k[[i, j]] = self.compute(&x1.row(i), &x2.row(j));
41            }
42        }
43
44        k
45    }
46
47    /// Get kernel parameters (for optimization)
48    fn get_params(&self) -> Vec<f64>;
49
50    /// Set kernel parameters
51    fn set_params(&mut self, params: &[f64]);
52
53    /// Get number of parameters
54    fn n_params(&self) -> usize {
55        self.get_params().len()
56    }
57}
58
59/// Squared Exponential (RBF) kernel: k(x, x') = σ² exp(-||x - x'||² / (2 l²))
60///
61/// This is the most commonly used kernel. It assumes smoothness and is infinitely
62/// differentiable.
63#[derive(Debug, Clone)]
64pub struct SquaredExponential {
65    /// Length scale parameter (controls how quickly correlation decays)
66    pub length_scale: f64,
67    /// Signal variance (controls output scale)
68    pub signal_variance: f64,
69}
70
71impl SquaredExponential {
72    /// Create a new Squared Exponential kernel
73    pub fn new(length_scale: f64, signal_variance: f64) -> Self {
74        Self {
75            length_scale,
76            signal_variance,
77        }
78    }
79}
80
81impl Default for SquaredExponential {
82    fn default() -> Self {
83        Self {
84            length_scale: 1.0,
85            signal_variance: 1.0,
86        }
87    }
88}
89
90impl Kernel for SquaredExponential {
91    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
92        let mut sq_dist = 0.0;
93        for i in 0..x1.len() {
94            let diff = x1[i] - x2[i];
95            sq_dist += diff * diff;
96        }
97
98        self.signal_variance * (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp()
99    }
100
101    fn get_params(&self) -> Vec<f64> {
102        vec![self.length_scale, self.signal_variance]
103    }
104
105    fn set_params(&mut self, params: &[f64]) {
106        if params.len() >= 2 {
107            self.length_scale = params[0];
108            self.signal_variance = params[1];
109        }
110    }
111}
112
113/// Matérn kernel with ν = 1/2
114///
115/// Equivalent to Exponential kernel: k(x, x') = σ² exp(-||x - x'|| / l)
116/// This kernel produces rougher functions than RBF.
117#[derive(Debug, Clone)]
118pub struct Matern12 {
119    pub length_scale: f64,
120    pub signal_variance: f64,
121}
122
123impl Matern12 {
124    pub fn new(length_scale: f64, signal_variance: f64) -> Self {
125        Self {
126            length_scale,
127            signal_variance,
128        }
129    }
130}
131
132impl Default for Matern12 {
133    fn default() -> Self {
134        Self {
135            length_scale: 1.0,
136            signal_variance: 1.0,
137        }
138    }
139}
140
141impl Kernel for Matern12 {
142    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
143        let mut sq_dist = 0.0;
144        for i in 0..x1.len() {
145            let diff = x1[i] - x2[i];
146            sq_dist += diff * diff;
147        }
148        let dist = sq_dist.sqrt();
149
150        self.signal_variance * (-dist / self.length_scale).exp()
151    }
152
153    fn get_params(&self) -> Vec<f64> {
154        vec![self.length_scale, self.signal_variance]
155    }
156
157    fn set_params(&mut self, params: &[f64]) {
158        if params.len() >= 2 {
159            self.length_scale = params[0];
160            self.signal_variance = params[1];
161        }
162    }
163}
164
165/// Matérn kernel with ν = 3/2
166///
167/// k(x, x') = σ² (1 + √3 r / l) exp(-√3 r / l), where r = ||x - x'||
168/// Once differentiable, smoother than Matérn 1/2.
169#[derive(Debug, Clone)]
170pub struct Matern32 {
171    pub length_scale: f64,
172    pub signal_variance: f64,
173}
174
175impl Matern32 {
176    pub fn new(length_scale: f64, signal_variance: f64) -> Self {
177        Self {
178            length_scale,
179            signal_variance,
180        }
181    }
182}
183
184impl Default for Matern32 {
185    fn default() -> Self {
186        Self {
187            length_scale: 1.0,
188            signal_variance: 1.0,
189        }
190    }
191}
192
193impl Kernel for Matern32 {
194    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
195        let mut sq_dist = 0.0;
196        for i in 0..x1.len() {
197            let diff = x1[i] - x2[i];
198            sq_dist += diff * diff;
199        }
200        let dist = sq_dist.sqrt();
201        let sqrt3 = 3.0_f64.sqrt();
202        let arg = sqrt3 * dist / self.length_scale;
203
204        self.signal_variance * (1.0 + arg) * (-arg).exp()
205    }
206
207    fn get_params(&self) -> Vec<f64> {
208        vec![self.length_scale, self.signal_variance]
209    }
210
211    fn set_params(&mut self, params: &[f64]) {
212        if params.len() >= 2 {
213            self.length_scale = params[0];
214            self.signal_variance = params[1];
215        }
216    }
217}
218
219/// Matérn kernel with ν = 5/2
220///
221/// k(x, x') = σ² (1 + √5 r / l + 5 r² / (3 l²)) exp(-√5 r / l)
222/// Twice differentiable, very smooth.
223#[derive(Debug, Clone)]
224pub struct Matern52 {
225    pub length_scale: f64,
226    pub signal_variance: f64,
227}
228
229impl Matern52 {
230    pub fn new(length_scale: f64, signal_variance: f64) -> Self {
231        Self {
232            length_scale,
233            signal_variance,
234        }
235    }
236}
237
238impl Default for Matern52 {
239    fn default() -> Self {
240        Self {
241            length_scale: 1.0,
242            signal_variance: 1.0,
243        }
244    }
245}
246
247impl Kernel for Matern52 {
248    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
249        let mut sq_dist = 0.0;
250        for i in 0..x1.len() {
251            let diff = x1[i] - x2[i];
252            sq_dist += diff * diff;
253        }
254        let dist = sq_dist.sqrt();
255        let sqrt5 = 5.0_f64.sqrt();
256        let arg = sqrt5 * dist / self.length_scale;
257        let arg2 = 5.0 * sq_dist / (3.0 * self.length_scale * self.length_scale);
258
259        self.signal_variance * (1.0 + arg + arg2) * (-arg).exp()
260    }
261
262    fn get_params(&self) -> Vec<f64> {
263        vec![self.length_scale, self.signal_variance]
264    }
265
266    fn set_params(&mut self, params: &[f64]) {
267        if params.len() >= 2 {
268            self.length_scale = params[0];
269            self.signal_variance = params[1];
270        }
271    }
272}
273
274/// White kernel (noise): k(x, x') = σ² δ(x, x')
275///
276/// This kernel only produces variance on the diagonal, representing
277/// independent noise on observations.
278#[derive(Debug, Clone)]
279pub struct WhiteKernel {
280    pub noise_level: f64,
281}
282
283impl WhiteKernel {
284    pub fn new(noise_level: f64) -> Self {
285        Self { noise_level }
286    }
287}
288
289impl Default for WhiteKernel {
290    fn default() -> Self {
291        Self { noise_level: 0.01 }
292    }
293}
294
295impl Kernel for WhiteKernel {
296    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
297        // Check if points are identical
298        let identical = x1
299            .iter()
300            .zip(x2.iter())
301            .all(|(&a, &b)| (a - b).abs() < 1e-10);
302
303        if identical {
304            self.noise_level
305        } else {
306            0.0
307        }
308    }
309
310    fn get_params(&self) -> Vec<f64> {
311        vec![self.noise_level]
312    }
313
314    fn set_params(&mut self, params: &[f64]) {
315        if !params.is_empty() {
316            self.noise_level = params[0];
317        }
318    }
319}
320
321/// Sum of two kernels: k(x, x') = k1(x, x') + k2(x, x')
322#[derive(Debug, Clone)]
323pub struct SumKernel<K1: Kernel, K2: Kernel> {
324    pub kernel1: K1,
325    pub kernel2: K2,
326}
327
328impl<K1: Kernel, K2: Kernel> SumKernel<K1, K2> {
329    pub fn new(kernel1: K1, kernel2: K2) -> Self {
330        Self { kernel1, kernel2 }
331    }
332}
333
334impl<K1: Kernel, K2: Kernel> Kernel for SumKernel<K1, K2> {
335    fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
336        self.kernel1.compute(x1, x2) + self.kernel2.compute(x1, x2)
337    }
338
339    fn get_params(&self) -> Vec<f64> {
340        let mut params = self.kernel1.get_params();
341        params.extend(self.kernel2.get_params());
342        params
343    }
344
345    fn set_params(&mut self, params: &[f64]) {
346        let n1 = self.kernel1.n_params();
347        if params.len() >= n1 {
348            self.kernel1.set_params(&params[..n1]);
349            if params.len() > n1 {
350                self.kernel2.set_params(&params[n1..]);
351            }
352        }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use scirs2_core::ndarray::array;
360
361    #[test]
362    fn test_squared_exponential() {
363        let kernel = SquaredExponential::default();
364        let x1 = array![0.0, 0.0];
365        let x2 = array![1.0, 1.0];
366
367        // Self-covariance should be signal_variance
368        assert!((kernel.compute(&x1.view(), &x1.view()) - 1.0).abs() < 1e-10);
369
370        // Cross-covariance should decay with distance
371        let k12 = kernel.compute(&x1.view(), &x2.view());
372        assert!(k12 < 1.0);
373        assert!(k12 > 0.0);
374    }
375
376    #[test]
377    fn test_matern_kernels() {
378        let m12 = Matern12::default();
379        let m32 = Matern32::default();
380        let m52 = Matern52::default();
381
382        let x1 = array![0.0];
383        let x2 = array![1.0];
384
385        // All should decay but at different rates
386        let k12 = m12.compute(&x1.view(), &x2.view());
387        let k32 = m32.compute(&x1.view(), &x2.view());
388        let k52 = m52.compute(&x1.view(), &x2.view());
389
390        assert!(k12 > 0.0 && k12 < 1.0);
391        assert!(k32 > 0.0 && k32 < 1.0);
392        assert!(k52 > 0.0 && k52 < 1.0);
393    }
394
395    #[test]
396    fn test_white_kernel() {
397        let kernel = WhiteKernel::new(0.1);
398        let x1 = array![0.0, 0.0];
399        let x2 = array![1.0, 1.0];
400
401        // Same points should give noise_level
402        assert!((kernel.compute(&x1.view(), &x1.view()) - 0.1).abs() < 1e-10);
403
404        // Different points should give 0
405        assert!((kernel.compute(&x1.view(), &x2.view())).abs() < 1e-10);
406    }
407
408    #[test]
409    fn test_sum_kernel() {
410        let rbf = SquaredExponential::default();
411        let noise = WhiteKernel::new(0.1);
412        let kernel = SumKernel::new(rbf, noise);
413
414        let x1 = array![0.0];
415
416        // Self-covariance should be signal_variance + noise
417        let k = kernel.compute(&x1.view(), &x1.view());
418        assert!((k - 1.1).abs() < 1e-10);
419    }
420}