scirs2_interpolate/auto_kernel_gp/
kernel.rs1#[derive(Debug, Clone, PartialEq)]
19pub enum BaseKernel {
20 Rbf { length_scale: f64 },
22 Matern52 { length_scale: f64 },
24 Periodic { period: f64, length_scale: f64 },
26 Linear { variance: f64 },
28 WhiteNoise { variance: f64 },
30}
31
32impl BaseKernel {
33 pub fn eval(&self, x1: f64, x2: f64) -> f64 {
35 match self {
36 BaseKernel::Rbf { length_scale } => {
37 let ell = length_scale.max(1e-10);
38 let d = x1 - x2;
39 (-d * d / (2.0 * ell * ell)).exp()
40 }
41 BaseKernel::Matern52 { length_scale } => {
42 let ell = length_scale.max(1e-10);
43 let r = (x1 - x2).abs();
44 let s = 5.0_f64.sqrt() * r / ell;
45 (1.0 + s + s * s / 3.0) * (-s).exp()
46 }
47 BaseKernel::Periodic {
48 period,
49 length_scale,
50 } => {
51 let p = period.max(1e-10);
52 let ell = length_scale.max(1e-10);
53 let d = (x1 - x2).abs();
54 let sin_val = (std::f64::consts::PI * d / p).sin();
55 (-2.0 * sin_val * sin_val / (ell * ell)).exp()
56 }
57 BaseKernel::Linear { variance } => variance * x1 * x2,
58 BaseKernel::WhiteNoise { variance } => {
59 if (x1 - x2).abs() < 1e-12 {
60 *variance
61 } else {
62 0.0
63 }
64 }
65 }
66 }
67
68 pub fn name(&self) -> &'static str {
70 match self {
71 BaseKernel::Rbf { .. } => "RBF",
72 BaseKernel::Matern52 { .. } => "Matern52",
73 BaseKernel::Periodic { .. } => "Periodic",
74 BaseKernel::Linear { .. } => "Linear",
75 BaseKernel::WhiteNoise { .. } => "WhiteNoise",
76 }
77 }
78
79 pub fn with_length_scale(&self, new_ell: f64) -> Self {
81 match self {
82 BaseKernel::Rbf { .. } => BaseKernel::Rbf {
83 length_scale: new_ell,
84 },
85 BaseKernel::Matern52 { .. } => BaseKernel::Matern52 {
86 length_scale: new_ell,
87 },
88 BaseKernel::Periodic { period, .. } => BaseKernel::Periodic {
89 period: *period,
90 length_scale: new_ell,
91 },
92 BaseKernel::Linear { .. } => BaseKernel::Linear { variance: new_ell },
93 BaseKernel::WhiteNoise { .. } => BaseKernel::WhiteNoise { variance: new_ell },
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
108pub enum KernelExpr {
109 Base(BaseKernel),
111 Sum(Box<KernelExpr>, Box<KernelExpr>),
113 Product(Box<KernelExpr>, Box<KernelExpr>),
115}
116
117impl KernelExpr {
118 pub fn eval(&self, x1: f64, x2: f64) -> f64 {
120 match self {
121 KernelExpr::Base(k) => k.eval(x1, x2),
122 KernelExpr::Sum(a, b) => a.eval(x1, x2) + b.eval(x1, x2),
123 KernelExpr::Product(a, b) => a.eval(x1, x2) * b.eval(x1, x2),
124 }
125 }
126
127 pub fn description(&self) -> String {
129 match self {
130 KernelExpr::Base(k) => k.name().to_string(),
131 KernelExpr::Sum(a, b) => format!("{} + {}", a.description(), b.description()),
132 KernelExpr::Product(a, b) => {
133 let ad = a.description();
135 let bd = b.description();
136 let left = if matches!(a.as_ref(), KernelExpr::Sum(_, _)) {
137 format!("({ad})")
138 } else {
139 ad
140 };
141 let right = if matches!(b.as_ref(), KernelExpr::Sum(_, _)) {
142 format!("({bd})")
143 } else {
144 bd
145 };
146 format!("{left} × {right}")
147 }
148 }
149 }
150
151 pub fn depth(&self) -> usize {
153 match self {
154 KernelExpr::Base(_) => 0,
155 KernelExpr::Sum(a, b) | KernelExpr::Product(a, b) => 1 + a.depth().max(b.depth()),
156 }
157 }
158}
159
160pub fn base_kernels() -> Vec<BaseKernel> {
162 vec![
163 BaseKernel::Rbf { length_scale: 1.0 },
164 BaseKernel::Matern52 { length_scale: 1.0 },
165 BaseKernel::Periodic {
166 period: 1.0,
167 length_scale: 1.0,
168 },
169 BaseKernel::Linear { variance: 1.0 },
170 BaseKernel::WhiteNoise { variance: 0.1 },
171 ]
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn kernel_expr_eval_rbf_is_symmetric() {
180 let k = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
181 let a = k.eval(0.5, 1.5);
182 let b = k.eval(1.5, 0.5);
183 assert!(
184 (a - b).abs() < 1e-12,
185 "RBF must be symmetric: k(0.5,1.5)={a}, k(1.5,0.5)={b}"
186 );
187 }
188
189 #[test]
190 fn kernel_expr_sum_is_additive() {
191 let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 });
192 let k2 = KernelExpr::Base(BaseKernel::Matern52 { length_scale: 1.0 });
193 let sum = KernelExpr::Sum(Box::new(k1.clone()), Box::new(k2.clone()));
194 let x = 0.3;
195 let y = 1.2;
196 let expected = k1.eval(x, y) + k2.eval(x, y);
197 let actual = sum.eval(x, y);
198 assert!(
199 (actual - expected).abs() < 1e-12,
200 "Sum kernel must equal sum of parts: {actual} vs {expected}"
201 );
202 }
203
204 #[test]
205 fn kernel_expr_product_is_multiplicative() {
206 let k1 = KernelExpr::Base(BaseKernel::Rbf { length_scale: 0.8 });
207 let k2 = KernelExpr::Base(BaseKernel::Linear { variance: 2.0 });
208 let prod = KernelExpr::Product(Box::new(k1.clone()), Box::new(k2.clone()));
209 let x = 1.5;
210 let y = 0.5;
211 let expected = k1.eval(x, y) * k2.eval(x, y);
212 let actual = prod.eval(x, y);
213 assert!(
214 (actual - expected).abs() < 1e-12,
215 "Product kernel must equal product of parts: {actual} vs {expected}"
216 );
217 }
218
219 #[test]
220 fn kernel_description_nonempty() {
221 let k = KernelExpr::Sum(
222 Box::new(KernelExpr::Base(BaseKernel::Rbf { length_scale: 1.0 })),
223 Box::new(KernelExpr::Product(
224 Box::new(KernelExpr::Base(BaseKernel::Periodic {
225 period: 1.0,
226 length_scale: 0.5,
227 })),
228 Box::new(KernelExpr::Base(BaseKernel::Linear { variance: 1.0 })),
229 )),
230 );
231 let desc = k.description();
232 assert!(!desc.is_empty(), "description must not be empty");
233 assert!(desc.contains("RBF"), "description should contain RBF");
234 assert!(
235 desc.contains("Periodic"),
236 "description should contain Periodic"
237 );
238 }
239}