Skip to main content

scirs2_optimize/surrogate/
mod.rs

1//! Surrogate-Assisted Optimization
2//!
3//! This module provides surrogate model-based optimization methods that build
4//! approximate models of expensive objective functions to guide the search
5//! for the optimum with fewer function evaluations.
6//!
7//! ## Surrogate Models
8//!
9//! - **RBF Surrogate**: Radial Basis Function interpolation (polyharmonic, multiquadric, thin-plate)
10//! - **Kriging**: Gaussian Process surrogate with nugget parameter for noise handling
11//! - **Ensemble**: Ensemble of surrogates with automatic model selection
12//!
13//! ## Usage
14//!
15//! All surrogates implement the [`SurrogateModel`] trait, which provides a common
16//! interface for fitting, predicting, and estimating uncertainty.
17
18pub mod ensemble;
19pub mod kriging;
20pub mod rbf_surrogate;
21
22pub use ensemble::{EnsembleOptions, EnsembleSurrogate, ModelSelectionCriterion};
23pub use kriging::{CorrelationFunction, KrigingOptions, KrigingSurrogate};
24pub use rbf_surrogate::{RbfKernel, RbfOptions, RbfSurrogate};
25
26use crate::error::{OptimizeError, OptimizeResult};
27use scirs2_core::ndarray::{Array1, Array2};
28
29/// Common trait for surrogate models
30pub trait SurrogateModel {
31    /// Fit the surrogate model to the provided data
32    ///
33    /// # Arguments
34    /// * `x` - Training points, shape (n_samples, n_features)
35    /// * `y` - Function values at training points, shape (n_samples,)
36    fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()>;
37
38    /// Predict the function value at a new point
39    ///
40    /// # Arguments
41    /// * `x` - Point to predict at, shape (n_features,)
42    ///
43    /// # Returns
44    /// Predicted function value
45    fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64>;
46
47    /// Predict the function value and uncertainty at a new point
48    ///
49    /// # Arguments
50    /// * `x` - Point to predict at, shape (n_features,)
51    ///
52    /// # Returns
53    /// (predicted mean, predicted standard deviation)
54    fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)>;
55
56    /// Predict at multiple points
57    ///
58    /// # Arguments
59    /// * `x` - Points to predict at, shape (n_points, n_features)
60    ///
61    /// # Returns
62    /// Predicted values, shape (n_points,)
63    fn predict_batch(&self, x: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
64        let n = x.nrows();
65        let mut predictions = Array1::zeros(n);
66        for i in 0..n {
67            predictions[i] = self.predict(&x.row(i).to_owned())?;
68        }
69        Ok(predictions)
70    }
71
72    /// Get the number of training points
73    fn n_samples(&self) -> usize;
74
75    /// Get the dimensionality of the problem
76    fn n_features(&self) -> usize;
77
78    /// Add a new data point and update the model
79    fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()>;
80}
81
82/// Compute pairwise squared Euclidean distances between rows of X and Y
83pub fn pairwise_sq_distances(x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
84    let n = x.nrows();
85    let m = y.nrows();
86    let mut dists = Array2::zeros((n, m));
87    for i in 0..n {
88        for j in 0..m {
89            let mut sq_dist = 0.0;
90            for k in 0..x.ncols() {
91                let diff = x[[i, k]] - y[[j, k]];
92                sq_dist += diff * diff;
93            }
94            dists[[i, j]] = sq_dist;
95        }
96    }
97    dists
98}
99
100/// Solve a symmetric positive definite linear system Ax = b using Cholesky decomposition
101/// Returns x
102pub fn solve_spd(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
103    let n = a.nrows();
104    if n != a.ncols() {
105        return Err(OptimizeError::InvalidInput(
106            "Matrix must be square".to_string(),
107        ));
108    }
109    if n != b.len() {
110        return Err(OptimizeError::InvalidInput(
111            "Matrix and vector dimensions must match".to_string(),
112        ));
113    }
114
115    // Cholesky factorization: A = L * L^T
116    let mut l = Array2::zeros((n, n));
117    for j in 0..n {
118        let mut sum = 0.0;
119        for k in 0..j {
120            sum += l[[j, k]] * l[[j, k]];
121        }
122        let diag = a[[j, j]] - sum;
123        if diag <= 0.0 {
124            return Err(OptimizeError::ComputationError(
125                "Matrix is not positive definite".to_string(),
126            ));
127        }
128        l[[j, j]] = diag.sqrt();
129
130        for i in (j + 1)..n {
131            let mut sum = 0.0;
132            for k in 0..j {
133                sum += l[[i, k]] * l[[j, k]];
134            }
135            l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
136        }
137    }
138
139    // Forward substitution: L * z = b
140    let mut z = Array1::zeros(n);
141    for i in 0..n {
142        let mut sum = 0.0;
143        for j in 0..i {
144            sum += l[[i, j]] * z[j];
145        }
146        z[i] = (b[i] - sum) / l[[i, i]];
147    }
148
149    // Back substitution: L^T * x = z
150    let mut x = Array1::zeros(n);
151    for i in (0..n).rev() {
152        let mut sum = 0.0;
153        for j in (i + 1)..n {
154            sum += l[[j, i]] * x[j];
155        }
156        x[i] = (z[i] - sum) / l[[i, i]];
157    }
158
159    Ok(x)
160}
161
162/// Solve a general linear system Ax = b using LU decomposition with partial pivoting
163pub fn solve_general(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
164    let n = a.nrows();
165    if n != a.ncols() || n != b.len() {
166        return Err(OptimizeError::InvalidInput(
167            "Dimension mismatch in linear system".to_string(),
168        ));
169    }
170
171    // LU decomposition with partial pivoting
172    let mut lu = a.clone();
173    let mut perm: Vec<usize> = (0..n).collect();
174
175    for k in 0..n {
176        // Find pivot
177        let mut max_val = lu[[k, k]].abs();
178        let mut max_row = k;
179        for i in (k + 1)..n {
180            if lu[[i, k]].abs() > max_val {
181                max_val = lu[[i, k]].abs();
182                max_row = i;
183            }
184        }
185
186        if max_val < 1e-30 {
187            return Err(OptimizeError::ComputationError(
188                "Singular or near-singular matrix in linear solve".to_string(),
189            ));
190        }
191
192        // Swap rows
193        if max_row != k {
194            perm.swap(k, max_row);
195            for j in 0..n {
196                let tmp = lu[[k, j]];
197                lu[[k, j]] = lu[[max_row, j]];
198                lu[[max_row, j]] = tmp;
199            }
200        }
201
202        // Eliminate
203        for i in (k + 1)..n {
204            lu[[i, k]] /= lu[[k, k]];
205            for j in (k + 1)..n {
206                lu[[i, j]] -= lu[[i, k]] * lu[[k, j]];
207            }
208        }
209    }
210
211    // Apply permutation to b
212    let mut pb = Array1::zeros(n);
213    for i in 0..n {
214        pb[i] = b[perm[i]];
215    }
216
217    // Forward substitution (L * y = Pb)
218    let mut y = pb;
219    for i in 1..n {
220        for j in 0..i {
221            y[i] -= lu[[i, j]] * y[j];
222        }
223    }
224
225    // Back substitution (U * x = y)
226    let mut x = y;
227    for i in (0..n).rev() {
228        for j in (i + 1)..n {
229            x[i] -= lu[[i, j]] * x[j];
230        }
231        x[i] /= lu[[i, i]];
232    }
233
234    Ok(x)
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_pairwise_distances() {
243        let x = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0])
244            .expect("Array creation failed");
245        let y = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0])
246            .expect("Array creation failed");
247        let dists = pairwise_sq_distances(&x, &y);
248        assert!((dists[[0, 0]] - 1.0).abs() < 1e-10);
249        assert!((dists[[0, 1]] - 1.0).abs() < 1e-10);
250        assert!((dists[[1, 0]] - 1.0).abs() < 1e-10);
251        assert!((dists[[1, 1]] - 1.0).abs() < 1e-10);
252    }
253
254    #[test]
255    fn test_solve_spd() {
256        // A = [[4, 2], [2, 3]], b = [1, 2]
257        // Solution: x = [-1/8, 3/4]
258        let a = Array2::from_shape_vec((2, 2), vec![4.0, 2.0, 2.0, 3.0])
259            .expect("Array creation failed");
260        let b = Array1::from_vec(vec![1.0, 2.0]);
261        let x = solve_spd(&a, &b).expect("SPD solve failed");
262        assert!((x[0] - (-0.125)).abs() < 1e-10);
263        assert!((x[1] - 0.75).abs() < 1e-10);
264    }
265
266    #[test]
267    fn test_solve_general() {
268        let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
269            .expect("Array creation failed");
270        let b = Array1::from_vec(vec![5.0, 11.0]);
271        let x = solve_general(&a, &b).expect("General solve failed");
272        assert!((x[0] - 1.0).abs() < 1e-10);
273        assert!((x[1] - 2.0).abs() < 1e-10);
274    }
275}