1use super::{solve_general, SurrogateModel};
22use crate::error::{OptimizeError, OptimizeResult};
23use scirs2_core::ndarray::{Array1, Array2};
24use scirs2_core::random::rngs::StdRng;
25use scirs2_core::random::{Rng, SeedableRng};
26use scirs2_core::RngExt;
27
28#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum CorrelationFunction {
31 SquaredExponential,
33 Matern32,
35 Matern52,
37 Exponential,
39 PowerExponential {
41 p: f64,
43 },
44}
45
46impl Default for CorrelationFunction {
47 fn default() -> Self {
48 CorrelationFunction::SquaredExponential
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct KrigingOptions {
55 pub correlation: CorrelationFunction,
57 pub nugget: Option<f64>,
60 pub optimize_hyperparams: bool,
62 pub n_restarts: usize,
64 pub seed: Option<u64>,
66 pub initial_theta: Option<Vec<f64>>,
68 pub theta_lower: f64,
70 pub theta_upper: f64,
72}
73
74impl Default for KrigingOptions {
75 fn default() -> Self {
76 Self {
77 correlation: CorrelationFunction::default(),
78 nugget: Some(1e-6),
79 optimize_hyperparams: true,
80 n_restarts: 5,
81 seed: None,
82 initial_theta: None,
83 theta_lower: 1e-3,
84 theta_upper: 1e3,
85 }
86 }
87}
88
89pub struct KrigingSurrogate {
91 options: KrigingOptions,
92 x_train: Option<Array2<f64>>,
94 y_train: Option<Array1<f64>>,
96 theta: Option<Vec<f64>>,
98 nugget: f64,
100 alpha: Option<Array1<f64>>,
102 mu: f64,
104 sigma_sq: f64,
106 corr_matrix: Option<Array2<f64>>,
108 chol_factor: Option<Array2<f64>>,
110 x_min: Option<Array1<f64>>,
112 x_range: Option<Array1<f64>>,
113 y_mean: f64,
114 y_std: f64,
115}
116
117impl KrigingSurrogate {
118 pub fn new(options: KrigingOptions) -> Self {
120 let nugget = options.nugget.unwrap_or(1e-6);
121 Self {
122 options,
123 x_train: None,
124 y_train: None,
125 theta: None,
126 nugget,
127 alpha: None,
128 mu: 0.0,
129 sigma_sq: 1.0,
130 corr_matrix: None,
131 chol_factor: None,
132 x_min: None,
133 x_range: None,
134 y_mean: 0.0,
135 y_std: 1.0,
136 }
137 }
138
139 fn correlation(&self, x1: &[f64], x2: &[f64], theta: &[f64]) -> f64 {
141 let d = x1.len();
142 match self.options.correlation {
143 CorrelationFunction::SquaredExponential => {
144 let mut sum = 0.0;
145 for k in 0..d {
146 let diff = x1[k] - x2[k];
147 sum += theta[k.min(theta.len() - 1)] * diff * diff;
148 }
149 (-sum).exp()
150 }
151 CorrelationFunction::Matern32 => {
152 let mut weighted_sq_sum = 0.0;
153 for k in 0..d {
154 let diff = x1[k] - x2[k];
155 weighted_sq_sum += theta[k.min(theta.len() - 1)] * diff * diff;
156 }
157 let r = (3.0 * weighted_sq_sum).sqrt();
158 (1.0 + r) * (-r).exp()
159 }
160 CorrelationFunction::Matern52 => {
161 let mut weighted_sq_sum = 0.0;
162 for k in 0..d {
163 let diff = x1[k] - x2[k];
164 weighted_sq_sum += theta[k.min(theta.len() - 1)] * diff * diff;
165 }
166 let r = (5.0 * weighted_sq_sum).sqrt();
167 (1.0 + r + r * r / 3.0) * (-r).exp()
168 }
169 CorrelationFunction::Exponential => {
170 let mut sum = 0.0;
171 for k in 0..d {
172 let diff = (x1[k] - x2[k]).abs();
173 sum += theta[k.min(theta.len() - 1)] * diff;
174 }
175 (-sum).exp()
176 }
177 CorrelationFunction::PowerExponential { p } => {
178 let mut sum = 0.0;
179 for k in 0..d {
180 let diff = (x1[k] - x2[k]).abs();
181 sum += theta[k.min(theta.len() - 1)] * diff.powf(p);
182 }
183 (-sum).exp()
184 }
185 }
186 }
187
188 fn compute_correlation_matrix(
190 &self,
191 x: &Array2<f64>,
192 theta: &[f64],
193 nugget: f64,
194 ) -> Array2<f64> {
195 let n = x.nrows();
196 let mut r = Array2::zeros((n, n));
197 for i in 0..n {
198 r[[i, i]] = 1.0 + nugget;
199 let x_i: Vec<f64> = (0..x.ncols()).map(|k| x[[i, k]]).collect();
200 for j in (i + 1)..n {
201 let x_j: Vec<f64> = (0..x.ncols()).map(|k| x[[j, k]]).collect();
202 let c = self.correlation(&x_i, &x_j, theta);
203 r[[i, j]] = c;
204 r[[j, i]] = c;
205 }
206 }
207 r
208 }
209
210 fn compute_correlation_vector(
212 &self,
213 x: &[f64],
214 x_train: &Array2<f64>,
215 theta: &[f64],
216 ) -> Array1<f64> {
217 let n = x_train.nrows();
218 let mut r = Array1::zeros(n);
219 for i in 0..n {
220 let x_i: Vec<f64> = (0..x_train.ncols()).map(|k| x_train[[i, k]]).collect();
221 r[i] = self.correlation(x, &x_i, theta);
222 }
223 r
224 }
225
226 fn cholesky(&self, a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
228 let n = a.nrows();
229 let mut l = Array2::zeros((n, n));
230 for j in 0..n {
231 let mut sum = 0.0;
232 for k in 0..j {
233 sum += l[[j, k]] * l[[j, k]];
234 }
235 let diag = a[[j, j]] - sum;
236 if diag <= 0.0 {
237 return Err(OptimizeError::ComputationError(
238 "Correlation matrix is not positive definite".to_string(),
239 ));
240 }
241 l[[j, j]] = diag.sqrt();
242 for i in (j + 1)..n {
243 let mut sum = 0.0;
244 for k in 0..j {
245 sum += l[[i, k]] * l[[j, k]];
246 }
247 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
248 }
249 }
250 Ok(l)
251 }
252
253 fn solve_lower(&self, l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
255 let n = b.len();
256 let mut x = Array1::zeros(n);
257 for i in 0..n {
258 let mut sum = 0.0;
259 for j in 0..i {
260 sum += l[[i, j]] * x[j];
261 }
262 x[i] = (b[i] - sum) / l[[i, i]];
263 }
264 x
265 }
266
267 fn solve_upper(&self, l: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
269 let n = b.len();
270 let mut x = Array1::zeros(n);
271 for i in (0..n).rev() {
272 let mut sum = 0.0;
273 for j in (i + 1)..n {
274 sum += l[[j, i]] * x[j];
275 }
276 x[i] = (b[i] - sum) / l[[i, i]];
277 }
278 x
279 }
280
281 fn log_likelihood(
283 &self,
284 x_train: &Array2<f64>,
285 y_train: &Array1<f64>,
286 theta: &[f64],
287 nugget: f64,
288 ) -> f64 {
289 let n = x_train.nrows();
290 let r = self.compute_correlation_matrix(x_train, theta, nugget);
291
292 let chol = match self.cholesky(&r) {
293 Ok(l) => l,
294 Err(_) => return f64::NEG_INFINITY,
295 };
296
297 let log_det: f64 = (0..n).map(|i| chol[[i, i]].ln()).sum::<f64>() * 2.0;
299
300 let ones = Array1::ones(n);
302 let z = self.solve_lower(&chol, &ones);
303 let r_inv_ones = self.solve_upper(&chol, &z);
304 let ones_r_inv_ones: f64 = ones.dot(&r_inv_ones);
305
306 if ones_r_inv_ones.abs() < 1e-30 {
307 return f64::NEG_INFINITY;
308 }
309
310 let z_y = self.solve_lower(&chol, y_train);
312 let r_inv_y = self.solve_upper(&chol, &z_y);
313
314 let mu_hat = ones.dot(&r_inv_y) / ones_r_inv_ones;
315
316 let residual: Array1<f64> = y_train - mu_hat;
318 let z_res = self.solve_lower(&chol, &residual);
319 let r_inv_res = self.solve_upper(&chol, &z_res);
320 let sigma_sq = residual.dot(&r_inv_res) / n as f64;
321
322 if sigma_sq <= 0.0 {
323 return f64::NEG_INFINITY;
324 }
325
326 -0.5 * (n as f64 * sigma_sq.ln() + log_det)
328 }
329
330 fn optimize_hyperparameters(
332 &self,
333 x_train: &Array2<f64>,
334 y_train: &Array1<f64>,
335 ) -> (Vec<f64>, f64) {
336 let d = x_train.ncols();
337 let seed = self
338 .options
339 .seed
340 .unwrap_or_else(|| scirs2_core::random::rng().random_range(0..u64::MAX));
341 let mut rng = StdRng::seed_from_u64(seed);
342
343 let theta_lo = self.options.theta_lower;
344 let theta_hi = self.options.theta_upper;
345 let log_lo = theta_lo.ln();
346 let log_hi = theta_hi.ln();
347
348 let nugget = self.nugget;
349
350 let mut best_theta: Vec<f64> = self
352 .options
353 .initial_theta
354 .clone()
355 .unwrap_or_else(|| vec![1.0; d]);
356
357 let mut best_ll = self.log_likelihood(x_train, y_train, &best_theta, nugget);
358
359 for _ in 0..self.options.n_restarts {
361 let theta: Vec<f64> = (0..d)
362 .map(|_| rng.random_range(log_lo..log_hi).exp())
363 .collect();
364
365 let ll = self.log_likelihood(x_train, y_train, &theta, nugget);
366 if ll > best_ll {
367 best_ll = ll;
368 best_theta = theta;
369 }
370 }
371
372 for _ in 0..3 {
374 for k in 0..d {
375 let mut best_tk = best_theta[k];
376 let mut best_ll_k = best_ll;
377
378 for &factor in &[0.1, 0.3, 0.5, 0.7, 1.5, 2.0, 3.0, 10.0] {
379 let mut trial = best_theta.clone();
380 trial[k] = (best_theta[k] * factor).clamp(theta_lo, theta_hi);
381 let ll = self.log_likelihood(x_train, y_train, &trial, nugget);
382 if ll > best_ll_k {
383 best_ll_k = ll;
384 best_tk = trial[k];
385 }
386 }
387
388 best_theta[k] = best_tk;
389 best_ll = best_ll_k;
390 }
391 }
392
393 (best_theta, nugget)
394 }
395
396 fn normalize_x(&self, x: &Array2<f64>) -> Array2<f64> {
398 if let (Some(ref x_min), Some(ref x_range)) = (&self.x_min, &self.x_range) {
399 let mut normalized = x.clone();
400 for i in 0..x.nrows() {
401 for j in 0..x.ncols() {
402 let r = if x_range[j] > 1e-30 { x_range[j] } else { 1.0 };
403 normalized[[i, j]] = (x[[i, j]] - x_min[j]) / r;
404 }
405 }
406 normalized
407 } else {
408 x.clone()
409 }
410 }
411
412 fn normalize_x_point(&self, x: &Array1<f64>) -> Vec<f64> {
414 if let (Some(ref x_min), Some(ref x_range)) = (&self.x_min, &self.x_range) {
415 x.iter()
416 .enumerate()
417 .map(|(j, &xj)| {
418 let r = if x_range[j] > 1e-30 { x_range[j] } else { 1.0 };
419 (xj - x_min[j]) / r
420 })
421 .collect()
422 } else {
423 x.to_vec()
424 }
425 }
426}
427
428impl SurrogateModel for KrigingSurrogate {
429 fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
430 let n = x.nrows();
431 let d = x.ncols();
432
433 if n < 2 {
434 return Err(OptimizeError::InvalidInput(
435 "Need at least 2 data points for Kriging".to_string(),
436 ));
437 }
438
439 let mut x_min = Array1::zeros(d);
441 let mut x_max = Array1::zeros(d);
442 for j in 0..d {
443 let mut lo = f64::INFINITY;
444 let mut hi = f64::NEG_INFINITY;
445 for i in 0..n {
446 if x[[i, j]] < lo {
447 lo = x[[i, j]];
448 }
449 if x[[i, j]] > hi {
450 hi = x[[i, j]];
451 }
452 }
453 x_min[j] = lo;
454 x_max[j] = hi;
455 }
456 let x_range = &x_max - &x_min;
457 self.x_min = Some(x_min);
458 self.x_range = Some(x_range);
459
460 let y_sum: f64 = y.iter().sum();
461 self.y_mean = y_sum / n as f64;
462 let y_var: f64 = y.iter().map(|yi| (yi - self.y_mean).powi(2)).sum::<f64>() / n as f64;
463 self.y_std = y_var.sqrt().max(1e-30);
464
465 let x_norm = self.normalize_x(x);
467 let y_norm: Array1<f64> = y.mapv(|yi| (yi - self.y_mean) / self.y_std);
468
469 let (theta, nugget) = if self.options.optimize_hyperparams {
471 self.optimize_hyperparameters(&x_norm, &y_norm)
472 } else {
473 let theta = self
474 .options
475 .initial_theta
476 .clone()
477 .unwrap_or_else(|| vec![1.0; d]);
478 (theta, self.nugget)
479 };
480 self.theta = Some(theta.clone());
481 self.nugget = nugget;
482
483 let r = self.compute_correlation_matrix(&x_norm, &theta, nugget);
485 let chol = self.cholesky(&r)?;
486
487 let ones = Array1::ones(n);
489 let z = self.solve_lower(&chol, &ones);
490 let r_inv_ones = self.solve_upper(&chol, &z);
491 let ones_r_inv_ones = ones.dot(&r_inv_ones);
492
493 let z_y = self.solve_lower(&chol, &y_norm);
494 let r_inv_y = self.solve_upper(&chol, &z_y);
495
496 self.mu = if ones_r_inv_ones.abs() > 1e-30 {
497 ones.dot(&r_inv_y) / ones_r_inv_ones
498 } else {
499 y_norm.mean().unwrap_or(0.0)
500 };
501
502 let residual: Array1<f64> = &y_norm - self.mu;
504 let z_res = self.solve_lower(&chol, &residual);
505 let alpha = self.solve_upper(&chol, &z_res);
506
507 self.sigma_sq = (residual.dot(&alpha) / n as f64).max(1e-20);
509
510 self.alpha = Some(alpha);
511 self.corr_matrix = Some(r);
512 self.chol_factor = Some(chol);
513 self.x_train = Some(x_norm);
514 self.y_train = Some(y_norm);
515
516 Ok(())
517 }
518
519 fn predict(&self, x: &Array1<f64>) -> OptimizeResult<f64> {
520 let x_train = self
521 .x_train
522 .as_ref()
523 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
524 let alpha = self
525 .alpha
526 .as_ref()
527 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
528 let theta = self
529 .theta
530 .as_ref()
531 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
532
533 let x_norm = self.normalize_x_point(x);
534 let r = self.compute_correlation_vector(&x_norm, x_train, theta);
535
536 let prediction_norm = self.mu + r.dot(alpha);
537
538 Ok(prediction_norm * self.y_std + self.y_mean)
540 }
541
542 fn predict_with_uncertainty(&self, x: &Array1<f64>) -> OptimizeResult<(f64, f64)> {
543 let x_train = self
544 .x_train
545 .as_ref()
546 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
547 let alpha = self
548 .alpha
549 .as_ref()
550 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
551 let theta = self
552 .theta
553 .as_ref()
554 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
555 let chol = self
556 .chol_factor
557 .as_ref()
558 .ok_or_else(|| OptimizeError::ComputationError("Model not fitted".to_string()))?;
559
560 let n = x_train.nrows();
561 let x_norm = self.normalize_x_point(x);
562 let r = self.compute_correlation_vector(&x_norm, x_train, theta);
563
564 let prediction_norm = self.mu + r.dot(alpha);
566 let mean = prediction_norm * self.y_std + self.y_mean;
567
568 let z = self.solve_lower(chol, &r);
571 let rt_r_inv_r = z.dot(&z);
572
573 let ones = Array1::ones(n);
574 let z_ones = self.solve_lower(chol, &ones);
575 let ones_r_inv_r: f64 = z_ones.dot(&z);
576 let ones_r_inv_ones: f64 = z_ones.dot(&z_ones);
577
578 let numerator = (1.0 - ones_r_inv_r).powi(2);
579 let denominator = ones_r_inv_ones.max(1e-30);
580
581 let mse_norm = self.sigma_sq * (1.0 - rt_r_inv_r + numerator / denominator).max(0.0);
582 let std = (mse_norm * self.y_std * self.y_std).sqrt().max(1e-10);
583
584 Ok((mean, std))
585 }
586
587 fn n_samples(&self) -> usize {
588 self.x_train.as_ref().map_or(0, |x| x.nrows())
589 }
590
591 fn n_features(&self) -> usize {
592 self.x_train.as_ref().map_or(0, |x| x.ncols())
593 }
594
595 fn update(&mut self, x: &Array1<f64>, y: f64) -> OptimizeResult<()> {
596 let (new_x, new_y) =
598 if let (Some(ref x_train), Some(ref y_train)) = (&self.x_train, &self.y_train) {
599 let d = x_train.ncols();
600 let n = x_train.nrows();
601
602 let mut x_denorm = Array2::zeros((n, d));
604 for i in 0..n {
605 for j in 0..d {
606 let r = self.x_range.as_ref().map_or(1.0, |xr| {
607 if xr[j] > 1e-30 {
608 xr[j]
609 } else {
610 1.0
611 }
612 });
613 let m = self.x_min.as_ref().map_or(0.0, |xm| xm[j]);
614 x_denorm[[i, j]] = x_train[[i, j]] * r + m;
615 }
616 }
617 let y_denorm: Array1<f64> = y_train.mapv(|yi| yi * self.y_std + self.y_mean);
618
619 let mut new_x = Array2::zeros((n + 1, d));
620 for i in 0..n {
621 for j in 0..d {
622 new_x[[i, j]] = x_denorm[[i, j]];
623 }
624 }
625 for j in 0..d {
626 new_x[[n, j]] = x[j];
627 }
628
629 let mut new_y = Array1::zeros(n + 1);
630 for i in 0..n {
631 new_y[i] = y_denorm[i];
632 }
633 new_y[n] = y;
634
635 (new_x, new_y)
636 } else {
637 let d = x.len();
638 let mut new_x = Array2::zeros((1, d));
639 for j in 0..d {
640 new_x[[0, j]] = x[j];
641 }
642 (new_x, Array1::from_vec(vec![y]))
643 };
644
645 self.fit(&new_x, &new_y)
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_kriging_basic_interpolation() {
655 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 0.25, 0.5, 0.75, 1.0])
656 .expect("Array creation failed");
657 let y_train = Array1::from_vec(vec![0.0, 0.25, 1.0, 0.75, 0.0]);
658
659 let mut kriging = KrigingSurrogate::new(KrigingOptions {
660 nugget: Some(1e-4),
661 optimize_hyperparams: false,
662 initial_theta: Some(vec![10.0]),
663 ..Default::default()
664 });
665
666 let result = kriging.fit(&x_train, &y_train);
667 assert!(result.is_ok(), "Kriging fit failed: {:?}", result.err());
668
669 for i in 0..5 {
671 let x = Array1::from_vec(vec![x_train[[i, 0]]]);
672 let pred = kriging.predict(&x).expect("Prediction failed");
673 assert!(
674 (pred - y_train[i]).abs() < 0.2,
675 "Kriging interpolation error at {}: pred={}, actual={}",
676 i,
677 pred,
678 y_train[i]
679 );
680 }
681 }
682
683 #[test]
684 fn test_kriging_uncertainty() {
685 let x_train = Array2::from_shape_vec((4, 1), vec![0.0, 0.33, 0.66, 1.0])
686 .expect("Array creation failed");
687 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
688
689 let mut kriging = KrigingSurrogate::new(KrigingOptions {
690 nugget: Some(1e-4),
691 optimize_hyperparams: false,
692 initial_theta: Some(vec![5.0]),
693 ..Default::default()
694 });
695 kriging.fit(&x_train, &y_train).expect("Fit failed");
696
697 let (_, unc_near) = kriging
699 .predict_with_uncertainty(&Array1::from_vec(vec![0.33]))
700 .expect("Prediction failed");
701 let (_, unc_far) = kriging
702 .predict_with_uncertainty(&Array1::from_vec(vec![2.0]))
703 .expect("Prediction failed");
704
705 assert!(
706 unc_far > unc_near,
707 "Far uncertainty ({}) should exceed near uncertainty ({})",
708 unc_far,
709 unc_near
710 );
711 }
712
713 #[test]
714 fn test_kriging_2d() {
715 let x_train = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
716 .expect("Array creation failed");
717 let y_train = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
718
719 let mut kriging = KrigingSurrogate::new(KrigingOptions {
720 nugget: Some(1e-4),
721 n_restarts: 2,
722 ..Default::default()
723 });
724 assert!(kriging.fit(&x_train, &y_train).is_ok());
725
726 let pred = kriging.predict(&Array1::from_vec(vec![0.5, 0.5]));
727 assert!(pred.is_ok());
728 let val = pred.expect("2D prediction failed");
729 assert!(val > -1.0 && val < 3.0, "Kriging 2D prediction: {}", val);
730 }
731
732 #[test]
733 fn test_kriging_matern32() {
734 let x_train =
735 Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
736 let y_train = Array1::from_vec(vec![1.0, 2.0, 1.0]);
737
738 let mut kriging = KrigingSurrogate::new(KrigingOptions {
739 correlation: CorrelationFunction::Matern32,
740 nugget: Some(1e-4),
741 optimize_hyperparams: false,
742 initial_theta: Some(vec![5.0]),
743 ..Default::default()
744 });
745 assert!(kriging.fit(&x_train, &y_train).is_ok());
746 let pred = kriging.predict(&Array1::from_vec(vec![0.25]));
747 assert!(pred.is_ok());
748 }
749
750 #[test]
751 fn test_kriging_matern52() {
752 let x_train =
753 Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
754 let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
755
756 let mut kriging = KrigingSurrogate::new(KrigingOptions {
757 correlation: CorrelationFunction::Matern52,
758 nugget: Some(1e-4),
759 optimize_hyperparams: false,
760 initial_theta: Some(vec![5.0]),
761 ..Default::default()
762 });
763 assert!(kriging.fit(&x_train, &y_train).is_ok());
764 }
765
766 #[test]
767 fn test_kriging_exponential() {
768 let x_train =
769 Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
770 let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
771
772 let mut kriging = KrigingSurrogate::new(KrigingOptions {
773 correlation: CorrelationFunction::Exponential,
774 nugget: Some(1e-3),
775 optimize_hyperparams: false,
776 initial_theta: Some(vec![5.0]),
777 ..Default::default()
778 });
779 assert!(kriging.fit(&x_train, &y_train).is_ok());
780 }
781
782 #[test]
783 fn test_kriging_update() {
784 let x_train =
785 Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
786 let y_train = Array1::from_vec(vec![0.0, 1.0, 0.0]);
787
788 let mut kriging = KrigingSurrogate::new(KrigingOptions {
789 nugget: Some(1e-4),
790 optimize_hyperparams: false,
791 initial_theta: Some(vec![5.0]),
792 ..Default::default()
793 });
794 kriging.fit(&x_train, &y_train).expect("Fit failed");
795 assert_eq!(kriging.n_samples(), 3);
796
797 kriging
798 .update(&Array1::from_vec(vec![0.25]), 0.5)
799 .expect("Update failed");
800 assert_eq!(kriging.n_samples(), 4);
801 }
802
803 #[test]
804 fn test_kriging_power_exponential() {
805 let x_train =
806 Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).expect("Array creation failed");
807 let y_train = Array1::from_vec(vec![1.0, 0.5, 1.0]);
808
809 let mut kriging = KrigingSurrogate::new(KrigingOptions {
810 correlation: CorrelationFunction::PowerExponential { p: 1.5 },
811 nugget: Some(1e-3),
812 optimize_hyperparams: false,
813 initial_theta: Some(vec![5.0]),
814 ..Default::default()
815 });
816 assert!(kriging.fit(&x_train, &y_train).is_ok());
817 }
818}