1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use scirs2_core::random::{Rng, RngExt};
14
15use crate::error::{OptimizeError, OptimizeResult};
16
17pub trait SurrogateKernel: Send + Sync {
23 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64;
25
26 fn covariance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
28 let n = x.nrows();
29 let mut k = Array2::zeros((n, n));
30 for i in 0..n {
31 for j in 0..=i {
32 let kij = self.eval(&x.row(i), &x.row(j));
33 k[[i, j]] = kij;
34 if i != j {
35 k[[j, i]] = kij;
36 }
37 }
38 }
39 k
40 }
41
42 fn cross_covariance(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
44 let n1 = x1.nrows();
45 let n2 = x2.nrows();
46 let mut k = Array2::zeros((n1, n2));
47 for i in 0..n1 {
48 for j in 0..n2 {
49 k[[i, j]] = self.eval(&x1.row(i), &x2.row(j));
50 }
51 }
52 k
53 }
54
55 fn get_log_params(&self) -> Vec<f64>;
57
58 fn set_log_params(&mut self, params: &[f64]);
60
61 fn n_params(&self) -> usize {
63 self.get_log_params().len()
64 }
65
66 fn clone_box(&self) -> Box<dyn SurrogateKernel>;
68
69 fn name(&self) -> &str;
71}
72
73impl Clone for Box<dyn SurrogateKernel> {
74 fn clone(&self) -> Self {
75 self.clone_box()
76 }
77}
78
79#[derive(Debug, Clone)]
87pub struct RbfKernel {
88 pub length_scale: f64,
90 pub signal_variance: f64,
92}
93
94impl RbfKernel {
95 pub fn new(length_scale: f64, signal_variance: f64) -> Self {
97 Self {
98 length_scale: length_scale.max(1e-10),
99 signal_variance: signal_variance.max(1e-10),
100 }
101 }
102}
103
104impl Default for RbfKernel {
105 fn default() -> Self {
106 Self::new(1.0, 1.0)
107 }
108}
109
110impl SurrogateKernel for RbfKernel {
111 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
112 let sq_dist = squared_distance(x1, x2);
113 self.signal_variance * (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp()
114 }
115
116 fn get_log_params(&self) -> Vec<f64> {
117 vec![self.length_scale.ln(), self.signal_variance.ln()]
118 }
119
120 fn set_log_params(&mut self, params: &[f64]) {
121 if params.len() >= 2 {
122 self.length_scale = params[0].exp().max(1e-10);
123 self.signal_variance = params[1].exp().max(1e-10);
124 }
125 }
126
127 fn clone_box(&self) -> Box<dyn SurrogateKernel> {
128 Box::new(self.clone())
129 }
130
131 fn name(&self) -> &str {
132 "RBF"
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq)]
142pub enum MaternVariant {
143 OneHalf,
145 ThreeHalves,
147 FiveHalves,
149}
150
151#[derive(Debug, Clone)]
157pub struct MaternKernel {
158 pub variant: MaternVariant,
160 pub length_scale: f64,
162 pub signal_variance: f64,
164}
165
166impl MaternKernel {
167 pub fn new(variant: MaternVariant, length_scale: f64, signal_variance: f64) -> Self {
168 Self {
169 variant,
170 length_scale: length_scale.max(1e-10),
171 signal_variance: signal_variance.max(1e-10),
172 }
173 }
174
175 pub fn one_half(length_scale: f64, signal_variance: f64) -> Self {
177 Self::new(MaternVariant::OneHalf, length_scale, signal_variance)
178 }
179
180 pub fn three_halves(length_scale: f64, signal_variance: f64) -> Self {
182 Self::new(MaternVariant::ThreeHalves, length_scale, signal_variance)
183 }
184
185 pub fn five_halves(length_scale: f64, signal_variance: f64) -> Self {
187 Self::new(MaternVariant::FiveHalves, length_scale, signal_variance)
188 }
189}
190
191impl Default for MaternKernel {
192 fn default() -> Self {
193 Self::new(MaternVariant::FiveHalves, 1.0, 1.0)
194 }
195}
196
197impl SurrogateKernel for MaternKernel {
198 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
199 let r = squared_distance(x1, x2).sqrt();
200 let l = self.length_scale;
201 let sv = self.signal_variance;
202
203 match self.variant {
204 MaternVariant::OneHalf => sv * (-r / l).exp(),
205 MaternVariant::ThreeHalves => {
206 let sqrt3_r_l = 3.0_f64.sqrt() * r / l;
207 sv * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
208 }
209 MaternVariant::FiveHalves => {
210 let sqrt5_r_l = 5.0_f64.sqrt() * r / l;
211 let r2_l2 = r * r / (l * l);
212 sv * (1.0 + sqrt5_r_l + 5.0 * r2_l2 / 3.0) * (-sqrt5_r_l).exp()
213 }
214 }
215 }
216
217 fn get_log_params(&self) -> Vec<f64> {
218 vec![self.length_scale.ln(), self.signal_variance.ln()]
219 }
220
221 fn set_log_params(&mut self, params: &[f64]) {
222 if params.len() >= 2 {
223 self.length_scale = params[0].exp().max(1e-10);
224 self.signal_variance = params[1].exp().max(1e-10);
225 }
226 }
227
228 fn clone_box(&self) -> Box<dyn SurrogateKernel> {
229 Box::new(self.clone())
230 }
231
232 fn name(&self) -> &str {
233 match self.variant {
234 MaternVariant::OneHalf => "Matern12",
235 MaternVariant::ThreeHalves => "Matern32",
236 MaternVariant::FiveHalves => "Matern52",
237 }
238 }
239}
240
241#[derive(Debug, Clone)]
253pub struct RationalQuadraticKernel {
254 pub length_scale: f64,
256 pub signal_variance: f64,
258 pub alpha: f64,
260}
261
262impl RationalQuadraticKernel {
263 pub fn new(length_scale: f64, signal_variance: f64, alpha: f64) -> Self {
264 Self {
265 length_scale: length_scale.max(1e-10),
266 signal_variance: signal_variance.max(1e-10),
267 alpha: alpha.max(1e-10),
268 }
269 }
270}
271
272impl Default for RationalQuadraticKernel {
273 fn default() -> Self {
274 Self::new(1.0, 1.0, 1.0)
275 }
276}
277
278impl SurrogateKernel for RationalQuadraticKernel {
279 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
280 let sq_dist = squared_distance(x1, x2);
281 let base = 1.0 + sq_dist / (2.0 * self.alpha * self.length_scale * self.length_scale);
282 self.signal_variance * base.powf(-self.alpha)
283 }
284
285 fn get_log_params(&self) -> Vec<f64> {
286 vec![
287 self.length_scale.ln(),
288 self.signal_variance.ln(),
289 self.alpha.ln(),
290 ]
291 }
292
293 fn set_log_params(&mut self, params: &[f64]) {
294 if params.len() >= 3 {
295 self.length_scale = params[0].exp().max(1e-10);
296 self.signal_variance = params[1].exp().max(1e-10);
297 self.alpha = params[2].exp().max(1e-10);
298 }
299 }
300
301 fn clone_box(&self) -> Box<dyn SurrogateKernel> {
302 Box::new(self.clone())
303 }
304
305 fn name(&self) -> &str {
306 "RationalQuadratic"
307 }
308}
309
310#[derive(Clone)]
316pub struct SumKernel {
317 pub kernel1: Box<dyn SurrogateKernel>,
318 pub kernel2: Box<dyn SurrogateKernel>,
319}
320
321impl SumKernel {
322 pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
323 Self {
324 kernel1: k1,
325 kernel2: k2,
326 }
327 }
328}
329
330impl SurrogateKernel for SumKernel {
331 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
332 self.kernel1.eval(x1, x2) + self.kernel2.eval(x1, x2)
333 }
334
335 fn get_log_params(&self) -> Vec<f64> {
336 let mut p = self.kernel1.get_log_params();
337 p.extend(self.kernel2.get_log_params());
338 p
339 }
340
341 fn set_log_params(&mut self, params: &[f64]) {
342 let n1 = self.kernel1.n_params();
343 if params.len() >= n1 {
344 self.kernel1.set_log_params(¶ms[..n1]);
345 }
346 if params.len() > n1 {
347 self.kernel2.set_log_params(¶ms[n1..]);
348 }
349 }
350
351 fn clone_box(&self) -> Box<dyn SurrogateKernel> {
352 Box::new(self.clone())
353 }
354
355 fn name(&self) -> &str {
356 "Sum"
357 }
358}
359
360#[derive(Clone)]
362pub struct ProductKernel {
363 pub kernel1: Box<dyn SurrogateKernel>,
364 pub kernel2: Box<dyn SurrogateKernel>,
365}
366
367impl ProductKernel {
368 pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
369 Self {
370 kernel1: k1,
371 kernel2: k2,
372 }
373 }
374}
375
376impl SurrogateKernel for ProductKernel {
377 fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
378 self.kernel1.eval(x1, x2) * self.kernel2.eval(x1, x2)
379 }
380
381 fn get_log_params(&self) -> Vec<f64> {
382 let mut p = self.kernel1.get_log_params();
383 p.extend(self.kernel2.get_log_params());
384 p
385 }
386
387 fn set_log_params(&mut self, params: &[f64]) {
388 let n1 = self.kernel1.n_params();
389 if params.len() >= n1 {
390 self.kernel1.set_log_params(¶ms[..n1]);
391 }
392 if params.len() > n1 {
393 self.kernel2.set_log_params(¶ms[n1..]);
394 }
395 }
396
397 fn clone_box(&self) -> Box<dyn SurrogateKernel> {
398 Box::new(self.clone())
399 }
400
401 fn name(&self) -> &str {
402 "Product"
403 }
404}
405
406#[derive(Clone)]
412pub struct GpSurrogateConfig {
413 pub noise_variance: f64,
415 pub optimize_hyperparams: bool,
417 pub n_restarts: usize,
419 pub max_opt_iters: usize,
421}
422
423impl Default for GpSurrogateConfig {
424 fn default() -> Self {
425 Self {
426 noise_variance: 1e-6,
427 optimize_hyperparams: true,
428 n_restarts: 3,
429 max_opt_iters: 50,
430 }
431 }
432}
433
434pub struct GpSurrogate {
439 kernel: Box<dyn SurrogateKernel>,
441 config: GpSurrogateConfig,
443 x_train: Option<Array2<f64>>,
445 y_train: Option<Array1<f64>>,
447 y_mean: f64,
449 y_std: f64,
451 l_factor: Option<Array2<f64>>,
453 alpha: Option<Array1<f64>>,
455}
456
457impl GpSurrogate {
458 pub fn new(kernel: Box<dyn SurrogateKernel>, config: GpSurrogateConfig) -> Self {
460 Self {
461 kernel,
462 config,
463 x_train: None,
464 y_train: None,
465 y_mean: 0.0,
466 y_std: 1.0,
467 l_factor: None,
468 alpha: None,
469 }
470 }
471
472 pub fn default_rbf() -> Self {
474 Self::new(Box::new(RbfKernel::default()), GpSurrogateConfig::default())
475 }
476
477 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
481 if x.nrows() != y.len() {
482 return Err(OptimizeError::InvalidInput(format!(
483 "x has {} rows but y has {} elements",
484 x.nrows(),
485 y.len()
486 )));
487 }
488 if x.nrows() == 0 {
489 return Err(OptimizeError::InvalidInput(
490 "Cannot fit GP with zero training samples".to_string(),
491 ));
492 }
493
494 self.x_train = Some(x.clone());
495 self.y_train = Some(y.clone());
496
497 self.y_mean = y.iter().sum::<f64>() / y.len() as f64;
499 let variance = y.iter().map(|&v| (v - self.y_mean).powi(2)).sum::<f64>() / y.len() as f64;
500 self.y_std = if variance > 1e-12 {
501 variance.sqrt()
502 } else {
503 1.0
504 };
505
506 if self.config.optimize_hyperparams && x.nrows() >= 3 {
508 self.optimize_hyperparameters()?;
509 }
510
511 self.update_model()
513 }
514
515 pub fn update(&mut self, x_new: &Array2<f64>, y_new: &Array1<f64>) -> OptimizeResult<()> {
517 if x_new.nrows() != y_new.len() {
518 return Err(OptimizeError::InvalidInput(
519 "x_new and y_new must have same number of rows".to_string(),
520 ));
521 }
522
523 match (&self.x_train, &self.y_train) {
524 (Some(xt), Some(yt)) => {
525 let mut x_all = Array2::zeros((xt.nrows() + x_new.nrows(), xt.ncols()));
526 for i in 0..xt.nrows() {
527 for j in 0..xt.ncols() {
528 x_all[[i, j]] = xt[[i, j]];
529 }
530 }
531 for i in 0..x_new.nrows() {
532 for j in 0..x_new.ncols() {
533 x_all[[xt.nrows() + i, j]] = x_new[[i, j]];
534 }
535 }
536 let mut y_all = Array1::zeros(yt.len() + y_new.len());
537 for i in 0..yt.len() {
538 y_all[i] = yt[i];
539 }
540 for i in 0..y_new.len() {
541 y_all[yt.len() + i] = y_new[i];
542 }
543 self.fit(&x_all, &y_all)
544 }
545 _ => self.fit(x_new, y_new),
546 }
547 }
548
549 pub fn predict_mean(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
551 let (mean, _) = self.predict(x_test)?;
552 Ok(mean)
553 }
554
555 pub fn predict_variance(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
557 let (_, var) = self.predict(x_test)?;
558 Ok(var)
559 }
560
561 pub fn predict(&self, x_test: &Array2<f64>) -> OptimizeResult<(Array1<f64>, Array1<f64>)> {
563 let x_train = self.x_train.as_ref().ok_or_else(|| {
564 OptimizeError::ComputationError("GP must be fitted before prediction".to_string())
565 })?;
566 let alpha = self
567 .alpha
568 .as_ref()
569 .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
570 let l_factor = self
571 .l_factor
572 .as_ref()
573 .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
574
575 let k_star = self.kernel.cross_covariance(x_test, x_train);
577
578 let mean_std = k_star.dot(alpha);
580
581 let mean = mean_std.mapv(|v| v * self.y_std + self.y_mean);
583
584 let n_test = x_test.nrows();
587 let mut variance = Array1::zeros(n_test);
588
589 for i in 0..n_test {
590 let k_self = self.kernel.eval(&x_test.row(i), &x_test.row(i));
591
592 let k_col = k_star.row(i).to_owned();
594 let v = forward_solve(l_factor, &k_col)?;
595
596 let v_sq_sum: f64 = v.iter().map(|&vi| vi * vi).sum();
597 let var = (k_self - v_sq_sum).max(0.0);
598 variance[i] = var * self.y_std * self.y_std;
599 }
600
601 Ok((mean, variance))
602 }
603
604 pub fn predict_single(&self, x: &ArrayView1<f64>) -> OptimizeResult<(f64, f64)> {
606 let x_mat = x
607 .to_owned()
608 .into_shape_with_order((1, x.len()))
609 .map_err(|e| OptimizeError::ComputationError(format!("Shape error: {}", e)))?;
610 let (mean, var) = self.predict(&x_mat)?;
611 Ok((mean[0], var[0].max(0.0).sqrt()))
612 }
613
614 pub fn log_marginal_likelihood(&self) -> OptimizeResult<f64> {
616 let y_train = self
617 .y_train
618 .as_ref()
619 .ok_or_else(|| OptimizeError::ComputationError("GP must be fitted".to_string()))?;
620 let l_factor = self
621 .l_factor
622 .as_ref()
623 .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
624 let alpha = self
625 .alpha
626 .as_ref()
627 .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
628
629 let y_std = &self.standardize_y(y_train);
630 let n = y_std.len() as f64;
631
632 let data_fit = -0.5 * y_std.dot(alpha);
634
635 let log_det: f64 = l_factor.diag().iter().map(|&v| v.abs().ln()).sum();
637
638 let norm = -0.5 * n * (2.0 * std::f64::consts::PI).ln();
640
641 Ok(data_fit - log_det + norm)
642 }
643
644 pub fn kernel(&self) -> &dyn SurrogateKernel {
646 self.kernel.as_ref()
647 }
648
649 pub fn kernel_mut(&mut self) -> &mut dyn SurrogateKernel {
651 self.kernel.as_mut()
652 }
653
654 pub fn n_train(&self) -> usize {
656 self.x_train.as_ref().map_or(0, |x| x.nrows())
657 }
658
659 fn standardize_y(&self, y: &Array1<f64>) -> Array1<f64> {
665 y.mapv(|v| (v - self.y_mean) / self.y_std)
666 }
667
668 fn update_model(&mut self) -> OptimizeResult<()> {
670 let x_train = self
671 .x_train
672 .as_ref()
673 .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
674 let y_train = self
675 .y_train
676 .as_ref()
677 .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
678
679 let y_std = self.standardize_y(y_train);
680
681 let mut k = self.kernel.covariance_matrix(x_train);
683 let n = k.nrows();
684
685 for i in 0..n {
687 k[[i, i]] += self.config.noise_variance;
688 }
689
690 let l = match cholesky(&k) {
692 Ok(l) => l,
693 Err(_) => {
694 let jitters = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2];
696 let mut result = Err(OptimizeError::ComputationError(
697 "Cholesky failed with all jitter levels".to_string(),
698 ));
699 for &jitter in &jitters {
700 for i in 0..n {
701 k[[i, i]] += jitter;
702 }
703 match cholesky(&k) {
704 Ok(l) => {
705 result = Ok(l);
706 break;
707 }
708 Err(_) => continue,
709 }
710 }
711 result?
712 }
713 };
714
715 let alpha1 = forward_solve(&l, &y_std)?;
717 let alpha = backward_solve_transpose(&l, &alpha1)?;
719
720 self.l_factor = Some(l);
721 self.alpha = Some(alpha);
722
723 Ok(())
724 }
725
726 fn optimize_hyperparameters(&mut self) -> OptimizeResult<()> {
732 let x_train = self
733 .x_train
734 .as_ref()
735 .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
736 .clone();
737 let y_train = self
738 .y_train
739 .as_ref()
740 .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
741 .clone();
742 let y_std = self.standardize_y(&y_train);
743
744 let n_params = self.kernel.n_params();
745 if n_params == 0 {
746 return Ok(());
747 }
748
749 let mut best_params = self.kernel.get_log_params();
750 let mut best_lml = f64::NEG_INFINITY;
751
752 if let Ok(lml) = self.eval_lml_at_params(&best_params, &x_train, &y_std) {
754 best_lml = lml;
755 }
756
757 let mut rng = scirs2_core::random::rng();
758
759 for restart in 0..self.config.n_restarts {
761 let init_params: Vec<f64> = if restart == 0 {
762 best_params.clone()
763 } else {
764 (0..n_params).map(|_| rng.random_range(-2.0..2.0)).collect()
765 };
766
767 let mut current_params = init_params;
769 for _iter in 0..self.config.max_opt_iters {
770 let mut improved = false;
771 for p in 0..n_params {
772 let original = current_params[p];
773
774 let steps = [0.1, 0.3, 1.0, -0.1, -0.3, -1.0];
776 let mut best_step_lml =
777 match self.eval_lml_at_params(¤t_params, &x_train, &y_std) {
778 Ok(v) => v,
779 Err(_) => f64::NEG_INFINITY,
780 };
781 let mut best_step_val = original;
782
783 for &step in &steps {
784 current_params[p] = original + step;
785 current_params[p] = current_params[p].clamp(-5.0, 5.0);
787
788 if let Ok(lml) = self.eval_lml_at_params(¤t_params, &x_train, &y_std)
789 {
790 if lml > best_step_lml {
791 best_step_lml = lml;
792 best_step_val = current_params[p];
793 improved = true;
794 }
795 }
796 }
797 current_params[p] = best_step_val;
798 }
799 if !improved {
800 break;
801 }
802 }
803
804 if let Ok(lml) = self.eval_lml_at_params(¤t_params, &x_train, &y_std) {
806 if lml > best_lml {
807 best_lml = lml;
808 best_params = current_params;
809 }
810 }
811 }
812
813 self.kernel.set_log_params(&best_params);
815
816 Ok(())
817 }
818
819 fn eval_lml_at_params(
822 &self,
823 log_params: &[f64],
824 x_train: &Array2<f64>,
825 y_std: &Array1<f64>,
826 ) -> OptimizeResult<f64> {
827 let mut kernel = self.kernel.clone();
828 kernel.set_log_params(log_params);
829
830 let mut k = kernel.covariance_matrix(x_train);
831 let n = k.nrows();
832 for i in 0..n {
833 k[[i, i]] += self.config.noise_variance;
834 }
835
836 let l = cholesky(&k)?;
837 let alpha1 = forward_solve(&l, y_std)?;
838 let alpha = backward_solve_transpose(&l, &alpha1)?;
839
840 let n_f = n as f64;
841 let data_fit = -0.5 * y_std.dot(&alpha);
842 let log_det: f64 = l.diag().iter().map(|&v| v.abs().ln()).sum();
843 let norm = -0.5 * n_f * (2.0 * std::f64::consts::PI).ln();
844
845 Ok(data_fit - log_det + norm)
846 }
847}
848
849fn squared_distance(x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
855 let mut s = 0.0;
856 for i in 0..x1.len() {
857 let d = x1[i] - x2[i];
858 s += d * d;
859 }
860 s
861}
862
863fn cholesky(a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
865 let n = a.nrows();
866 if n != a.ncols() {
867 return Err(OptimizeError::ComputationError(
868 "Cholesky: matrix must be square".to_string(),
869 ));
870 }
871 let mut l = Array2::zeros((n, n));
872
873 for i in 0..n {
874 for j in 0..=i {
875 let mut s = 0.0;
876 for k in 0..j {
877 s += l[[i, k]] * l[[j, k]];
878 }
879 if i == j {
880 let diag = a[[i, i]] - s;
881 if diag <= 0.0 {
882 return Err(OptimizeError::ComputationError(format!(
883 "Cholesky: matrix not positive-definite (diag[{}] = {:.6e})",
884 i, diag
885 )));
886 }
887 l[[i, j]] = diag.sqrt();
888 } else {
889 l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
890 }
891 }
892 }
893 Ok(l)
894}
895
896fn forward_solve(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
898 let n = l.nrows();
899 let mut x = Array1::zeros(n);
900 for i in 0..n {
901 let mut s = 0.0;
902 for j in 0..i {
903 s += l[[i, j]] * x[j];
904 }
905 let diag = l[[i, i]];
906 if diag.abs() < 1e-15 {
907 return Err(OptimizeError::ComputationError(
908 "Forward solve: near-zero diagonal".to_string(),
909 ));
910 }
911 x[i] = (b[i] - s) / diag;
912 }
913 Ok(x)
914}
915
916fn backward_solve_transpose(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
918 let n = l.nrows();
919 let mut x = Array1::zeros(n);
920 for i in (0..n).rev() {
921 let mut s = 0.0;
922 for j in (i + 1)..n {
923 s += l[[j, i]] * x[j]; }
925 let diag = l[[i, i]];
926 if diag.abs() < 1e-15 {
927 return Err(OptimizeError::ComputationError(
928 "Backward solve: near-zero diagonal".to_string(),
929 ));
930 }
931 x[i] = (b[i] - s) / diag;
932 }
933 Ok(x)
934}
935
936#[cfg(test)]
941mod tests {
942 use super::*;
943 use scirs2_core::ndarray::array;
944
945 fn make_train_data() -> (Array2<f64>, Array1<f64>) {
946 let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).expect("shape ok");
948 let y = array![0.0, 0.841, 0.909, 0.141, -0.757];
949 (x, y)
950 }
951
952 #[test]
953 fn test_rbf_kernel_symmetry() {
954 let k = RbfKernel::default();
955 let a = array![1.0, 2.0];
956 let b = array![3.0, 4.0];
957 assert!((k.eval(&a.view(), &b.view()) - k.eval(&b.view(), &a.view())).abs() < 1e-14);
958 }
959
960 #[test]
961 fn test_rbf_kernel_self_covariance() {
962 let k = RbfKernel::new(1.0, 2.0);
963 let a = array![1.0, 2.0];
964 assert!((k.eval(&a.view(), &a.view()) - 2.0).abs() < 1e-14);
966 }
967
968 #[test]
969 fn test_matern_variants() {
970 let a = array![0.0];
971 let b = array![1.0];
972
973 for variant in &[
974 MaternVariant::OneHalf,
975 MaternVariant::ThreeHalves,
976 MaternVariant::FiveHalves,
977 ] {
978 let k = MaternKernel::new(*variant, 1.0, 1.0);
979 let val = k.eval(&a.view(), &b.view());
980 assert!(val > 0.0 && val < 1.0, "Matern({:?}) = {}", variant, val);
981 assert!((k.eval(&a.view(), &a.view()) - 1.0).abs() < 1e-14);
983 }
984 }
985
986 #[test]
987 fn test_rational_quadratic_kernel() {
988 let k = RationalQuadraticKernel::new(1.0, 1.0, 1.0);
989 let a = array![0.0];
990 let b = array![1.0];
991 let val = k.eval(&a.view(), &b.view());
992 assert!((val - 2.0 / 3.0).abs() < 1e-10);
994 }
995
996 #[test]
997 fn test_rational_quadratic_approaches_rbf() {
998 let rbf = RbfKernel::new(1.0, 1.0);
1000 let rq = RationalQuadraticKernel::new(1.0, 1.0, 1e6);
1001 let a = array![0.0, 1.0];
1002 let b = array![2.0, 3.0];
1003
1004 let rbf_val = rbf.eval(&a.view(), &b.view());
1005 let rq_val = rq.eval(&a.view(), &b.view());
1006 assert!(
1007 (rbf_val - rq_val).abs() < 1e-4,
1008 "RBF={}, RQ(alpha=1e6)={}",
1009 rbf_val,
1010 rq_val
1011 );
1012 }
1013
1014 #[test]
1015 fn test_sum_kernel() {
1016 let k1 = Box::new(RbfKernel::new(1.0, 1.0));
1017 let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
1018 let sum = SumKernel::new(k1.clone(), k2.clone());
1019
1020 let a = array![1.0];
1021 let b = array![2.0];
1022 let expected = k1.eval(&a.view(), &b.view()) + k2.eval(&a.view(), &b.view());
1023 assert!((sum.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
1024 }
1025
1026 #[test]
1027 fn test_product_kernel() {
1028 let k1 = Box::new(RbfKernel::new(1.0, 1.0));
1029 let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
1030 let prod = ProductKernel::new(k1.clone(), k2.clone());
1031
1032 let a = array![1.0];
1033 let b = array![2.0];
1034 let expected = k1.eval(&a.view(), &b.view()) * k2.eval(&a.view(), &b.view());
1035 assert!((prod.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
1036 }
1037
1038 #[test]
1039 fn test_gp_fit_predict_basic() {
1040 let (x, y) = make_train_data();
1041 let mut gp = GpSurrogate::new(
1042 Box::new(RbfKernel::default()),
1043 GpSurrogateConfig {
1044 optimize_hyperparams: false,
1045 noise_variance: 1e-4,
1046 ..Default::default()
1047 },
1048 );
1049 gp.fit(&x, &y).expect("fit should succeed");
1050
1051 let (mean, var) = gp.predict(&x).expect("predict should succeed");
1053 for i in 0..y.len() {
1054 assert!(
1055 (mean[i] - y[i]).abs() < 0.15,
1056 "mean[{}]={:.4} vs y[{}]={:.4}",
1057 i,
1058 mean[i],
1059 i,
1060 y[i]
1061 );
1062 assert!(
1064 var[i] < 0.5,
1065 "var[{}]={:.4} should be small at training point",
1066 i,
1067 var[i]
1068 );
1069 }
1070 }
1071
1072 #[test]
1073 fn test_gp_uncertainty_away_from_data() {
1074 let (x, y) = make_train_data();
1075 let mut gp = GpSurrogate::new(
1076 Box::new(RbfKernel::default()),
1077 GpSurrogateConfig {
1078 optimize_hyperparams: false,
1079 noise_variance: 1e-4,
1080 ..Default::default()
1081 },
1082 );
1083 gp.fit(&x, &y).expect("fit should succeed");
1084
1085 let x_far = Array2::from_shape_vec((1, 1), vec![10.0]).expect("shape ok");
1087 let (_, var_far) = gp.predict(&x_far).expect("predict ok");
1088
1089 let x_near = Array2::from_shape_vec((1, 1), vec![2.0]).expect("shape ok");
1091 let (_, var_near) = gp.predict(&x_near).expect("predict ok");
1092
1093 assert!(
1095 var_far[0] > var_near[0],
1096 "var_far={:.4} should be > var_near={:.4}",
1097 var_far[0],
1098 var_near[0]
1099 );
1100 }
1101
1102 #[test]
1103 fn test_gp_predict_single() {
1104 let (x, y) = make_train_data();
1105 let mut gp = GpSurrogate::default_rbf();
1106 gp.config.optimize_hyperparams = false;
1107 gp.config.noise_variance = 1e-4;
1108 gp.fit(&x, &y).expect("fit ok");
1109
1110 let point = array![1.5];
1111 let (mean, std) = gp.predict_single(&point.view()).expect("predict_single ok");
1112 assert!(mean.is_finite());
1113 assert!(std >= 0.0);
1114 }
1115
1116 #[test]
1117 fn test_gp_log_marginal_likelihood() {
1118 let (x, y) = make_train_data();
1119 let mut gp = GpSurrogate::new(
1120 Box::new(RbfKernel::default()),
1121 GpSurrogateConfig {
1122 optimize_hyperparams: false,
1123 noise_variance: 1e-4,
1124 ..Default::default()
1125 },
1126 );
1127 gp.fit(&x, &y).expect("fit ok");
1128
1129 let lml = gp.log_marginal_likelihood().expect("lml ok");
1130 assert!(lml.is_finite(), "LML should be finite, got {}", lml);
1131 }
1132
1133 #[test]
1134 fn test_gp_update_incremental() {
1135 let (x, y) = make_train_data();
1136 let mut gp = GpSurrogate::new(
1137 Box::new(RbfKernel::default()),
1138 GpSurrogateConfig {
1139 optimize_hyperparams: false,
1140 noise_variance: 1e-4,
1141 ..Default::default()
1142 },
1143 );
1144 gp.fit(&x, &y).expect("fit ok");
1145 assert_eq!(gp.n_train(), 5);
1146
1147 let x_new = Array2::from_shape_vec((1, 1), vec![5.0]).expect("shape ok");
1149 let y_new = array![-0.959];
1150 gp.update(&x_new, &y_new).expect("update ok");
1151 assert_eq!(gp.n_train(), 6);
1152 }
1153
1154 #[test]
1155 fn test_gp_hyperparameter_optimization() {
1156 let (x, y) = make_train_data();
1157 let mut gp = GpSurrogate::new(
1158 Box::new(RbfKernel::default()),
1159 GpSurrogateConfig {
1160 optimize_hyperparams: true,
1161 n_restarts: 2,
1162 max_opt_iters: 20,
1163 noise_variance: 1e-4,
1164 },
1165 );
1166 gp.fit(&x, &y).expect("fit with optimization ok");
1167
1168 let x_test = Array2::from_shape_vec((1, 1), vec![1.5]).expect("shape ok");
1170 let (mean, var) = gp.predict(&x_test).expect("predict ok");
1171 assert!(mean[0].is_finite());
1172 assert!(var[0].is_finite());
1173 }
1174
1175 #[test]
1176 fn test_cholesky_positive_definite() {
1177 let a = Array2::from_shape_vec((2, 2), vec![4.0, 2.0, 2.0, 3.0]).expect("shape ok");
1178 let l = cholesky(&a).expect("should succeed");
1179 let reconstructed = l.dot(&l.t());
1181 for i in 0..2 {
1182 for j in 0..2 {
1183 assert!(
1184 (reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-10,
1185 "Mismatch at [{},{}]",
1186 i,
1187 j
1188 );
1189 }
1190 }
1191 }
1192
1193 #[test]
1194 fn test_cholesky_non_pd_fails() {
1195 let a = Array2::from_shape_vec((2, 2), vec![1.0, 10.0, 10.0, 1.0]).expect("shape ok");
1196 assert!(cholesky(&a).is_err());
1197 }
1198
1199 #[test]
1200 fn test_kernel_log_params_roundtrip() {
1201 let mut k = RbfKernel::new(2.5, 0.3);
1202 let params = k.get_log_params();
1203 k.set_log_params(¶ms);
1204 assert!((k.length_scale - 2.5).abs() < 1e-10);
1205 assert!((k.signal_variance - 0.3).abs() < 1e-10);
1206 }
1207
1208 #[test]
1209 fn test_matern_kernel_covariance_matrix() {
1210 let k = MaternKernel::three_halves(1.0, 1.0);
1211 let x = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("shape ok");
1212 let cov = k.covariance_matrix(&x);
1213 assert_eq!(cov.nrows(), 3);
1214 assert_eq!(cov.ncols(), 3);
1215 for i in 0..3 {
1217 for j in 0..3 {
1218 assert!(
1219 (cov[[i, j]] - cov[[j, i]]).abs() < 1e-14,
1220 "Not symmetric at [{},{}]",
1221 i,
1222 j
1223 );
1224 }
1225 }
1226 for i in 0..3 {
1228 assert!(cov[[i, i]] > 0.0);
1229 }
1230 }
1231
1232 #[test]
1233 fn test_gp_multidimensional() {
1234 let x = Array2::from_shape_vec(
1236 (6, 2),
1237 vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 0.5, 0.5],
1238 )
1239 .expect("shape ok");
1240 let y = array![0.0, 1.0, 1.0, 1.0, 1.0, 0.5];
1241
1242 let mut gp = GpSurrogate::new(
1243 Box::new(RbfKernel::default()),
1244 GpSurrogateConfig {
1245 optimize_hyperparams: false,
1246 noise_variance: 1e-4,
1247 ..Default::default()
1248 },
1249 );
1250 gp.fit(&x, &y).expect("fit ok");
1251
1252 let x_test = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape ok");
1254 let (mean, _) = gp.predict(&x_test).expect("predict ok");
1255 assert!(
1256 mean[0].abs() < 0.3,
1257 "Prediction at origin should be close to 0, got {}",
1258 mean[0]
1259 );
1260 }
1261}