scirs2_interpolate/deep_kriging/types.rs
1//! Types for Deep Kriging and GP Surrogate modules.
2//!
3//! This module defines configuration types, kernel specifications,
4//! acquisition functions, and result containers for neural-basis kriging
5//! and Gaussian process surrogate modelling.
6
7// ---------------------------------------------------------------------------
8// Activation function
9// ---------------------------------------------------------------------------
10
11/// Activation function used in MLP layers.
12#[derive(Debug, Clone, Copy, PartialEq)]
13#[non_exhaustive]
14pub enum Activation {
15 /// Rectified Linear Unit: max(0, x)
16 ReLU,
17 /// Hyperbolic tangent
18 Tanh,
19 /// Logistic sigmoid: 1 / (1 + exp(-x))
20 Sigmoid,
21 /// Exponential Linear Unit: x if x > 0 else alpha*(exp(x)-1)
22 ELU { alpha: f64 },
23}
24
25impl Activation {
26 /// Apply the activation function element-wise.
27 pub fn apply(&self, x: f64) -> f64 {
28 match self {
29 Activation::ReLU => x.max(0.0),
30 Activation::Tanh => x.tanh(),
31 Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
32 Activation::ELU { alpha } => {
33 if x > 0.0 {
34 x
35 } else {
36 alpha * (x.exp() - 1.0)
37 }
38 }
39 }
40 }
41
42 /// Derivative of the activation function.
43 pub fn derivative(&self, x: f64) -> f64 {
44 match self {
45 Activation::ReLU => {
46 if x > 0.0 {
47 1.0
48 } else {
49 0.0
50 }
51 }
52 Activation::Tanh => {
53 let t = x.tanh();
54 1.0 - t * t
55 }
56 Activation::Sigmoid => {
57 let s = 1.0 / (1.0 + (-x).exp());
58 s * (1.0 - s)
59 }
60 Activation::ELU { alpha } => {
61 if x > 0.0 {
62 1.0
63 } else {
64 alpha * x.exp()
65 }
66 }
67 }
68 }
69}
70
71// ---------------------------------------------------------------------------
72// Deep Kriging config
73// ---------------------------------------------------------------------------
74
75/// Configuration for Neural Basis Kriging (Deep Kriging).
76///
77/// An MLP learns nonlinear basis functions that are fed into ordinary kriging.
78#[derive(Debug, Clone)]
79pub struct DeepKrigingConfig {
80 /// Sizes of hidden layers in the MLP (e.g. `[32, 16]`).
81 pub hidden_layers: Vec<usize>,
82 /// Learning rate for gradient descent on MLP weights.
83 pub learning_rate: f64,
84 /// Number of training epochs (alternating optimisation steps).
85 pub epochs: usize,
86 /// Activation function used between layers.
87 pub activation: Activation,
88 /// Dimension of the basis output (last MLP layer).
89 pub basis_dim: usize,
90 /// Seed for reproducible weight initialisation.
91 pub seed: u64,
92}
93
94impl Default for DeepKrigingConfig {
95 fn default() -> Self {
96 Self {
97 hidden_layers: vec![32, 16],
98 learning_rate: 0.01,
99 epochs: 100,
100 activation: Activation::Tanh,
101 basis_dim: 8,
102 seed: 42,
103 }
104 }
105}
106
107// ---------------------------------------------------------------------------
108// Kernel types for GP surrogate
109// ---------------------------------------------------------------------------
110
111/// Kernel (covariance function) type for Gaussian process surrogate.
112#[derive(Debug, Clone, Copy, PartialEq)]
113#[non_exhaustive]
114pub enum KernelType {
115 /// Squared exponential (RBF) kernel:
116 /// k(x, x') = variance * exp(-||x-x'||^2 / (2 * lengthscale^2))
117 SquaredExponential {
118 /// Characteristic lengthscale.
119 lengthscale: f64,
120 /// Signal variance.
121 variance: f64,
122 },
123 /// Matern kernel with smoothness parameter nu.
124 /// Supported nu values: 0.5 (exponential), 1.5, 2.5.
125 Matern {
126 /// Smoothness parameter (0.5, 1.5, or 2.5).
127 nu: f64,
128 /// Characteristic lengthscale.
129 lengthscale: f64,
130 /// Signal variance.
131 variance: f64,
132 },
133 /// Rational quadratic kernel:
134 /// k(x, x') = variance * (1 + ||x-x'||^2 / (2 * alpha * lengthscale^2))^(-alpha)
135 RationalQuadratic {
136 /// Scale mixture parameter.
137 alpha: f64,
138 /// Characteristic lengthscale.
139 lengthscale: f64,
140 /// Signal variance.
141 variance: f64,
142 },
143}
144
145impl Default for KernelType {
146 fn default() -> Self {
147 KernelType::SquaredExponential {
148 lengthscale: 1.0,
149 variance: 1.0,
150 }
151 }
152}
153
154impl KernelType {
155 /// Evaluate the kernel between two points.
156 pub fn evaluate(&self, x: &[f64], xp: &[f64]) -> f64 {
157 let sq_dist: f64 = x
158 .iter()
159 .zip(xp.iter())
160 .map(|(a, b)| (a - b) * (a - b))
161 .sum();
162
163 match self {
164 KernelType::SquaredExponential {
165 lengthscale,
166 variance,
167 } => {
168 let l2 = lengthscale * lengthscale;
169 variance * (-sq_dist / (2.0 * l2)).exp()
170 }
171 KernelType::Matern {
172 nu,
173 lengthscale,
174 variance,
175 } => {
176 let r = sq_dist.sqrt() / lengthscale;
177 if r < 1e-12 {
178 return *variance;
179 }
180 if (*nu - 0.5).abs() < 1e-6 {
181 // Matern 1/2 = exponential
182 variance * (-r).exp()
183 } else if (*nu - 1.5).abs() < 1e-6 {
184 // Matern 3/2
185 let s3 = 3.0_f64.sqrt() * r;
186 variance * (1.0 + s3) * (-s3).exp()
187 } else if (*nu - 2.5).abs() < 1e-6 {
188 // Matern 5/2
189 let s5 = 5.0_f64.sqrt() * r;
190 variance * (1.0 + s5 + s5 * s5 / 3.0) * (-s5).exp()
191 } else {
192 // Fall back to squared exponential for unsupported nu
193 variance * (-sq_dist / (2.0 * lengthscale * lengthscale)).exp()
194 }
195 }
196 KernelType::RationalQuadratic {
197 alpha,
198 lengthscale,
199 variance,
200 } => {
201 let l2 = lengthscale * lengthscale;
202 variance * (1.0 + sq_dist / (2.0 * alpha * l2)).powf(-alpha)
203 }
204 }
205 }
206
207 /// Return the signal variance of the kernel.
208 pub fn signal_variance(&self) -> f64 {
209 match self {
210 KernelType::SquaredExponential { variance, .. } => *variance,
211 KernelType::Matern { variance, .. } => *variance,
212 KernelType::RationalQuadratic { variance, .. } => *variance,
213 }
214 }
215
216 /// Return the lengthscale of the kernel.
217 pub fn lengthscale(&self) -> f64 {
218 match self {
219 KernelType::SquaredExponential { lengthscale, .. } => *lengthscale,
220 KernelType::Matern { lengthscale, .. } => *lengthscale,
221 KernelType::RationalQuadratic { lengthscale, .. } => *lengthscale,
222 }
223 }
224}
225
226// ---------------------------------------------------------------------------
227// Acquisition functions
228// ---------------------------------------------------------------------------
229
230/// Acquisition function for Bayesian optimisation.
231#[derive(Debug, Clone, Copy, PartialEq)]
232#[non_exhaustive]
233pub enum AcquisitionFunction {
234 /// Expected Improvement: EI(x) = sigma * [z*Phi(z) + phi(z)]
235 EI,
236 /// Probability of Improvement.
237 PI,
238 /// Upper Confidence Bound with exploration weight kappa.
239 UCB(f64),
240 /// Lower Confidence Bound with exploration weight kappa.
241 LCB(f64),
242}
243
244impl Default for AcquisitionFunction {
245 fn default() -> Self {
246 AcquisitionFunction::EI
247 }
248}
249
250// ---------------------------------------------------------------------------
251// GP Surrogate config
252// ---------------------------------------------------------------------------
253
254/// Configuration for the Gaussian process surrogate model.
255#[derive(Debug, Clone)]
256pub struct GPSurrogateConfig {
257 /// Covariance kernel.
258 pub kernel: KernelType,
259 /// Observation noise variance (added to diagonal of K).
260 pub noise: f64,
261 /// Whether to optimise kernel hyperparameters via marginal likelihood.
262 pub optimize_hyperparams: bool,
263 /// Number of random restarts for hyperparameter optimisation.
264 pub n_restarts: usize,
265 /// Maximum number of optimisation iterations per restart.
266 pub max_opt_iterations: usize,
267}
268
269impl Default for GPSurrogateConfig {
270 fn default() -> Self {
271 Self {
272 kernel: KernelType::default(),
273 noise: 1e-6,
274 optimize_hyperparams: false,
275 n_restarts: 3,
276 max_opt_iterations: 100,
277 }
278 }
279}
280
281// ---------------------------------------------------------------------------
282// Result types
283// ---------------------------------------------------------------------------
284
285/// Result of a GP surrogate prediction.
286#[derive(Debug, Clone)]
287pub struct SurrogateResult {
288 /// Predictive means at query points.
289 pub predictions: Vec<f64>,
290 /// Predictive variances at query points.
291 pub variances: Vec<f64>,
292 /// Optimised hyperparameters (kernel parameters + noise).
293 pub hyperparameters: Vec<f64>,
294}