Skip to main content

scirs2_optimize/bayesian/
gp.rs

1//! Gaussian Process surrogate model for Bayesian optimization.
2//!
3//! Provides a self-contained GP regression implementation with:
4//! - Multiple kernel functions: RBF, Matern (1/2, 3/2, 5/2), Rational Quadratic
5//! - Composite kernels: Sum and Product
6//! - Efficient Cholesky-based prediction for mean and variance
7//! - Hyperparameter optimization via type-II maximum likelihood (marginal likelihood)
8//!
9//! The GP is specifically designed as a surrogate model for Bayesian optimization,
10//! prioritising numerically robust prediction of both mean and uncertainty.
11
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use scirs2_core::random::{Rng, RngExt};
14
15use crate::error::{OptimizeError, OptimizeResult};
16
17// ---------------------------------------------------------------------------
18// Kernel trait & implementations
19// ---------------------------------------------------------------------------
20
21/// Trait for covariance (kernel) functions used by the GP surrogate.
22pub trait SurrogateKernel: Send + Sync {
23    /// Evaluate the kernel between two input vectors.
24    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64;
25
26    /// Compute the full covariance matrix for a set of inputs.
27    fn covariance_matrix(&self, x: &Array2<f64>) -> Array2<f64> {
28        let n = x.nrows();
29        let mut k = Array2::zeros((n, n));
30        for i in 0..n {
31            for j in 0..=i {
32                let kij = self.eval(&x.row(i), &x.row(j));
33                k[[i, j]] = kij;
34                if i != j {
35                    k[[j, i]] = kij;
36                }
37            }
38        }
39        k
40    }
41
42    /// Compute the cross-covariance matrix between two sets of inputs.
43    fn cross_covariance(&self, x1: &Array2<f64>, x2: &Array2<f64>) -> Array2<f64> {
44        let n1 = x1.nrows();
45        let n2 = x2.nrows();
46        let mut k = Array2::zeros((n1, n2));
47        for i in 0..n1 {
48            for j in 0..n2 {
49                k[[i, j]] = self.eval(&x1.row(i), &x2.row(j));
50            }
51        }
52        k
53    }
54
55    /// Return current hyperparameters as a flat vector (log-scale).
56    fn get_log_params(&self) -> Vec<f64>;
57
58    /// Set hyperparameters from a flat vector (log-scale).
59    fn set_log_params(&mut self, params: &[f64]);
60
61    /// Number of hyperparameters.
62    fn n_params(&self) -> usize {
63        self.get_log_params().len()
64    }
65
66    /// Clone the kernel into a boxed trait object.
67    fn clone_box(&self) -> Box<dyn SurrogateKernel>;
68
69    /// Name of the kernel (for debug display).
70    fn name(&self) -> &str;
71}
72
73impl Clone for Box<dyn SurrogateKernel> {
74    fn clone(&self) -> Self {
75        self.clone_box()
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Squared Exponential (RBF) kernel
81// ---------------------------------------------------------------------------
82
83/// Squared Exponential / RBF kernel.
84///
85/// k(x, x') = sigma^2 * exp(-0.5 * ||x - x'||^2 / length_scale^2)
86#[derive(Debug, Clone)]
87pub struct RbfKernel {
88    /// Length scale
89    pub length_scale: f64,
90    /// Signal variance
91    pub signal_variance: f64,
92}
93
94impl RbfKernel {
95    /// Create a new RBF kernel.
96    pub fn new(length_scale: f64, signal_variance: f64) -> Self {
97        Self {
98            length_scale: length_scale.max(1e-10),
99            signal_variance: signal_variance.max(1e-10),
100        }
101    }
102}
103
104impl Default for RbfKernel {
105    fn default() -> Self {
106        Self::new(1.0, 1.0)
107    }
108}
109
110impl SurrogateKernel for RbfKernel {
111    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
112        let sq_dist = squared_distance(x1, x2);
113        self.signal_variance * (-0.5 * sq_dist / (self.length_scale * self.length_scale)).exp()
114    }
115
116    fn get_log_params(&self) -> Vec<f64> {
117        vec![self.length_scale.ln(), self.signal_variance.ln()]
118    }
119
120    fn set_log_params(&mut self, params: &[f64]) {
121        if params.len() >= 2 {
122            self.length_scale = params[0].exp().max(1e-10);
123            self.signal_variance = params[1].exp().max(1e-10);
124        }
125    }
126
127    fn clone_box(&self) -> Box<dyn SurrogateKernel> {
128        Box::new(self.clone())
129    }
130
131    fn name(&self) -> &str {
132        "RBF"
133    }
134}
135
136// ---------------------------------------------------------------------------
137// Matern kernel family
138// ---------------------------------------------------------------------------
139
140/// Variant of the Matern kernel.
141#[derive(Debug, Clone, Copy, PartialEq)]
142pub enum MaternVariant {
143    /// nu = 1/2 (exponential kernel)
144    OneHalf,
145    /// nu = 3/2
146    ThreeHalves,
147    /// nu = 5/2
148    FiveHalves,
149}
150
151/// Matern kernel with selectable smoothness parameter nu.
152///
153/// - nu = 1/2: k = sigma^2 * exp(-r / l)  (once-differentiable)
154/// - nu = 3/2: k = sigma^2 * (1 + sqrt(3)*r/l) * exp(-sqrt(3)*r/l)
155/// - nu = 5/2: k = sigma^2 * (1 + sqrt(5)*r/l + 5*r^2/(3*l^2)) * exp(-sqrt(5)*r/l)
156#[derive(Debug, Clone)]
157pub struct MaternKernel {
158    /// Smoothness parameter
159    pub variant: MaternVariant,
160    /// Length scale
161    pub length_scale: f64,
162    /// Signal variance
163    pub signal_variance: f64,
164}
165
166impl MaternKernel {
167    pub fn new(variant: MaternVariant, length_scale: f64, signal_variance: f64) -> Self {
168        Self {
169            variant,
170            length_scale: length_scale.max(1e-10),
171            signal_variance: signal_variance.max(1e-10),
172        }
173    }
174
175    /// Create a Matern-1/2 kernel.
176    pub fn one_half(length_scale: f64, signal_variance: f64) -> Self {
177        Self::new(MaternVariant::OneHalf, length_scale, signal_variance)
178    }
179
180    /// Create a Matern-3/2 kernel.
181    pub fn three_halves(length_scale: f64, signal_variance: f64) -> Self {
182        Self::new(MaternVariant::ThreeHalves, length_scale, signal_variance)
183    }
184
185    /// Create a Matern-5/2 kernel.
186    pub fn five_halves(length_scale: f64, signal_variance: f64) -> Self {
187        Self::new(MaternVariant::FiveHalves, length_scale, signal_variance)
188    }
189}
190
191impl Default for MaternKernel {
192    fn default() -> Self {
193        Self::new(MaternVariant::FiveHalves, 1.0, 1.0)
194    }
195}
196
197impl SurrogateKernel for MaternKernel {
198    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
199        let r = squared_distance(x1, x2).sqrt();
200        let l = self.length_scale;
201        let sv = self.signal_variance;
202
203        match self.variant {
204            MaternVariant::OneHalf => sv * (-r / l).exp(),
205            MaternVariant::ThreeHalves => {
206                let sqrt3_r_l = 3.0_f64.sqrt() * r / l;
207                sv * (1.0 + sqrt3_r_l) * (-sqrt3_r_l).exp()
208            }
209            MaternVariant::FiveHalves => {
210                let sqrt5_r_l = 5.0_f64.sqrt() * r / l;
211                let r2_l2 = r * r / (l * l);
212                sv * (1.0 + sqrt5_r_l + 5.0 * r2_l2 / 3.0) * (-sqrt5_r_l).exp()
213            }
214        }
215    }
216
217    fn get_log_params(&self) -> Vec<f64> {
218        vec![self.length_scale.ln(), self.signal_variance.ln()]
219    }
220
221    fn set_log_params(&mut self, params: &[f64]) {
222        if params.len() >= 2 {
223            self.length_scale = params[0].exp().max(1e-10);
224            self.signal_variance = params[1].exp().max(1e-10);
225        }
226    }
227
228    fn clone_box(&self) -> Box<dyn SurrogateKernel> {
229        Box::new(self.clone())
230    }
231
232    fn name(&self) -> &str {
233        match self.variant {
234            MaternVariant::OneHalf => "Matern12",
235            MaternVariant::ThreeHalves => "Matern32",
236            MaternVariant::FiveHalves => "Matern52",
237        }
238    }
239}
240
241// ---------------------------------------------------------------------------
242// Rational Quadratic kernel
243// ---------------------------------------------------------------------------
244
245/// Rational Quadratic kernel.
246///
247/// k(x, x') = sigma^2 * (1 + ||x - x'||^2 / (2 * alpha * l^2))^(-alpha)
248///
249/// This can be seen as a scale mixture of RBF kernels with different length
250/// scales. The parameter `alpha` controls the relative weighting of large-scale
251/// vs small-scale variations.
252#[derive(Debug, Clone)]
253pub struct RationalQuadraticKernel {
254    /// Length scale
255    pub length_scale: f64,
256    /// Signal variance
257    pub signal_variance: f64,
258    /// Shape parameter (alpha); larger alpha => closer to RBF
259    pub alpha: f64,
260}
261
262impl RationalQuadraticKernel {
263    pub fn new(length_scale: f64, signal_variance: f64, alpha: f64) -> Self {
264        Self {
265            length_scale: length_scale.max(1e-10),
266            signal_variance: signal_variance.max(1e-10),
267            alpha: alpha.max(1e-10),
268        }
269    }
270}
271
272impl Default for RationalQuadraticKernel {
273    fn default() -> Self {
274        Self::new(1.0, 1.0, 1.0)
275    }
276}
277
278impl SurrogateKernel for RationalQuadraticKernel {
279    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
280        let sq_dist = squared_distance(x1, x2);
281        let base = 1.0 + sq_dist / (2.0 * self.alpha * self.length_scale * self.length_scale);
282        self.signal_variance * base.powf(-self.alpha)
283    }
284
285    fn get_log_params(&self) -> Vec<f64> {
286        vec![
287            self.length_scale.ln(),
288            self.signal_variance.ln(),
289            self.alpha.ln(),
290        ]
291    }
292
293    fn set_log_params(&mut self, params: &[f64]) {
294        if params.len() >= 3 {
295            self.length_scale = params[0].exp().max(1e-10);
296            self.signal_variance = params[1].exp().max(1e-10);
297            self.alpha = params[2].exp().max(1e-10);
298        }
299    }
300
301    fn clone_box(&self) -> Box<dyn SurrogateKernel> {
302        Box::new(self.clone())
303    }
304
305    fn name(&self) -> &str {
306        "RationalQuadratic"
307    }
308}
309
310// ---------------------------------------------------------------------------
311// Composite kernels: Sum and Product
312// ---------------------------------------------------------------------------
313
314/// Sum of two kernels: k(x,x') = k1(x,x') + k2(x,x')
315#[derive(Clone)]
316pub struct SumKernel {
317    pub kernel1: Box<dyn SurrogateKernel>,
318    pub kernel2: Box<dyn SurrogateKernel>,
319}
320
321impl SumKernel {
322    pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
323        Self {
324            kernel1: k1,
325            kernel2: k2,
326        }
327    }
328}
329
330impl SurrogateKernel for SumKernel {
331    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
332        self.kernel1.eval(x1, x2) + self.kernel2.eval(x1, x2)
333    }
334
335    fn get_log_params(&self) -> Vec<f64> {
336        let mut p = self.kernel1.get_log_params();
337        p.extend(self.kernel2.get_log_params());
338        p
339    }
340
341    fn set_log_params(&mut self, params: &[f64]) {
342        let n1 = self.kernel1.n_params();
343        if params.len() >= n1 {
344            self.kernel1.set_log_params(&params[..n1]);
345        }
346        if params.len() > n1 {
347            self.kernel2.set_log_params(&params[n1..]);
348        }
349    }
350
351    fn clone_box(&self) -> Box<dyn SurrogateKernel> {
352        Box::new(self.clone())
353    }
354
355    fn name(&self) -> &str {
356        "Sum"
357    }
358}
359
360/// Product of two kernels: k(x,x') = k1(x,x') * k2(x,x')
361#[derive(Clone)]
362pub struct ProductKernel {
363    pub kernel1: Box<dyn SurrogateKernel>,
364    pub kernel2: Box<dyn SurrogateKernel>,
365}
366
367impl ProductKernel {
368    pub fn new(k1: Box<dyn SurrogateKernel>, k2: Box<dyn SurrogateKernel>) -> Self {
369        Self {
370            kernel1: k1,
371            kernel2: k2,
372        }
373    }
374}
375
376impl SurrogateKernel for ProductKernel {
377    fn eval(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
378        self.kernel1.eval(x1, x2) * self.kernel2.eval(x1, x2)
379    }
380
381    fn get_log_params(&self) -> Vec<f64> {
382        let mut p = self.kernel1.get_log_params();
383        p.extend(self.kernel2.get_log_params());
384        p
385    }
386
387    fn set_log_params(&mut self, params: &[f64]) {
388        let n1 = self.kernel1.n_params();
389        if params.len() >= n1 {
390            self.kernel1.set_log_params(&params[..n1]);
391        }
392        if params.len() > n1 {
393            self.kernel2.set_log_params(&params[n1..]);
394        }
395    }
396
397    fn clone_box(&self) -> Box<dyn SurrogateKernel> {
398        Box::new(self.clone())
399    }
400
401    fn name(&self) -> &str {
402        "Product"
403    }
404}
405
406// ---------------------------------------------------------------------------
407// Gaussian Process Surrogate
408// ---------------------------------------------------------------------------
409
410/// Configuration for the GP surrogate.
411#[derive(Clone)]
412pub struct GpSurrogateConfig {
413    /// Noise variance added to the diagonal for numerical stability.
414    pub noise_variance: f64,
415    /// Whether to optimise kernel hyperparameters via marginal likelihood.
416    pub optimize_hyperparams: bool,
417    /// Number of random restarts for hyperparameter optimization.
418    pub n_restarts: usize,
419    /// Maximum number of L-BFGS iterations per restart.
420    pub max_opt_iters: usize,
421}
422
423impl Default for GpSurrogateConfig {
424    fn default() -> Self {
425        Self {
426            noise_variance: 1e-6,
427            optimize_hyperparams: true,
428            n_restarts: 3,
429            max_opt_iters: 50,
430        }
431    }
432}
433
434/// Gaussian Process surrogate model for Bayesian optimization.
435///
436/// Maintains training data and the fitted model (Cholesky factor + alpha vector)
437/// for efficient prediction.
438pub struct GpSurrogate {
439    /// Kernel function
440    kernel: Box<dyn SurrogateKernel>,
441    /// Configuration
442    config: GpSurrogateConfig,
443    /// Training inputs (n_train x n_dims)
444    x_train: Option<Array2<f64>>,
445    /// Training targets (n_train,)
446    y_train: Option<Array1<f64>>,
447    /// Mean of training targets (for standardization)
448    y_mean: f64,
449    /// Std of training targets (for standardization)
450    y_std: f64,
451    /// Lower-triangular Cholesky factor of K + noise*I
452    l_factor: Option<Array2<f64>>,
453    /// Alpha = L^T \ (L \ y_centered)
454    alpha: Option<Array1<f64>>,
455}
456
457impl GpSurrogate {
458    /// Create a new GP surrogate with the given kernel.
459    pub fn new(kernel: Box<dyn SurrogateKernel>, config: GpSurrogateConfig) -> Self {
460        Self {
461            kernel,
462            config,
463            x_train: None,
464            y_train: None,
465            y_mean: 0.0,
466            y_std: 1.0,
467            l_factor: None,
468            alpha: None,
469        }
470    }
471
472    /// Create a GP surrogate with an RBF kernel and default configuration.
473    pub fn default_rbf() -> Self {
474        Self::new(Box::new(RbfKernel::default()), GpSurrogateConfig::default())
475    }
476
477    /// Fit the GP to training data.
478    ///
479    /// Optionally optimises kernel hyperparameters via marginal likelihood maximisation.
480    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> OptimizeResult<()> {
481        if x.nrows() != y.len() {
482            return Err(OptimizeError::InvalidInput(format!(
483                "x has {} rows but y has {} elements",
484                x.nrows(),
485                y.len()
486            )));
487        }
488        if x.nrows() == 0 {
489            return Err(OptimizeError::InvalidInput(
490                "Cannot fit GP with zero training samples".to_string(),
491            ));
492        }
493
494        self.x_train = Some(x.clone());
495        self.y_train = Some(y.clone());
496
497        // Standardise targets
498        self.y_mean = y.iter().sum::<f64>() / y.len() as f64;
499        let variance = y.iter().map(|&v| (v - self.y_mean).powi(2)).sum::<f64>() / y.len() as f64;
500        self.y_std = if variance > 1e-12 {
501            variance.sqrt()
502        } else {
503            1.0
504        };
505
506        // Optimise hyperparameters if requested
507        if self.config.optimize_hyperparams && x.nrows() >= 3 {
508            self.optimize_hyperparameters()?;
509        }
510
511        // Compute Cholesky factor and alpha with current kernel
512        self.update_model()
513    }
514
515    /// Add new observations incrementally and refit the model.
516    pub fn update(&mut self, x_new: &Array2<f64>, y_new: &Array1<f64>) -> OptimizeResult<()> {
517        if x_new.nrows() != y_new.len() {
518            return Err(OptimizeError::InvalidInput(
519                "x_new and y_new must have same number of rows".to_string(),
520            ));
521        }
522
523        match (&self.x_train, &self.y_train) {
524            (Some(xt), Some(yt)) => {
525                let mut x_all = Array2::zeros((xt.nrows() + x_new.nrows(), xt.ncols()));
526                for i in 0..xt.nrows() {
527                    for j in 0..xt.ncols() {
528                        x_all[[i, j]] = xt[[i, j]];
529                    }
530                }
531                for i in 0..x_new.nrows() {
532                    for j in 0..x_new.ncols() {
533                        x_all[[xt.nrows() + i, j]] = x_new[[i, j]];
534                    }
535                }
536                let mut y_all = Array1::zeros(yt.len() + y_new.len());
537                for i in 0..yt.len() {
538                    y_all[i] = yt[i];
539                }
540                for i in 0..y_new.len() {
541                    y_all[yt.len() + i] = y_new[i];
542                }
543                self.fit(&x_all, &y_all)
544            }
545            _ => self.fit(x_new, y_new),
546        }
547    }
548
549    /// Predict the GP mean at test points.
550    pub fn predict_mean(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
551        let (mean, _) = self.predict(x_test)?;
552        Ok(mean)
553    }
554
555    /// Predict the GP variance at test points.
556    pub fn predict_variance(&self, x_test: &Array2<f64>) -> OptimizeResult<Array1<f64>> {
557        let (_, var) = self.predict(x_test)?;
558        Ok(var)
559    }
560
561    /// Predict both mean and variance at test points.
562    pub fn predict(&self, x_test: &Array2<f64>) -> OptimizeResult<(Array1<f64>, Array1<f64>)> {
563        let x_train = self.x_train.as_ref().ok_or_else(|| {
564            OptimizeError::ComputationError("GP must be fitted before prediction".to_string())
565        })?;
566        let alpha = self
567            .alpha
568            .as_ref()
569            .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
570        let l_factor = self
571            .l_factor
572            .as_ref()
573            .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
574
575        // K(X_test, X_train)
576        let k_star = self.kernel.cross_covariance(x_test, x_train);
577
578        // Standardised mean: k_star @ alpha
579        let mean_std = k_star.dot(alpha);
580
581        // De-standardise
582        let mean = mean_std.mapv(|v| v * self.y_std + self.y_mean);
583
584        // Variance: k(x*, x*) - k_star @ K^{-1} @ k_star^T
585        // Using: v = L \ k_star^T, variance = k_self - ||v||^2
586        let n_test = x_test.nrows();
587        let mut variance = Array1::zeros(n_test);
588
589        for i in 0..n_test {
590            let k_self = self.kernel.eval(&x_test.row(i), &x_test.row(i));
591
592            // Solve L v = k_star[i, :] using forward substitution
593            let k_col = k_star.row(i).to_owned();
594            let v = forward_solve(l_factor, &k_col)?;
595
596            let v_sq_sum: f64 = v.iter().map(|&vi| vi * vi).sum();
597            let var = (k_self - v_sq_sum).max(0.0);
598            variance[i] = var * self.y_std * self.y_std;
599        }
600
601        Ok((mean, variance))
602    }
603
604    /// Predict mean and standard deviation at a single point.
605    pub fn predict_single(&self, x: &ArrayView1<f64>) -> OptimizeResult<(f64, f64)> {
606        let x_mat = x
607            .to_owned()
608            .into_shape_with_order((1, x.len()))
609            .map_err(|e| OptimizeError::ComputationError(format!("Shape error: {}", e)))?;
610        let (mean, var) = self.predict(&x_mat)?;
611        Ok((mean[0], var[0].max(0.0).sqrt()))
612    }
613
614    /// Compute the log marginal likelihood of the current model.
615    pub fn log_marginal_likelihood(&self) -> OptimizeResult<f64> {
616        let y_train = self
617            .y_train
618            .as_ref()
619            .ok_or_else(|| OptimizeError::ComputationError("GP must be fitted".to_string()))?;
620        let l_factor = self
621            .l_factor
622            .as_ref()
623            .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
624        let alpha = self
625            .alpha
626            .as_ref()
627            .ok_or_else(|| OptimizeError::ComputationError("GP model not fitted".to_string()))?;
628
629        let y_std = &self.standardize_y(y_train);
630        let n = y_std.len() as f64;
631
632        // -0.5 * y^T alpha
633        let data_fit = -0.5 * y_std.dot(alpha);
634
635        // -sum(log(diag(L)))
636        let log_det: f64 = l_factor.diag().iter().map(|&v| v.abs().ln()).sum();
637
638        // -0.5 * n * log(2 pi)
639        let norm = -0.5 * n * (2.0 * std::f64::consts::PI).ln();
640
641        Ok(data_fit - log_det + norm)
642    }
643
644    /// Return a reference to the current kernel.
645    pub fn kernel(&self) -> &dyn SurrogateKernel {
646        self.kernel.as_ref()
647    }
648
649    /// Return a mutable reference to the kernel.
650    pub fn kernel_mut(&mut self) -> &mut dyn SurrogateKernel {
651        self.kernel.as_mut()
652    }
653
654    /// Number of training samples.
655    pub fn n_train(&self) -> usize {
656        self.x_train.as_ref().map_or(0, |x| x.nrows())
657    }
658
659    // -----------------------------------------------------------------------
660    // Internal helpers
661    // -----------------------------------------------------------------------
662
663    /// Standardise y-values: (y - mean) / std
664    fn standardize_y(&self, y: &Array1<f64>) -> Array1<f64> {
665        y.mapv(|v| (v - self.y_mean) / self.y_std)
666    }
667
668    /// Recompute Cholesky factor and alpha from current kernel + data.
669    fn update_model(&mut self) -> OptimizeResult<()> {
670        let x_train = self
671            .x_train
672            .as_ref()
673            .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
674        let y_train = self
675            .y_train
676            .as_ref()
677            .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?;
678
679        let y_std = self.standardize_y(y_train);
680
681        // Build covariance matrix
682        let mut k = self.kernel.covariance_matrix(x_train);
683        let n = k.nrows();
684
685        // Add noise to diagonal
686        for i in 0..n {
687            k[[i, i]] += self.config.noise_variance;
688        }
689
690        // Cholesky decomposition with jitter fallback
691        let l = match cholesky(&k) {
692            Ok(l) => l,
693            Err(_) => {
694                // Try increasing jitter
695                let jitters = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2];
696                let mut result = Err(OptimizeError::ComputationError(
697                    "Cholesky failed with all jitter levels".to_string(),
698                ));
699                for &jitter in &jitters {
700                    for i in 0..n {
701                        k[[i, i]] += jitter;
702                    }
703                    match cholesky(&k) {
704                        Ok(l) => {
705                            result = Ok(l);
706                            break;
707                        }
708                        Err(_) => continue,
709                    }
710                }
711                result?
712            }
713        };
714
715        // Solve L alpha1 = y_std (forward substitution)
716        let alpha1 = forward_solve(&l, &y_std)?;
717        // Solve L^T alpha = alpha1 (backward substitution)
718        let alpha = backward_solve_transpose(&l, &alpha1)?;
719
720        self.l_factor = Some(l);
721        self.alpha = Some(alpha);
722
723        Ok(())
724    }
725
726    /// Optimise kernel hyperparameters by maximising the log marginal likelihood.
727    ///
728    /// Uses a simple coordinate-wise golden-section search on each log-parameter
729    /// with random restarts. This avoids depending on external optimisers and
730    /// keeps the implementation self-contained.
731    fn optimize_hyperparameters(&mut self) -> OptimizeResult<()> {
732        let x_train = self
733            .x_train
734            .as_ref()
735            .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
736            .clone();
737        let y_train = self
738            .y_train
739            .as_ref()
740            .ok_or_else(|| OptimizeError::ComputationError("No training data".to_string()))?
741            .clone();
742        let y_std = self.standardize_y(&y_train);
743
744        let n_params = self.kernel.n_params();
745        if n_params == 0 {
746            return Ok(());
747        }
748
749        let mut best_params = self.kernel.get_log_params();
750        let mut best_lml = f64::NEG_INFINITY;
751
752        // Evaluate current params
753        if let Ok(lml) = self.eval_lml_at_params(&best_params, &x_train, &y_std) {
754            best_lml = lml;
755        }
756
757        let mut rng = scirs2_core::random::rng();
758
759        // Random restarts
760        for restart in 0..self.config.n_restarts {
761            let init_params: Vec<f64> = if restart == 0 {
762                best_params.clone()
763            } else {
764                (0..n_params).map(|_| rng.random_range(-2.0..2.0)).collect()
765            };
766
767            // Coordinate-wise optimization
768            let mut current_params = init_params;
769            for _iter in 0..self.config.max_opt_iters {
770                let mut improved = false;
771                for p in 0..n_params {
772                    let original = current_params[p];
773
774                    // Try a few steps in each direction
775                    let steps = [0.1, 0.3, 1.0, -0.1, -0.3, -1.0];
776                    let mut best_step_lml =
777                        match self.eval_lml_at_params(&current_params, &x_train, &y_std) {
778                            Ok(v) => v,
779                            Err(_) => f64::NEG_INFINITY,
780                        };
781                    let mut best_step_val = original;
782
783                    for &step in &steps {
784                        current_params[p] = original + step;
785                        // Clamp to reasonable range
786                        current_params[p] = current_params[p].clamp(-5.0, 5.0);
787
788                        if let Ok(lml) = self.eval_lml_at_params(&current_params, &x_train, &y_std)
789                        {
790                            if lml > best_step_lml {
791                                best_step_lml = lml;
792                                best_step_val = current_params[p];
793                                improved = true;
794                            }
795                        }
796                    }
797                    current_params[p] = best_step_val;
798                }
799                if !improved {
800                    break;
801                }
802            }
803
804            // Evaluate final
805            if let Ok(lml) = self.eval_lml_at_params(&current_params, &x_train, &y_std) {
806                if lml > best_lml {
807                    best_lml = lml;
808                    best_params = current_params;
809                }
810            }
811        }
812
813        // Set best params
814        self.kernel.set_log_params(&best_params);
815
816        Ok(())
817    }
818
819    /// Evaluate the log marginal likelihood at given (log) kernel parameters
820    /// without mutating the surrogate.
821    fn eval_lml_at_params(
822        &self,
823        log_params: &[f64],
824        x_train: &Array2<f64>,
825        y_std: &Array1<f64>,
826    ) -> OptimizeResult<f64> {
827        let mut kernel = self.kernel.clone();
828        kernel.set_log_params(log_params);
829
830        let mut k = kernel.covariance_matrix(x_train);
831        let n = k.nrows();
832        for i in 0..n {
833            k[[i, i]] += self.config.noise_variance;
834        }
835
836        let l = cholesky(&k)?;
837        let alpha1 = forward_solve(&l, y_std)?;
838        let alpha = backward_solve_transpose(&l, &alpha1)?;
839
840        let n_f = n as f64;
841        let data_fit = -0.5 * y_std.dot(&alpha);
842        let log_det: f64 = l.diag().iter().map(|&v| v.abs().ln()).sum();
843        let norm = -0.5 * n_f * (2.0 * std::f64::consts::PI).ln();
844
845        Ok(data_fit - log_det + norm)
846    }
847}
848
849// ---------------------------------------------------------------------------
850// Linear algebra helpers (pure Rust, no unwrap)
851// ---------------------------------------------------------------------------
852
853/// Squared Euclidean distance between two vectors.
854fn squared_distance(x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
855    let mut s = 0.0;
856    for i in 0..x1.len() {
857        let d = x1[i] - x2[i];
858        s += d * d;
859    }
860    s
861}
862
863/// Cholesky decomposition: A = L L^T  (lower triangular).
864fn cholesky(a: &Array2<f64>) -> OptimizeResult<Array2<f64>> {
865    let n = a.nrows();
866    if n != a.ncols() {
867        return Err(OptimizeError::ComputationError(
868            "Cholesky: matrix must be square".to_string(),
869        ));
870    }
871    let mut l = Array2::zeros((n, n));
872
873    for i in 0..n {
874        for j in 0..=i {
875            let mut s = 0.0;
876            for k in 0..j {
877                s += l[[i, k]] * l[[j, k]];
878            }
879            if i == j {
880                let diag = a[[i, i]] - s;
881                if diag <= 0.0 {
882                    return Err(OptimizeError::ComputationError(format!(
883                        "Cholesky: matrix not positive-definite (diag[{}] = {:.6e})",
884                        i, diag
885                    )));
886                }
887                l[[i, j]] = diag.sqrt();
888            } else {
889                l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
890            }
891        }
892    }
893    Ok(l)
894}
895
896/// Forward substitution: solve L x = b where L is lower-triangular.
897fn forward_solve(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
898    let n = l.nrows();
899    let mut x = Array1::zeros(n);
900    for i in 0..n {
901        let mut s = 0.0;
902        for j in 0..i {
903            s += l[[i, j]] * x[j];
904        }
905        let diag = l[[i, i]];
906        if diag.abs() < 1e-15 {
907            return Err(OptimizeError::ComputationError(
908                "Forward solve: near-zero diagonal".to_string(),
909            ));
910        }
911        x[i] = (b[i] - s) / diag;
912    }
913    Ok(x)
914}
915
916/// Backward substitution: solve L^T x = b where L is lower-triangular.
917fn backward_solve_transpose(l: &Array2<f64>, b: &Array1<f64>) -> OptimizeResult<Array1<f64>> {
918    let n = l.nrows();
919    let mut x = Array1::zeros(n);
920    for i in (0..n).rev() {
921        let mut s = 0.0;
922        for j in (i + 1)..n {
923            s += l[[j, i]] * x[j]; // L^T[i,j] = L[j,i]
924        }
925        let diag = l[[i, i]];
926        if diag.abs() < 1e-15 {
927            return Err(OptimizeError::ComputationError(
928                "Backward solve: near-zero diagonal".to_string(),
929            ));
930        }
931        x[i] = (b[i] - s) / diag;
932    }
933    Ok(x)
934}
935
936// ---------------------------------------------------------------------------
937// Tests
938// ---------------------------------------------------------------------------
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943    use scirs2_core::ndarray::array;
944
945    fn make_train_data() -> (Array2<f64>, Array1<f64>) {
946        // f(x) = sin(x)  sampled at 5 points
947        let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0]).expect("shape ok");
948        let y = array![0.0, 0.841, 0.909, 0.141, -0.757];
949        (x, y)
950    }
951
952    #[test]
953    fn test_rbf_kernel_symmetry() {
954        let k = RbfKernel::default();
955        let a = array![1.0, 2.0];
956        let b = array![3.0, 4.0];
957        assert!((k.eval(&a.view(), &b.view()) - k.eval(&b.view(), &a.view())).abs() < 1e-14);
958    }
959
960    #[test]
961    fn test_rbf_kernel_self_covariance() {
962        let k = RbfKernel::new(1.0, 2.0);
963        let a = array![1.0, 2.0];
964        // Self-covariance should be signal_variance
965        assert!((k.eval(&a.view(), &a.view()) - 2.0).abs() < 1e-14);
966    }
967
968    #[test]
969    fn test_matern_variants() {
970        let a = array![0.0];
971        let b = array![1.0];
972
973        for variant in &[
974            MaternVariant::OneHalf,
975            MaternVariant::ThreeHalves,
976            MaternVariant::FiveHalves,
977        ] {
978            let k = MaternKernel::new(*variant, 1.0, 1.0);
979            let val = k.eval(&a.view(), &b.view());
980            assert!(val > 0.0 && val < 1.0, "Matern({:?}) = {}", variant, val);
981            // Self-covariance = signal_variance
982            assert!((k.eval(&a.view(), &a.view()) - 1.0).abs() < 1e-14);
983        }
984    }
985
986    #[test]
987    fn test_rational_quadratic_kernel() {
988        let k = RationalQuadraticKernel::new(1.0, 1.0, 1.0);
989        let a = array![0.0];
990        let b = array![1.0];
991        let val = k.eval(&a.view(), &b.view());
992        // Should be (1 + 0.5)^{-1} = 2/3
993        assert!((val - 2.0 / 3.0).abs() < 1e-10);
994    }
995
996    #[test]
997    fn test_rational_quadratic_approaches_rbf() {
998        // As alpha -> infinity, RQ -> RBF
999        let rbf = RbfKernel::new(1.0, 1.0);
1000        let rq = RationalQuadraticKernel::new(1.0, 1.0, 1e6);
1001        let a = array![0.0, 1.0];
1002        let b = array![2.0, 3.0];
1003
1004        let rbf_val = rbf.eval(&a.view(), &b.view());
1005        let rq_val = rq.eval(&a.view(), &b.view());
1006        assert!(
1007            (rbf_val - rq_val).abs() < 1e-4,
1008            "RBF={}, RQ(alpha=1e6)={}",
1009            rbf_val,
1010            rq_val
1011        );
1012    }
1013
1014    #[test]
1015    fn test_sum_kernel() {
1016        let k1 = Box::new(RbfKernel::new(1.0, 1.0));
1017        let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
1018        let sum = SumKernel::new(k1.clone(), k2.clone());
1019
1020        let a = array![1.0];
1021        let b = array![2.0];
1022        let expected = k1.eval(&a.view(), &b.view()) + k2.eval(&a.view(), &b.view());
1023        assert!((sum.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
1024    }
1025
1026    #[test]
1027    fn test_product_kernel() {
1028        let k1 = Box::new(RbfKernel::new(1.0, 1.0));
1029        let k2 = Box::new(MaternKernel::five_halves(1.0, 0.5));
1030        let prod = ProductKernel::new(k1.clone(), k2.clone());
1031
1032        let a = array![1.0];
1033        let b = array![2.0];
1034        let expected = k1.eval(&a.view(), &b.view()) * k2.eval(&a.view(), &b.view());
1035        assert!((prod.eval(&a.view(), &b.view()) - expected).abs() < 1e-14);
1036    }
1037
1038    #[test]
1039    fn test_gp_fit_predict_basic() {
1040        let (x, y) = make_train_data();
1041        let mut gp = GpSurrogate::new(
1042            Box::new(RbfKernel::default()),
1043            GpSurrogateConfig {
1044                optimize_hyperparams: false,
1045                noise_variance: 1e-4,
1046                ..Default::default()
1047            },
1048        );
1049        gp.fit(&x, &y).expect("fit should succeed");
1050
1051        // Predict at training points -> should be close to training values
1052        let (mean, var) = gp.predict(&x).expect("predict should succeed");
1053        for i in 0..y.len() {
1054            assert!(
1055                (mean[i] - y[i]).abs() < 0.15,
1056                "mean[{}]={:.4} vs y[{}]={:.4}",
1057                i,
1058                mean[i],
1059                i,
1060                y[i]
1061            );
1062            // Variance should be small at training points
1063            assert!(
1064                var[i] < 0.5,
1065                "var[{}]={:.4} should be small at training point",
1066                i,
1067                var[i]
1068            );
1069        }
1070    }
1071
1072    #[test]
1073    fn test_gp_uncertainty_away_from_data() {
1074        let (x, y) = make_train_data();
1075        let mut gp = GpSurrogate::new(
1076            Box::new(RbfKernel::default()),
1077            GpSurrogateConfig {
1078                optimize_hyperparams: false,
1079                noise_variance: 1e-4,
1080                ..Default::default()
1081            },
1082        );
1083        gp.fit(&x, &y).expect("fit should succeed");
1084
1085        // Predict far from training data
1086        let x_far = Array2::from_shape_vec((1, 1), vec![10.0]).expect("shape ok");
1087        let (_, var_far) = gp.predict(&x_far).expect("predict ok");
1088
1089        // Predict at training data
1090        let x_near = Array2::from_shape_vec((1, 1), vec![2.0]).expect("shape ok");
1091        let (_, var_near) = gp.predict(&x_near).expect("predict ok");
1092
1093        // Uncertainty should be higher far from data
1094        assert!(
1095            var_far[0] > var_near[0],
1096            "var_far={:.4} should be > var_near={:.4}",
1097            var_far[0],
1098            var_near[0]
1099        );
1100    }
1101
1102    #[test]
1103    fn test_gp_predict_single() {
1104        let (x, y) = make_train_data();
1105        let mut gp = GpSurrogate::default_rbf();
1106        gp.config.optimize_hyperparams = false;
1107        gp.config.noise_variance = 1e-4;
1108        gp.fit(&x, &y).expect("fit ok");
1109
1110        let point = array![1.5];
1111        let (mean, std) = gp.predict_single(&point.view()).expect("predict_single ok");
1112        assert!(mean.is_finite());
1113        assert!(std >= 0.0);
1114    }
1115
1116    #[test]
1117    fn test_gp_log_marginal_likelihood() {
1118        let (x, y) = make_train_data();
1119        let mut gp = GpSurrogate::new(
1120            Box::new(RbfKernel::default()),
1121            GpSurrogateConfig {
1122                optimize_hyperparams: false,
1123                noise_variance: 1e-4,
1124                ..Default::default()
1125            },
1126        );
1127        gp.fit(&x, &y).expect("fit ok");
1128
1129        let lml = gp.log_marginal_likelihood().expect("lml ok");
1130        assert!(lml.is_finite(), "LML should be finite, got {}", lml);
1131    }
1132
1133    #[test]
1134    fn test_gp_update_incremental() {
1135        let (x, y) = make_train_data();
1136        let mut gp = GpSurrogate::new(
1137            Box::new(RbfKernel::default()),
1138            GpSurrogateConfig {
1139                optimize_hyperparams: false,
1140                noise_variance: 1e-4,
1141                ..Default::default()
1142            },
1143        );
1144        gp.fit(&x, &y).expect("fit ok");
1145        assert_eq!(gp.n_train(), 5);
1146
1147        // Add one more point
1148        let x_new = Array2::from_shape_vec((1, 1), vec![5.0]).expect("shape ok");
1149        let y_new = array![-0.959];
1150        gp.update(&x_new, &y_new).expect("update ok");
1151        assert_eq!(gp.n_train(), 6);
1152    }
1153
1154    #[test]
1155    fn test_gp_hyperparameter_optimization() {
1156        let (x, y) = make_train_data();
1157        let mut gp = GpSurrogate::new(
1158            Box::new(RbfKernel::default()),
1159            GpSurrogateConfig {
1160                optimize_hyperparams: true,
1161                n_restarts: 2,
1162                max_opt_iters: 20,
1163                noise_variance: 1e-4,
1164            },
1165        );
1166        gp.fit(&x, &y).expect("fit with optimization ok");
1167
1168        // Just verify it completes without error and produces finite predictions
1169        let x_test = Array2::from_shape_vec((1, 1), vec![1.5]).expect("shape ok");
1170        let (mean, var) = gp.predict(&x_test).expect("predict ok");
1171        assert!(mean[0].is_finite());
1172        assert!(var[0].is_finite());
1173    }
1174
1175    #[test]
1176    fn test_cholesky_positive_definite() {
1177        let a = Array2::from_shape_vec((2, 2), vec![4.0, 2.0, 2.0, 3.0]).expect("shape ok");
1178        let l = cholesky(&a).expect("should succeed");
1179        // Verify L L^T = A
1180        let reconstructed = l.dot(&l.t());
1181        for i in 0..2 {
1182            for j in 0..2 {
1183                assert!(
1184                    (reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-10,
1185                    "Mismatch at [{},{}]",
1186                    i,
1187                    j
1188                );
1189            }
1190        }
1191    }
1192
1193    #[test]
1194    fn test_cholesky_non_pd_fails() {
1195        let a = Array2::from_shape_vec((2, 2), vec![1.0, 10.0, 10.0, 1.0]).expect("shape ok");
1196        assert!(cholesky(&a).is_err());
1197    }
1198
1199    #[test]
1200    fn test_kernel_log_params_roundtrip() {
1201        let mut k = RbfKernel::new(2.5, 0.3);
1202        let params = k.get_log_params();
1203        k.set_log_params(&params);
1204        assert!((k.length_scale - 2.5).abs() < 1e-10);
1205        assert!((k.signal_variance - 0.3).abs() < 1e-10);
1206    }
1207
1208    #[test]
1209    fn test_matern_kernel_covariance_matrix() {
1210        let k = MaternKernel::three_halves(1.0, 1.0);
1211        let x = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("shape ok");
1212        let cov = k.covariance_matrix(&x);
1213        assert_eq!(cov.nrows(), 3);
1214        assert_eq!(cov.ncols(), 3);
1215        // Symmetric
1216        for i in 0..3 {
1217            for j in 0..3 {
1218                assert!(
1219                    (cov[[i, j]] - cov[[j, i]]).abs() < 1e-14,
1220                    "Not symmetric at [{},{}]",
1221                    i,
1222                    j
1223                );
1224            }
1225        }
1226        // Positive diagonal
1227        for i in 0..3 {
1228            assert!(cov[[i, i]] > 0.0);
1229        }
1230    }
1231
1232    #[test]
1233    fn test_gp_multidimensional() {
1234        // 2D function: f(x,y) = x^2 + y^2
1235        let x = Array2::from_shape_vec(
1236            (6, 2),
1237            vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 0.5, 0.5],
1238        )
1239        .expect("shape ok");
1240        let y = array![0.0, 1.0, 1.0, 1.0, 1.0, 0.5];
1241
1242        let mut gp = GpSurrogate::new(
1243            Box::new(RbfKernel::default()),
1244            GpSurrogateConfig {
1245                optimize_hyperparams: false,
1246                noise_variance: 1e-4,
1247                ..Default::default()
1248            },
1249        );
1250        gp.fit(&x, &y).expect("fit ok");
1251
1252        // Predict at origin (should be close to 0)
1253        let x_test = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape ok");
1254        let (mean, _) = gp.predict(&x_test).expect("predict ok");
1255        assert!(
1256            mean[0].abs() < 0.3,
1257            "Prediction at origin should be close to 0, got {}",
1258            mean[0]
1259        );
1260    }
1261}