Skip to main content

scirs2_interpolate/deep_kriging/
types.rs

1//! Types for Deep Kriging and GP Surrogate modules.
2//!
3//! This module defines configuration types, kernel specifications,
4//! acquisition functions, and result containers for neural-basis kriging
5//! and Gaussian process surrogate modelling.
6
7// ---------------------------------------------------------------------------
8// Activation function
9// ---------------------------------------------------------------------------
10
11/// Activation function used in MLP layers.
12#[derive(Debug, Clone, Copy, PartialEq)]
13#[non_exhaustive]
14pub enum Activation {
15    /// Rectified Linear Unit: max(0, x)
16    ReLU,
17    /// Hyperbolic tangent
18    Tanh,
19    /// Logistic sigmoid: 1 / (1 + exp(-x))
20    Sigmoid,
21    /// Exponential Linear Unit: x if x > 0 else alpha*(exp(x)-1)
22    ELU { alpha: f64 },
23}
24
25impl Activation {
26    /// Apply the activation function element-wise.
27    pub fn apply(&self, x: f64) -> f64 {
28        match self {
29            Activation::ReLU => x.max(0.0),
30            Activation::Tanh => x.tanh(),
31            Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
32            Activation::ELU { alpha } => {
33                if x > 0.0 {
34                    x
35                } else {
36                    alpha * (x.exp() - 1.0)
37                }
38            }
39        }
40    }
41
42    /// Derivative of the activation function.
43    pub fn derivative(&self, x: f64) -> f64 {
44        match self {
45            Activation::ReLU => {
46                if x > 0.0 {
47                    1.0
48                } else {
49                    0.0
50                }
51            }
52            Activation::Tanh => {
53                let t = x.tanh();
54                1.0 - t * t
55            }
56            Activation::Sigmoid => {
57                let s = 1.0 / (1.0 + (-x).exp());
58                s * (1.0 - s)
59            }
60            Activation::ELU { alpha } => {
61                if x > 0.0 {
62                    1.0
63                } else {
64                    alpha * x.exp()
65                }
66            }
67        }
68    }
69}
70
71// ---------------------------------------------------------------------------
72// Deep Kriging config
73// ---------------------------------------------------------------------------
74
75/// Configuration for Neural Basis Kriging (Deep Kriging).
76///
77/// An MLP learns nonlinear basis functions that are fed into ordinary kriging.
78#[derive(Debug, Clone)]
79pub struct DeepKrigingConfig {
80    /// Sizes of hidden layers in the MLP (e.g. `[32, 16]`).
81    pub hidden_layers: Vec<usize>,
82    /// Learning rate for gradient descent on MLP weights.
83    pub learning_rate: f64,
84    /// Number of training epochs (alternating optimisation steps).
85    pub epochs: usize,
86    /// Activation function used between layers.
87    pub activation: Activation,
88    /// Dimension of the basis output (last MLP layer).
89    pub basis_dim: usize,
90    /// Seed for reproducible weight initialisation.
91    pub seed: u64,
92}
93
94impl Default for DeepKrigingConfig {
95    fn default() -> Self {
96        Self {
97            hidden_layers: vec![32, 16],
98            learning_rate: 0.01,
99            epochs: 100,
100            activation: Activation::Tanh,
101            basis_dim: 8,
102            seed: 42,
103        }
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Kernel types for GP surrogate
109// ---------------------------------------------------------------------------
110
111/// Kernel (covariance function) type for Gaussian process surrogate.
112#[derive(Debug, Clone, Copy, PartialEq)]
113#[non_exhaustive]
114pub enum KernelType {
115    /// Squared exponential (RBF) kernel:
116    /// k(x, x') = variance * exp(-||x-x'||^2 / (2 * lengthscale^2))
117    SquaredExponential {
118        /// Characteristic lengthscale.
119        lengthscale: f64,
120        /// Signal variance.
121        variance: f64,
122    },
123    /// Matern kernel with smoothness parameter nu.
124    /// Supported nu values: 0.5 (exponential), 1.5, 2.5.
125    Matern {
126        /// Smoothness parameter (0.5, 1.5, or 2.5).
127        nu: f64,
128        /// Characteristic lengthscale.
129        lengthscale: f64,
130        /// Signal variance.
131        variance: f64,
132    },
133    /// Rational quadratic kernel:
134    /// k(x, x') = variance * (1 + ||x-x'||^2 / (2 * alpha * lengthscale^2))^(-alpha)
135    RationalQuadratic {
136        /// Scale mixture parameter.
137        alpha: f64,
138        /// Characteristic lengthscale.
139        lengthscale: f64,
140        /// Signal variance.
141        variance: f64,
142    },
143}
144
145impl Default for KernelType {
146    fn default() -> Self {
147        KernelType::SquaredExponential {
148            lengthscale: 1.0,
149            variance: 1.0,
150        }
151    }
152}
153
154impl KernelType {
155    /// Evaluate the kernel between two points.
156    pub fn evaluate(&self, x: &[f64], xp: &[f64]) -> f64 {
157        let sq_dist: f64 = x
158            .iter()
159            .zip(xp.iter())
160            .map(|(a, b)| (a - b) * (a - b))
161            .sum();
162
163        match self {
164            KernelType::SquaredExponential {
165                lengthscale,
166                variance,
167            } => {
168                let l2 = lengthscale * lengthscale;
169                variance * (-sq_dist / (2.0 * l2)).exp()
170            }
171            KernelType::Matern {
172                nu,
173                lengthscale,
174                variance,
175            } => {
176                let r = sq_dist.sqrt() / lengthscale;
177                if r < 1e-12 {
178                    return *variance;
179                }
180                if (*nu - 0.5).abs() < 1e-6 {
181                    // Matern 1/2 = exponential
182                    variance * (-r).exp()
183                } else if (*nu - 1.5).abs() < 1e-6 {
184                    // Matern 3/2
185                    let s3 = 3.0_f64.sqrt() * r;
186                    variance * (1.0 + s3) * (-s3).exp()
187                } else if (*nu - 2.5).abs() < 1e-6 {
188                    // Matern 5/2
189                    let s5 = 5.0_f64.sqrt() * r;
190                    variance * (1.0 + s5 + s5 * s5 / 3.0) * (-s5).exp()
191                } else {
192                    // Fall back to squared exponential for unsupported nu
193                    variance * (-sq_dist / (2.0 * lengthscale * lengthscale)).exp()
194                }
195            }
196            KernelType::RationalQuadratic {
197                alpha,
198                lengthscale,
199                variance,
200            } => {
201                let l2 = lengthscale * lengthscale;
202                variance * (1.0 + sq_dist / (2.0 * alpha * l2)).powf(-alpha)
203            }
204        }
205    }
206
207    /// Return the signal variance of the kernel.
208    pub fn signal_variance(&self) -> f64 {
209        match self {
210            KernelType::SquaredExponential { variance, .. } => *variance,
211            KernelType::Matern { variance, .. } => *variance,
212            KernelType::RationalQuadratic { variance, .. } => *variance,
213        }
214    }
215
216    /// Return the lengthscale of the kernel.
217    pub fn lengthscale(&self) -> f64 {
218        match self {
219            KernelType::SquaredExponential { lengthscale, .. } => *lengthscale,
220            KernelType::Matern { lengthscale, .. } => *lengthscale,
221            KernelType::RationalQuadratic { lengthscale, .. } => *lengthscale,
222        }
223    }
224}
225
226// ---------------------------------------------------------------------------
227// Acquisition functions
228// ---------------------------------------------------------------------------
229
230/// Acquisition function for Bayesian optimisation.
231#[derive(Debug, Clone, Copy, PartialEq)]
232#[non_exhaustive]
233pub enum AcquisitionFunction {
234    /// Expected Improvement: EI(x) = sigma * [z*Phi(z) + phi(z)]
235    EI,
236    /// Probability of Improvement.
237    PI,
238    /// Upper Confidence Bound with exploration weight kappa.
239    UCB(f64),
240    /// Lower Confidence Bound with exploration weight kappa.
241    LCB(f64),
242}
243
244impl Default for AcquisitionFunction {
245    fn default() -> Self {
246        AcquisitionFunction::EI
247    }
248}
249
250// ---------------------------------------------------------------------------
251// GP Surrogate config
252// ---------------------------------------------------------------------------
253
254/// Configuration for the Gaussian process surrogate model.
255#[derive(Debug, Clone)]
256pub struct GPSurrogateConfig {
257    /// Covariance kernel.
258    pub kernel: KernelType,
259    /// Observation noise variance (added to diagonal of K).
260    pub noise: f64,
261    /// Whether to optimise kernel hyperparameters via marginal likelihood.
262    pub optimize_hyperparams: bool,
263    /// Number of random restarts for hyperparameter optimisation.
264    pub n_restarts: usize,
265    /// Maximum number of optimisation iterations per restart.
266    pub max_opt_iterations: usize,
267}
268
269impl Default for GPSurrogateConfig {
270    fn default() -> Self {
271        Self {
272            kernel: KernelType::default(),
273            noise: 1e-6,
274            optimize_hyperparams: false,
275            n_restarts: 3,
276            max_opt_iterations: 100,
277        }
278    }
279}
280
281// ---------------------------------------------------------------------------
282// Result types
283// ---------------------------------------------------------------------------
284
285/// Result of a GP surrogate prediction.
286#[derive(Debug, Clone)]
287pub struct SurrogateResult {
288    /// Predictive means at query points.
289    pub predictions: Vec<f64>,
290    /// Predictive variances at query points.
291    pub variances: Vec<f64>,
292    /// Optimised hyperparameters (kernel parameters + noise).
293    pub hyperparameters: Vec<f64>,
294}