1use super::{CovGrad, CovGradError, Kernel, KernelError};
2use nalgebra::{
3 base::constraint::{SameNumberOfColumns, ShapeConstraint},
4 dvector, EuclideanNorm,
5};
6use nalgebra::{base::storage::Storage, Norm};
7use nalgebra::{DMatrix, DVector, Dim, Matrix};
8use std::f64;
9use std::f64::consts::PI;
10
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone, Debug, PartialEq)]
17#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
18#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
19pub struct ExpSineSquaredKernel {
20 length_scale: f64,
21 periodicity: f64,
22}
23
24impl ExpSineSquaredKernel {
25 pub fn new(
27 length_scale: f64,
28 periodicity: f64,
29 ) -> Result<Self, KernelError> {
30 if length_scale <= 0.0 {
31 Err(KernelError::ParameterOutOfBounds {
32 name: "length_scale".to_string(),
33 given: length_scale,
34 bounds: (0.0, f64::INFINITY),
35 })
36 } else if periodicity <= 0.0 {
37 Err(KernelError::ParameterOutOfBounds {
38 name: "periodicity".to_string(),
39 given: periodicity,
40 bounds: (0.0, f64::INFINITY),
41 })
42 } else {
43 Ok(Self {
44 length_scale,
45 periodicity,
46 })
47 }
48 }
49
50 pub fn new_unchecked(length_scale: f64, periodicity: f64) -> Self {
52 Self {
53 length_scale,
54 periodicity,
55 }
56 }
57}
58
59impl Kernel for ExpSineSquaredKernel {
60 fn n_parameters(&self) -> usize {
61 2
62 }
63
64 fn covariance<R1, R2, C1, C2, S1, S2>(
65 &self,
66 x1: &Matrix<f64, R1, C1, S1>,
67 x2: &Matrix<f64, R2, C2, S2>,
68 ) -> DMatrix<f64>
69 where
70 R1: Dim,
71 R2: Dim,
72 C1: Dim,
73 C2: Dim,
74 S1: Storage<f64, R1, C1>,
75 S2: Storage<f64, R2, C2>,
76 ShapeConstraint: SameNumberOfColumns<C1, C2>,
77 {
78 assert!(x1.ncols() == x2.ncols());
79 let metric = EuclideanNorm {};
80 let l2 = self.length_scale.powi(2);
81
82 DMatrix::from_fn(x1.nrows(), x2.nrows(), |i, j| {
83 let d = metric.metric_distance(&x1.row(i), &x2.row(j));
84 let s2 = (PI * d / self.periodicity).sin().powi(2);
85 (-2.0 * s2 / l2).exp()
86 })
87 }
88
89 fn is_stationary(&self) -> bool {
90 true
91 }
92
93 fn diag<R, C, S>(&self, x: &Matrix<f64, R, C, S>) -> DVector<f64>
95 where
96 R: Dim,
97 C: Dim,
98 S: Storage<f64, R, C>,
99 {
100 DVector::repeat(x.len(), 1.0)
101 }
102
103 fn parameters(&self) -> DVector<f64> {
106 dvector![self.length_scale.ln(), self.periodicity.ln()]
107 }
108
109 fn reparameterize(&self, params: &[f64]) -> Result<Self, KernelError> {
112 match params {
113 [] => Err(KernelError::MissingParameters(2)),
114 [_] => Err(KernelError::MissingParameters(1)),
115 [length_scale, periodicity] => {
116 Self::new(length_scale.exp(), periodicity.exp())
117 }
118 _ => Err(KernelError::ExtraneousParameters(params.len() - 1)),
119 }
120 }
121
122 fn covariance_with_gradient<R, C, S>(
124 &self,
125 x: &Matrix<f64, R, C, S>,
126 ) -> Result<(DMatrix<f64>, CovGrad), CovGradError>
127 where
128 R: Dim,
129 C: Dim,
130 S: Storage<f64, R, C>,
131 {
132 let n = x.nrows();
133 let mut cov = DMatrix::zeros(n, n);
134 let mut grad = CovGrad::zeros(n, 2);
135 let metric = EuclideanNorm {};
136 let l2 = self.length_scale.powi(2);
137
138 for i in 0..n {
140 for j in 0..i {
141 let d = metric.metric_distance(&x.row(i), &x.row(j));
142 let arg = PI * d / self.periodicity;
143
144 let sin_arg = arg.sin();
145 let sin_arg_2 = sin_arg.powi(2);
146 let cos_arg = arg.cos();
147
148 let k = (-2.0 * sin_arg_2 / l2).exp();
149 cov[(i, j)] = k;
150 cov[(j, i)] = k;
151
152 let dk_dl = 4.0 * sin_arg_2 * k / l2;
153 grad[(i, j, 0)] = dk_dl;
154 grad[(j, i, 0)] = dk_dl;
155
156 let dk_dp = (4.0 * arg / l2) * cos_arg * sin_arg * k;
157 grad[(i, j, 1)] = dk_dp;
158 grad[(j, i, 1)] = dk_dp;
159 }
160 cov[(i, i)] = 1.0;
162 }
163 Ok((cov, grad))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn expsinesquared_kernel_a() -> Result<(), KernelError> {
173 let kernel = ExpSineSquaredKernel::new(3.0, 5.0)?;
174 assert!(kernel.is_stationary());
175
176 let kernel = ExpSineSquaredKernel::new(1.0, 1.0)?;
177
178 let x: DMatrix<f64> =
179 DMatrix::from_row_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
180 let y = x.map(|z| z.sin());
181
182 let cov = kernel.covariance(&x, &y);
183 let expected_cov = DMatrix::from_row_slice(
184 5,
185 5,
186 &[
187 0.383_938_97,
188 0.692_107_48,
189 0.853_810_81,
190 0.633_565_1,
191 0.633_565_1,
192 0.383_938_97,
193 0.692_107_48,
194 0.853_810_81,
195 0.633_565_1,
196 0.633_565_1,
197 0.383_938_97,
198 0.692_107_48,
199 0.853_810_81,
200 0.633_565_1,
201 0.633_565_1,
202 0.383_938_97,
203 0.692_107_48,
204 0.853_810_81,
205 0.633_565_1,
206 0.633_565_1,
207 0.383_938_97,
208 0.692_107_48,
209 0.853_810_81,
210 0.633_565_1,
211 0.633_565_1,
212 ],
213 );
214 assert!(cov.relative_eq(&expected_cov, 1E-7, 1E-7));
215
216 let (cov, grad) = kernel.covariance_with_gradient(&x)?;
217
218 let expected_cov = DMatrix::from_row_slice(
219 5,
220 5,
221 &[
222 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
223 1., 1., 1., 1., 1., 1., 1., 1., 1.,
224 ],
225 );
226
227 let expected_grad = CovGrad::new_unchecked(&[
228 DMatrix::from_row_slice(
229 5,
230 5,
231 &[
232 0.000_000_00e+00,
233 5.999_039_13e-32,
234 2.399_615_65e-31,
235 5.399_135_22e-31,
236 1.499_759_78e-30,
237 5.999_039_13e-32,
238 0.000_000_00e+00,
239 5.999_039_13e-32,
240 2.399_615_65e-31,
241 9.598_462_61e-31,
242 2.399_615_65e-31,
243 5.999_039_13e-32,
244 0.000_000_00e+00,
245 5.999_039_13e-32,
246 5.399_135_22e-31,
247 5.399_135_22e-31,
248 2.399_615_65e-31,
249 5.999_039_13e-32,
250 0.000_000_00e+00,
251 2.399_615_65e-31,
252 1.499_759_78e-30,
253 9.598_462_61e-31,
254 5.399_135_22e-31,
255 2.399_615_65e-31,
256 0.000_000_00e+00,
257 ],
258 ),
259 DMatrix::from_row_slice(
260 5,
261 5,
262 &[
263 0.000_000_00e+00,
264 -1.538_936_55e-15,
265 -6.155_746_22e-15,
266 -1.385_042_90e-14,
267 -3.847_341_39e-14,
268 -1.538_936_55e-15,
269 0.000_000_00e+00,
270 -1.538_936_55e-15,
271 -6.155_746_22e-15,
272 -2.462_298_49e-14,
273 -6.155_746_22e-15,
274 -1.538_936_55e-15,
275 0.000_000_00e+00,
276 -1.538_936_55e-15,
277 -1.385_042_90e-14,
278 -1.385_042_90e-14,
279 -6.155_746_22e-15,
280 -1.538_936_55e-15,
281 0.000_000_00e+00,
282 -6.155_746_22e-15,
283 -3.847_341_39e-14,
284 -2.462_298_49e-14,
285 -1.385_042_90e-14,
286 -6.155_746_22e-15,
287 0.000_000_00e+00,
288 ],
289 ),
290 ]);
291 assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
292 assert!(grad.relative_eq(&expected_grad, 1E-8, 1E-8));
293 Ok(())
294 }
295
296 #[test]
297 fn expsinesquared_kernel_b() -> Result<(), KernelError> {
298 let x: DMatrix<f64> =
299 DMatrix::from_row_slice(5, 1, &[-4.0, -3.0, -2.0, -1.0, 1.0]);
300 let kernel = ExpSineSquaredKernel::new(5.0, 2.0 * f64::consts::PI)?;
302 let (cov, grad) = kernel.covariance_with_gradient(&x)?;
303 let expected_cov = DMatrix::from_row_slice(
304 5,
305 5,
306 &[
307 1.,
308 0.981_780_12,
309 0.944_928_63,
310 0.923_485_94,
311 0.971_753_11,
312 0.981_780_12,
313 1.,
314 0.981_780_12,
315 0.944_928_63,
316 0.935_994_44,
317 0.944_928_63,
318 0.981_780_12,
319 1.,
320 0.981_780_12,
321 0.923_485_94,
322 0.923_485_94,
323 0.944_928_63,
324 0.981_780_12,
325 1.,
326 0.944_928_63,
327 0.971_753_11,
328 0.935_994_44,
329 0.923_485_94,
330 0.944_928_63,
331 1.,
332 ],
333 );
334
335 let expected_grad = CovGrad::new_unchecked(&[
336 DMatrix::from_row_slice(
337 5,
338 5,
339 &[
340 0.,
341 0.036_105_76,
342 0.107_052_62,
343 0.147_018_41,
344 0.055_688_28,
345 0.036_105_76,
346 0.,
347 0.036_105_76,
348 0.107_052_62,
349 0.123_824_1,
350 0.107_052_62,
351 0.036_105_76,
352 0.,
353 0.036_105_76,
354 0.147_018_41,
355 0.147_018_41,
356 0.107_052_62,
357 0.036_105_76,
358 0.,
359 0.107_052_62,
360 0.055_688_28,
361 0.123_824_1,
362 0.147_018_41,
363 0.107_052_62,
364 0.,
365 ],
366 ),
367 DMatrix::from_row_slice(
368 5,
369 5,
370 &[
371 0.,
372 0.033_045_58,
373 0.068_737_69,
374 0.015_638_68,
375 -0.186_367_53,
376 0.033_045_58,
377 0.,
378 0.033_045_58,
379 0.068_737_69,
380 -0.113_338_07,
381 0.068_737_69,
382 0.033_045_58,
383 0.,
384 0.033_045_58,
385 0.015_638_68,
386 0.015_638_68,
387 0.068_737_69,
388 0.033_045_58,
389 0.,
390 0.068_737_69,
391 -0.186_367_53,
392 -0.113_338_07,
393 0.015_638_68,
394 0.068_737_69,
395 0.,
396 ],
397 ),
398 ]);
399 assert!(cov.relative_eq(&expected_cov, 1E-8, 1E-8));
400 assert!(grad.relative_eq(&expected_grad, 1E-8, 1E-8));
401
402 Ok(())
403 }
404}