rv/process/gaussian/kernel/
mod.rs1use nalgebra::base::constraint::{SameNumberOfColumns, ShapeConstraint};
4use nalgebra::base::storage::Storage;
5use nalgebra::{DMatrix, DVector, Dim, Matrix};
6use std::f64;
7
8#[cfg(feature = "serde1")]
9use serde::{Deserialize, Serialize};
10
11mod covgrad;
12pub use covgrad::*;
13
14mod misc;
15pub use self::misc::*;
16
17mod constant_kernel;
18pub use self::constant_kernel::*;
19
20mod ops;
21pub use self::ops::*;
22
23mod rbf;
24pub use self::rbf::*;
25mod white_kernel;
26pub use self::white_kernel::*;
27mod rational_quadratic;
28pub use self::rational_quadratic::*;
29mod exp_sin_squared;
30pub use self::exp_sin_squared::*;
31mod seard;
32pub use self::seard::*;
33mod matern;
34pub use self::matern::*;
35
36pub trait Kernel: std::fmt::Debug + Clone + PartialEq {
38 fn n_parameters(&self) -> usize;
40
41 fn covariance<R1, R2, C1, C2, S1, S2>(
43 &self,
44 x1: &Matrix<f64, R1, C1, S1>,
45 x2: &Matrix<f64, R2, C2, S2>,
46 ) -> DMatrix<f64>
47 where
48 R1: Dim,
49 R2: Dim,
50 C1: Dim,
51 C2: Dim,
52 S1: Storage<f64, R1, C1>,
53 S2: Storage<f64, R2, C2>,
54 ShapeConstraint: SameNumberOfColumns<C1, C2>;
55
56 fn is_stationary(&self) -> bool;
58
59 fn diag<R, C, S>(&self, x: &Matrix<f64, R, C, S>) -> DVector<f64>
61 where
62 R: Dim,
63 C: Dim,
64 S: Storage<f64, R, C>;
65
66 fn parameters(&self) -> DVector<f64>;
69
70 fn reparameterize(&self, params: &[f64]) -> Result<Self, KernelError>;
73
74 fn consume_parameters<I: IntoIterator<Item = f64>>(
78 &self,
79 params: I,
80 ) -> Result<(Self, I::IntoIter), KernelError> {
81 let mut iter = params.into_iter();
82 let n = self.n_parameters();
83 let mut parameters: Vec<f64> = Vec::with_capacity(n);
84
85 for i in 0..n {
87 parameters.push(
88 iter.next().ok_or(KernelError::MissingParameters(n - i))?,
89 );
90 }
91
92 Ok((self.reparameterize(¶meters)?, iter))
93 }
94
95 fn covariance_with_gradient<R, C, S>(
97 &self,
98 x: &Matrix<f64, R, C, S>,
99 ) -> Result<(DMatrix<f64>, CovGrad), CovGradError>
100 where
101 R: Dim,
102 C: Dim,
103 S: Storage<f64, R, C>;
104
105 fn add<B: Kernel>(self, other: B) -> AddKernel<Self, B> {
106 AddKernel::new(self, other)
107 }
108
109 fn mul<B: Kernel>(self, other: B) -> ProductKernel<Self, B> {
110 ProductKernel::new(self, other)
111 }
112}
113
114#[derive(Debug, Clone, PartialEq)]
116#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
117#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
118pub enum KernelError {
119 ImproperBounds(f64, f64),
121 ParameterOutOfBounds {
123 name: String,
125 given: f64,
127 bounds: (f64, f64),
129 },
130 ExtraneousParameters(usize),
132 MissingParameters(usize),
134 CovGrad(CovGradError),
136}
137
138impl std::error::Error for KernelError {}
139
140impl std::fmt::Display for KernelError {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 match self {
143 Self::ImproperBounds(lower, upper) => {
144 writeln!(f, "Bounds are not in order: ({}, {})", lower, upper)
145 }
146 Self::ParameterOutOfBounds {
147 name,
148 given,
149 bounds,
150 } => writeln!(
151 f,
152 "Parameter {} is out of bounds ({}, {}), given: {}",
153 name, bounds.0, bounds.1, given
154 ),
155 Self::ExtraneousParameters(n) => {
156 writeln!(f, "{} extra parameters proved to kernel", n)
157 }
158 Self::MissingParameters(n) => {
159 writeln!(f, "Missing {} parameters", n)
160 }
161 Self::CovGrad(e) => {
162 writeln!(f, "Covariance Gradient couldn't be computed: {}", e)
163 }
164 }
165 }
166}
167
168impl From<CovGradError> for KernelError {
169 fn from(e: CovGradError) -> Self {
170 Self::CovGrad(e)
171 }
172}
173
174macro_rules! impl_mul_add {
175 ($type: ty) => {
176 impl<B> std::ops::Mul<B> for $type
177 where
178 B: Kernel,
179 {
180 type Output = ProductKernel<$type, B>;
181
182 fn mul(self, rhs: B) -> Self::Output {
183 ProductKernel::new(self, rhs)
184 }
185 }
186
187 impl<B> std::ops::Add<B> for $type
188 where
189 B: Kernel,
190 {
191 type Output = AddKernel<$type, B>;
192
193 fn add(self, rhs: B) -> Self::Output {
194 AddKernel::new(self, rhs)
195 }
196 }
197 };
198}
199
200impl_mul_add!(ConstantKernel);
201impl_mul_add!(RBFKernel);
202impl_mul_add!(SEardKernel);
203impl_mul_add!(ExpSineSquaredKernel);
204impl_mul_add!(RationalQuadratic);
205impl_mul_add!(WhiteKernel);
206impl_mul_add!(MaternKernel);