scirs2_stats/gaussian_process/
kernel.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Float;
8
9pub trait Kernel: Clone + Send + Sync {
11 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64;
13
14 fn compute_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
16 let n = x.nrows();
17 let mut k = Array2::zeros((n, n));
18
19 for i in 0..n {
20 for j in 0..=i {
21 let kij = self.compute(&x.row(i), &x.row(j));
22 k[[i, j]] = kij;
23 if i != j {
24 k[[j, i]] = kij;
25 }
26 }
27 }
28
29 k
30 }
31
32 fn compute_cross_matrix(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
34 let n1 = x1.nrows();
35 let n2 = x2.nrows();
36 let mut k = Array2::zeros((n1, n2));
37
38 for i in 0..n1 {
39 for j in 0..n2 {
40 k[[i, j]] = self.compute(&x1.row(i), &x2.row(j));
41 }
42 }
43
44 k
45 }
46
47 fn get_params(&self) -> Vec<f64>;
49
50 fn set_params(&mut self, params: &[f64]);
52
53 fn n_params(&self) -> usize {
55 self.get_params().len()
56 }
57}
58
59#[derive(Debug, Clone)]
64pub struct SquaredExponential {
65 pub length_scale: f64,
67 pub signal_variance: f64,
69}
70
71impl SquaredExponential {
72 pub fn new(length_scale: f64, signal_variance: f64) -> Self {
74 Self {
75 length_scale,
76 signal_variance,
77 }
78 }
79}
80
81impl Default for SquaredExponential {
82 fn default() -> Self {
83 Self {
84 length_scale: 1.0,
85 signal_variance: 1.0,
86 }
87 }
88}
89
90impl Kernel for SquaredExponential {
91 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
92 let mut sq_dist = 0.0;
93 for i in 0..x1.len() {
94 let diff = x1[i] - x2[i];
95 sq_dist += diff * diff;
96 }
97
98 self.signal_variance * (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp()
99 }
100
101 fn get_params(&self) -> Vec<f64> {
102 vec![self.length_scale, self.signal_variance]
103 }
104
105 fn set_params(&mut self, params: &[f64]) {
106 if params.len() >= 2 {
107 self.length_scale = params[0];
108 self.signal_variance = params[1];
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
118pub struct Matern12 {
119 pub length_scale: f64,
120 pub signal_variance: f64,
121}
122
123impl Matern12 {
124 pub fn new(length_scale: f64, signal_variance: f64) -> Self {
125 Self {
126 length_scale,
127 signal_variance,
128 }
129 }
130}
131
132impl Default for Matern12 {
133 fn default() -> Self {
134 Self {
135 length_scale: 1.0,
136 signal_variance: 1.0,
137 }
138 }
139}
140
141impl Kernel for Matern12 {
142 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
143 let mut sq_dist = 0.0;
144 for i in 0..x1.len() {
145 let diff = x1[i] - x2[i];
146 sq_dist += diff * diff;
147 }
148 let dist = sq_dist.sqrt();
149
150 self.signal_variance * (-dist / self.length_scale).exp()
151 }
152
153 fn get_params(&self) -> Vec<f64> {
154 vec![self.length_scale, self.signal_variance]
155 }
156
157 fn set_params(&mut self, params: &[f64]) {
158 if params.len() >= 2 {
159 self.length_scale = params[0];
160 self.signal_variance = params[1];
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
170pub struct Matern32 {
171 pub length_scale: f64,
172 pub signal_variance: f64,
173}
174
175impl Matern32 {
176 pub fn new(length_scale: f64, signal_variance: f64) -> Self {
177 Self {
178 length_scale,
179 signal_variance,
180 }
181 }
182}
183
184impl Default for Matern32 {
185 fn default() -> Self {
186 Self {
187 length_scale: 1.0,
188 signal_variance: 1.0,
189 }
190 }
191}
192
193impl Kernel for Matern32 {
194 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
195 let mut sq_dist = 0.0;
196 for i in 0..x1.len() {
197 let diff = x1[i] - x2[i];
198 sq_dist += diff * diff;
199 }
200 let dist = sq_dist.sqrt();
201 let sqrt3 = 3.0_f64.sqrt();
202 let arg = sqrt3 * dist / self.length_scale;
203
204 self.signal_variance * (1.0 + arg) * (-arg).exp()
205 }
206
207 fn get_params(&self) -> Vec<f64> {
208 vec![self.length_scale, self.signal_variance]
209 }
210
211 fn set_params(&mut self, params: &[f64]) {
212 if params.len() >= 2 {
213 self.length_scale = params[0];
214 self.signal_variance = params[1];
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
224pub struct Matern52 {
225 pub length_scale: f64,
226 pub signal_variance: f64,
227}
228
229impl Matern52 {
230 pub fn new(length_scale: f64, signal_variance: f64) -> Self {
231 Self {
232 length_scale,
233 signal_variance,
234 }
235 }
236}
237
238impl Default for Matern52 {
239 fn default() -> Self {
240 Self {
241 length_scale: 1.0,
242 signal_variance: 1.0,
243 }
244 }
245}
246
247impl Kernel for Matern52 {
248 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
249 let mut sq_dist = 0.0;
250 for i in 0..x1.len() {
251 let diff = x1[i] - x2[i];
252 sq_dist += diff * diff;
253 }
254 let dist = sq_dist.sqrt();
255 let sqrt5 = 5.0_f64.sqrt();
256 let arg = sqrt5 * dist / self.length_scale;
257 let arg2 = 5.0 * sq_dist / (3.0 * self.length_scale * self.length_scale);
258
259 self.signal_variance * (1.0 + arg + arg2) * (-arg).exp()
260 }
261
262 fn get_params(&self) -> Vec<f64> {
263 vec![self.length_scale, self.signal_variance]
264 }
265
266 fn set_params(&mut self, params: &[f64]) {
267 if params.len() >= 2 {
268 self.length_scale = params[0];
269 self.signal_variance = params[1];
270 }
271 }
272}
273
274#[derive(Debug, Clone)]
279pub struct WhiteKernel {
280 pub noise_level: f64,
281}
282
283impl WhiteKernel {
284 pub fn new(noise_level: f64) -> Self {
285 Self { noise_level }
286 }
287}
288
289impl Default for WhiteKernel {
290 fn default() -> Self {
291 Self { noise_level: 0.01 }
292 }
293}
294
295impl Kernel for WhiteKernel {
296 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
297 let identical = x1
299 .iter()
300 .zip(x2.iter())
301 .all(|(&a, &b)| (a - b).abs() < 1e-10);
302
303 if identical {
304 self.noise_level
305 } else {
306 0.0
307 }
308 }
309
310 fn get_params(&self) -> Vec<f64> {
311 vec![self.noise_level]
312 }
313
314 fn set_params(&mut self, params: &[f64]) {
315 if !params.is_empty() {
316 self.noise_level = params[0];
317 }
318 }
319}
320
321#[derive(Debug, Clone)]
323pub struct SumKernel<K1: Kernel, K2: Kernel> {
324 pub kernel1: K1,
325 pub kernel2: K2,
326}
327
328impl<K1: Kernel, K2: Kernel> SumKernel<K1, K2> {
329 pub fn new(kernel1: K1, kernel2: K2) -> Self {
330 Self { kernel1, kernel2 }
331 }
332}
333
334impl<K1: Kernel, K2: Kernel> Kernel for SumKernel<K1, K2> {
335 fn compute(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
336 self.kernel1.compute(x1, x2) + self.kernel2.compute(x1, x2)
337 }
338
339 fn get_params(&self) -> Vec<f64> {
340 let mut params = self.kernel1.get_params();
341 params.extend(self.kernel2.get_params());
342 params
343 }
344
345 fn set_params(&mut self, params: &[f64]) {
346 let n1 = self.kernel1.n_params();
347 if params.len() >= n1 {
348 self.kernel1.set_params(¶ms[..n1]);
349 if params.len() > n1 {
350 self.kernel2.set_params(¶ms[n1..]);
351 }
352 }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use scirs2_core::ndarray::array;
360
361 #[test]
362 fn test_squared_exponential() {
363 let kernel = SquaredExponential::default();
364 let x1 = array![0.0, 0.0];
365 let x2 = array![1.0, 1.0];
366
367 assert!((kernel.compute(&x1.view(), &x1.view()) - 1.0).abs() < 1e-10);
369
370 let k12 = kernel.compute(&x1.view(), &x2.view());
372 assert!(k12 < 1.0);
373 assert!(k12 > 0.0);
374 }
375
376 #[test]
377 fn test_matern_kernels() {
378 let m12 = Matern12::default();
379 let m32 = Matern32::default();
380 let m52 = Matern52::default();
381
382 let x1 = array![0.0];
383 let x2 = array![1.0];
384
385 let k12 = m12.compute(&x1.view(), &x2.view());
387 let k32 = m32.compute(&x1.view(), &x2.view());
388 let k52 = m52.compute(&x1.view(), &x2.view());
389
390 assert!(k12 > 0.0 && k12 < 1.0);
391 assert!(k32 > 0.0 && k32 < 1.0);
392 assert!(k52 > 0.0 && k52 < 1.0);
393 }
394
395 #[test]
396 fn test_white_kernel() {
397 let kernel = WhiteKernel::new(0.1);
398 let x1 = array![0.0, 0.0];
399 let x2 = array![1.0, 1.0];
400
401 assert!((kernel.compute(&x1.view(), &x1.view()) - 0.1).abs() < 1e-10);
403
404 assert!((kernel.compute(&x1.view(), &x2.view())).abs() < 1e-10);
406 }
407
408 #[test]
409 fn test_sum_kernel() {
410 let rbf = SquaredExponential::default();
411 let noise = WhiteKernel::new(0.1);
412 let kernel = SumKernel::new(rbf, noise);
413
414 let x1 = array![0.0];
415
416 let k = kernel.compute(&x1.view(), &x1.view());
418 assert!((k - 1.1).abs() < 1e-10);
419 }
420}