rv/process/gaussian/kernel/
mod.rs

1//! Gaussian Processes
2
3use nalgebra::base::constraint::{SameNumberOfColumns, ShapeConstraint};
4use nalgebra::base::storage::Storage;
5use nalgebra::{DMatrix, DVector, Dim, Matrix};
6use std::f64;
7
8#[cfg(feature = "serde1")]
9use serde::{Deserialize, Serialize};
10
11mod covgrad;
12pub use covgrad::*;
13
14mod misc;
15pub use self::misc::*;
16
17mod constant_kernel;
18pub use self::constant_kernel::*;
19
20mod ops;
21pub use self::ops::*;
22
23mod rbf;
24pub use self::rbf::*;
25mod white_kernel;
26pub use self::white_kernel::*;
27mod rational_quadratic;
28pub use self::rational_quadratic::*;
29mod exp_sin_squared;
30pub use self::exp_sin_squared::*;
31mod seard;
32pub use self::seard::*;
33mod matern;
34pub use self::matern::*;
35
36/// Kernel Function
37pub trait Kernel: std::fmt::Debug + Clone + PartialEq {
38    /// Return the number of parameters used in this `Kernel`.
39    fn n_parameters(&self) -> usize;
40
41    /// Returns the covariance matrix for two equal sized vectors
42    fn covariance<R1, R2, C1, C2, S1, S2>(
43        &self,
44        x1: &Matrix<f64, R1, C1, S1>,
45        x2: &Matrix<f64, R2, C2, S2>,
46    ) -> DMatrix<f64>
47    where
48        R1: Dim,
49        R2: Dim,
50        C1: Dim,
51        C2: Dim,
52        S1: Storage<f64, R1, C1>,
53        S2: Storage<f64, R2, C2>,
54        ShapeConstraint: SameNumberOfColumns<C1, C2>;
55
56    /// Reports if the given kernel function is stationary.
57    fn is_stationary(&self) -> bool;
58
59    /// Returns the diagonal of the kernel(x, x)
60    fn diag<R, C, S>(&self, x: &Matrix<f64, R, C, S>) -> DVector<f64>
61    where
62        R: Dim,
63        C: Dim,
64        S: Storage<f64, R, C>;
65
66    /// Return the corresponding parameter vector
67    /// The parameters here are in a log-scale
68    fn parameters(&self) -> DVector<f64>;
69
70    /// Create a new kernel of the given type from the provided parameters.
71    /// The parameters here are in a log-scale
72    fn reparameterize(&self, params: &[f64]) -> Result<Self, KernelError>;
73
74    /// Takes a sequence of parameters and consumes only the ones it needs
75    /// to create itself.
76    /// The parameters here are in a log-scale
77    fn consume_parameters<I: IntoIterator<Item = f64>>(
78        &self,
79        params: I,
80    ) -> Result<(Self, I::IntoIter), KernelError> {
81        let mut iter = params.into_iter();
82        let n = self.n_parameters();
83        let mut parameters: Vec<f64> = Vec::with_capacity(n);
84
85        // TODO: Clean this up if/when `iter_next_chunk` is stabilized
86        for i in 0..n {
87            parameters.push(
88                iter.next().ok_or(KernelError::MissingParameters(n - i))?,
89            );
90        }
91
92        Ok((self.reparameterize(&parameters)?, iter))
93    }
94
95    /// Covariance and Gradient with the log-scaled hyper-parameters
96    fn covariance_with_gradient<R, C, S>(
97        &self,
98        x: &Matrix<f64, R, C, S>,
99    ) -> Result<(DMatrix<f64>, CovGrad), CovGradError>
100    where
101        R: Dim,
102        C: Dim,
103        S: Storage<f64, R, C>;
104
105    fn add<B: Kernel>(self, other: B) -> AddKernel<Self, B> {
106        AddKernel::new(self, other)
107    }
108
109    fn mul<B: Kernel>(self, other: B) -> ProductKernel<Self, B> {
110        ProductKernel::new(self, other)
111    }
112}
113
114/// Errors from Kernel construction
115#[derive(Debug, Clone, PartialEq)]
116#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
117#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
118pub enum KernelError {
119    /// Lower bounds must be lower that upper bounds
120    ImproperBounds(f64, f64),
121    /// Parameter Out of Bounds
122    ParameterOutOfBounds {
123        /// Name of parameter
124        name: String,
125        /// Value given
126        given: f64,
127        /// Lower and upper bounds on value
128        bounds: (f64, f64),
129    },
130    /// Too many parameters provided
131    ExtraneousParameters(usize),
132    /// Too few parameters provided
133    MissingParameters(usize),
134    /// An error in computing cov-grad
135    CovGrad(CovGradError),
136}
137
138impl std::error::Error for KernelError {}
139
140impl std::fmt::Display for KernelError {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        match self {
143            Self::ImproperBounds(lower, upper) => {
144                writeln!(f, "Bounds are not in order: ({}, {})", lower, upper)
145            }
146            Self::ParameterOutOfBounds {
147                name,
148                given,
149                bounds,
150            } => writeln!(
151                f,
152                "Parameter {} is out of bounds ({}, {}), given: {}",
153                name, bounds.0, bounds.1, given
154            ),
155            Self::ExtraneousParameters(n) => {
156                writeln!(f, "{} extra parameters proved to kernel", n)
157            }
158            Self::MissingParameters(n) => {
159                writeln!(f, "Missing {} parameters", n)
160            }
161            Self::CovGrad(e) => {
162                writeln!(f, "Covariance Gradient couldn't be computed: {}", e)
163            }
164        }
165    }
166}
167
168impl From<CovGradError> for KernelError {
169    fn from(e: CovGradError) -> Self {
170        Self::CovGrad(e)
171    }
172}
173
174macro_rules! impl_mul_add {
175    ($type: ty) => {
176        impl<B> std::ops::Mul<B> for $type
177        where
178            B: Kernel,
179        {
180            type Output = ProductKernel<$type, B>;
181
182            fn mul(self, rhs: B) -> Self::Output {
183                ProductKernel::new(self, rhs)
184            }
185        }
186
187        impl<B> std::ops::Add<B> for $type
188        where
189            B: Kernel,
190        {
191            type Output = AddKernel<$type, B>;
192
193            fn add(self, rhs: B) -> Self::Output {
194                AddKernel::new(self, rhs)
195            }
196        }
197    };
198}
199
200impl_mul_add!(ConstantKernel);
201impl_mul_add!(RBFKernel);
202impl_mul_add!(SEardKernel);
203impl_mul_add!(ExpSineSquaredKernel);
204impl_mul_add!(RationalQuadratic);
205impl_mul_add!(WhiteKernel);
206impl_mul_add!(MaternKernel);