scirs2_optimize/surrogate/
mod.rs1pub mod ensemble;
19pub mod kriging;
20pub mod rbf_surrogate;
21
22pub use ensemble::{EnsembleOptions, EnsembleSurrogate, ModelSelectionCriterion};
23pub use kriging::{CorrelationFunction, KrigingOptions, KrigingSurrogate};
24pub use rbf_surrogate::{RbfKernel, RbfOptions, RbfSurrogate};
25
26use crate::error::{OptimizeError, OptimizeResult};
27use scirs2_core::ndarray::{Array1, Array2};
28
29pub trait SurrogateModel {
31 fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()>;
37
38 fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64>;
46
47 fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)>;
55
56 fn predict_batch(&self, x: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
64 let n = x.nrows();
65 let mut predictions = Array1::zeros(n);
66 for i in 0..n {
67 predictions[i] = self.predict(&x.row(i).to_owned())?;
68 }
69 Ok(predictions)
70 }
71
72 fn n_samples(&self) -> usize;
74
75 fn n_features(&self) -> usize;
77
78 fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()>;
80}
81
82pub fn pairwise_sq_distances(x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
84 let n = x.nrows();
85 let m = y.nrows();
86 let mut dists = Array2::zeros((n, m));
87 for i in 0..n {
88 for j in 0..m {
89 let mut sq_dist = 0.0;
90 for k in 0..x.ncols() {
91 let diff = x[[i, k]] - y[[j, k]];
92 sq_dist += diff * diff;
93 }
94 dists[[i, j]] = sq_dist;
95 }
96 }
97 dists
98}
99
100pub fn solve_spd(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
103 let n = a.nrows();
104 if n != a.ncols() {
105 return Err(OptimizeError::InvalidInput(
106 "Matrix must be square".to_string(),
107 ));
108 }
109 if n != b.len() {
110 return Err(OptimizeError::InvalidInput(
111 "Matrix and vector dimensions must match".to_string(),
112 ));
113 }
114
115 let mut l = Array2::zeros((n, n));
117 for j in 0..n {
118 let mut sum = 0.0;
119 for k in 0..j {
120 sum += l[[j, k]] * l[[j, k]];
121 }
122 let diag = a[[j, j]] - sum;
123 if diag <= 0.0 {
124 return Err(OptimizeError::ComputationError(
125 "Matrix is not positive definite".to_string(),
126 ));
127 }
128 l[[j, j]] = diag.sqrt();
129
130 for i in (j + 1)..n {
131 let mut sum = 0.0;
132 for k in 0..j {
133 sum += l[[i, k]] * l[[j, k]];
134 }
135 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
136 }
137 }
138
139 let mut z = Array1::zeros(n);
141 for i in 0..n {
142 let mut sum = 0.0;
143 for j in 0..i {
144 sum += l[[i, j]] * z[j];
145 }
146 z[i] = (b[i] - sum) / l[[i, i]];
147 }
148
149 let mut x = Array1::zeros(n);
151 for i in (0..n).rev() {
152 let mut sum = 0.0;
153 for j in (i + 1)..n {
154 sum += l[[j, i]] * x[j];
155 }
156 x[i] = (z[i] - sum) / l[[i, i]];
157 }
158
159 Ok(x)
160}
161
162pub fn solve_general(a: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
164 let n = a.nrows();
165 if n != a.ncols() || n != b.len() {
166 return Err(OptimizeError::InvalidInput(
167 "Dimension mismatch in linear system".to_string(),
168 ));
169 }
170
171 let mut lu = a.clone();
173 let mut perm: Vec<usize> = (0..n).collect();
174
175 for k in 0..n {
176 let mut max_val = lu[[k, k]].abs();
178 let mut max_row = k;
179 for i in (k + 1)..n {
180 if lu[[i, k]].abs() > max_val {
181 max_val = lu[[i, k]].abs();
182 max_row = i;
183 }
184 }
185
186 if max_val < 1e-30 {
187 return Err(OptimizeError::ComputationError(
188 "Singular or near-singular matrix in linear solve".to_string(),
189 ));
190 }
191
192 if max_row != k {
194 perm.swap(k, max_row);
195 for j in 0..n {
196 let tmp = lu[[k, j]];
197 lu[[k, j]] = lu[[max_row, j]];
198 lu[[max_row, j]] = tmp;
199 }
200 }
201
202 for i in (k + 1)..n {
204 lu[[i, k]] /= lu[[k, k]];
205 for j in (k + 1)..n {
206 lu[[i, j]] -= lu[[i, k]] * lu[[k, j]];
207 }
208 }
209 }
210
211 let mut pb = Array1::zeros(n);
213 for i in 0..n {
214 pb[i] = b[perm[i]];
215 }
216
217 let mut y = pb;
219 for i in 1..n {
220 for j in 0..i {
221 y[i] -= lu[[i, j]] * y[j];
222 }
223 }
224
225 let mut x = y;
227 for i in (0..n).rev() {
228 for j in (i + 1)..n {
229 x[i] -= lu[[i, j]] * x[j];
230 }
231 x[i] /= lu[[i, i]];
232 }
233
234 Ok(x)
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_pairwise_distances() {
243 let x = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0])
244 .expect("Array creation failed");
245 let y = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0])
246 .expect("Array creation failed");
247 let dists = pairwise_sq_distances(&x, &y);
248 assert!((dists[[0, 0]] - 1.0).abs() < 1e-10);
249 assert!((dists[[0, 1]] - 1.0).abs() < 1e-10);
250 assert!((dists[[1, 0]] - 1.0).abs() < 1e-10);
251 assert!((dists[[1, 1]] - 1.0).abs() < 1e-10);
252 }
253
254 #[test]
255 fn test_solve_spd() {
256 let a = Array2::from_shape_vec((2, 2), vec![4.0, 2.0, 2.0, 3.0])
259 .expect("Array creation failed");
260 let b = Array1::from_vec(vec![1.0, 2.0]);
261 let x = solve_spd(&a, &b).expect("SPD solve failed");
262 assert!((x[0] - (-0.125)).abs() < 1e-10);
263 assert!((x[1] - 0.75).abs() < 1e-10);
264 }
265
266 #[test]
267 fn test_solve_general() {
268 let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0])
269 .expect("Array creation failed");
270 let b = Array1::from_vec(vec![5.0, 11.0]);
271 let x = solve_general(&a, &b).expect("General solve failed");
272 assert!((x[0] - 1.0).abs() < 1e-10);
273 assert!((x[1] - 2.0).abs() < 1e-10);
274 }
275}