Skip to main content

scirs2_stats/gaussian_process/
gp.rs

1//! Core Gaussian Process implementation
2//!
3//! This module provides the fundamental GP regression functionality.
4
5use super::kernel::Kernel;
6use super::prior::Prior;
7use crate::error::StatsResult;
8use scirs2_core::error::CoreError;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
10
11/// Gaussian Process for regression
12///
13/// A Gaussian Process defines a distribution over functions. Given training data,
14/// it can make predictions with uncertainty estimates.
15#[derive(Clone)]
16pub struct GaussianProcess<K: Kernel, P: Prior> {
17    /// Kernel (covariance) function
18    pub kernel: K,
19    /// Prior mean function
20    pub prior: P,
21    /// Training inputs
22    x_train: Option<Array2<f64>>,
23    /// Training outputs (mean-subtracted)
24    y_train_centered: Option<Array1<f64>>,
25    /// Cholesky decomposition of K(X, X) + noise * I
26    l_matrix: Option<Array2<f64>>,
27    /// alpha = L^T \ (L \ y)
28    alpha: Option<Array1<f64>>,
29    /// Noise level (observation noise)
30    pub noise: f64,
31}
32
33impl<K: Kernel, P: Prior> GaussianProcess<K, P> {
34    /// Create a new Gaussian Process
35    pub fn new(kernel: K, prior: P, noise: f64) -> Self {
36        Self {
37            kernel,
38            prior,
39            x_train: None,
40            y_train_centered: None,
41            l_matrix: None,
42            alpha: None,
43            noise: noise.max(1e-10), // Ensure numerical stability
44        }
45    }
46
47    /// Fit the Gaussian Process to training data
48    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<()> {
49        if x.nrows() != y.len() {
50            return Err(
51                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
52                    "Number of samples in X and y must match",
53                ))
54                .into(),
55            );
56        }
57
58        if x.nrows() == 0 {
59            return Err(
60                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
61                    "Cannot fit with zero samples",
62                ))
63                .into(),
64            );
65        }
66
67        // Compute prior mean
68        let prior_mean = self.prior.compute_vector(x);
69
70        // Center the targets
71        let y_centered = y - &prior_mean;
72
73        // Compute covariance matrix K(X, X)
74        let mut k = self.kernel.compute_matrix(x);
75
76        // Add noise to diagonal for numerical stability
77        for i in 0..k.nrows() {
78            k[[i, i]] += self.noise;
79        }
80
81        // Cholesky decomposition: K = L L^T
82        let l = match cholesky_decomposition(&k) {
83            Ok(l) => l,
84            Err(_) => {
85                // If Cholesky fails, add more jitter
86                let jitter = 1e-6;
87                for i in 0..k.nrows() {
88                    k[[i, i]] += jitter;
89                }
90                cholesky_decomposition(&k).map_err(|e| {
91                    CoreError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
92                        "Cholesky decomposition failed: {}",
93                        e
94                    )))
95                })?
96            }
97        };
98
99        // Solve L alpha_1 = y  (forward substitution)
100        let alpha_1 = solve_lower_triangular(&l, &y_centered)?;
101
102        // Solve L^T alpha = alpha_1  (backward substitution)
103        let alpha = solve_upper_triangular(&l.t().to_owned(), &alpha_1)?;
104
105        // Store results
106        self.x_train = Some(x.clone());
107        self.y_train_centered = Some(y_centered);
108        self.l_matrix = Some(l);
109        self.alpha = Some(alpha);
110
111        Ok(())
112    }
113
114    /// Predict mean values for new inputs
115    pub fn predict(&self, x: &Array2<f64>) -> StatsResult<Array1<f64>> {
116        let (mean, _std) = self.predict_with_std(x)?;
117        Ok(mean)
118    }
119
120    /// Predict with mean and standard deviation
121    pub fn predict_with_std(&self, x: &Array2<f64>) -> StatsResult<(Array1<f64>, Array1<f64>)> {
122        if self.x_train.is_none() || self.alpha.is_none() {
123            return Err(
124                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
125                    "GP must be fitted before making predictions",
126                ))
127                .into(),
128            );
129        }
130
131        let x_train = self.x_train.as_ref().expect("Operation failed");
132        let alpha = self.alpha.as_ref().expect("Operation failed");
133        let l = self.l_matrix.as_ref().expect("Operation failed");
134
135        // Compute K(X_test, X_train)
136        let k_trans = self.kernel.compute_cross_matrix(x, x_train);
137
138        // Mean: K(X_test, X_train) @ alpha + prior
139        let mean_centered = k_trans.dot(alpha);
140        let prior_mean = self.prior.compute_vector(x);
141        let mean = mean_centered + prior_mean;
142
143        // Variance calculation
144        // v = L \ K(X_test, X_train)^T
145        let k_trans_t = k_trans.t().to_owned();
146        let v = solve_lower_triangular_matrix(l, &k_trans_t)?;
147
148        // Compute variance for each test point
149        let mut variance = Array1::zeros(x.nrows());
150        for i in 0..x.nrows() {
151            // K(x_test[i], x_test[i])
152            let k_self = self.kernel.compute(&x.row(i), &x.row(i));
153
154            // ||v[i]||^2
155            let v_norm_sq: f64 = v.column(i).iter().map(|&x| x * x).sum();
156
157            // var = k_self - ||v||^2 + noise
158            variance[i] = (k_self - v_norm_sq + self.noise).max(0.0);
159        }
160
161        let std = variance.mapv(|x| x.sqrt());
162
163        Ok((mean, std))
164    }
165
166    /// Predict mean for a single point
167    pub fn predict_single(&self, x: &ArrayView1<f64>) -> StatsResult<f64> {
168        let x_mat = x.to_owned().insert_axis(Axis(0));
169        let pred = self.predict(&x_mat)?;
170        Ok(pred[0])
171    }
172
173    /// Predict variance for a single point
174    pub fn predict_variance_single(&self, x: &ArrayView1<f64>) -> StatsResult<f64> {
175        let x_mat = x.to_owned().insert_axis(Axis(0));
176        let (_mean, std) = self.predict_with_std(&x_mat)?;
177        Ok(std[0] * std[0])
178    }
179
180    /// Compute log marginal likelihood
181    ///
182    /// This is useful for hyperparameter optimization.
183    pub fn log_marginal_likelihood(&self) -> StatsResult<f64> {
184        if self.y_train_centered.is_none() || self.l_matrix.is_none() {
185            return Err(
186                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
187                    "GP must be fitted before computing log marginal likelihood",
188                ))
189                .into(),
190            );
191        }
192
193        let y = self.y_train_centered.as_ref().expect("Operation failed");
194        let l = self.l_matrix.as_ref().expect("Operation failed");
195        let alpha = self.alpha.as_ref().expect("Operation failed");
196
197        let n = y.len() as f64;
198
199        // Compute data fit: -0.5 * y^T @ alpha
200        let data_fit = -0.5 * y.dot(alpha);
201
202        // Compute complexity penalty: -sum(log(diag(L)))
203        let log_det: f64 = l.diag().iter().map(|&x| x.ln()).sum();
204        let complexity = -log_det;
205
206        // Normalization constant: -n/2 * log(2π)
207        let normalization = -0.5 * n * (2.0 * std::f64::consts::PI).ln();
208
209        Ok(data_fit + complexity + normalization)
210    }
211
212    /// Get number of training samples
213    pub fn n_train_samples(&self) -> usize {
214        self.x_train.as_ref().map_or(0, |x| x.nrows())
215    }
216}
217
218/// Cholesky decomposition: A = L L^T where L is lower triangular
219fn cholesky_decomposition(a: &Array2<f64>) -> Result<Array2<f64>, String> {
220    let n = a.nrows();
221    if n != a.ncols() {
222        return Err("Matrix must be square".to_string());
223    }
224
225    let mut l = Array2::zeros((n, n));
226
227    for i in 0..n {
228        for j in 0..=i {
229            let mut sum = 0.0;
230
231            if j == i {
232                for k in 0..j {
233                    sum += l[[j, k]] * l[[j, k]];
234                }
235                let val = a[[j, j]] - sum;
236                if val <= 0.0 {
237                    return Err(format!(
238                        "Matrix is not positive definite (diagonal {} = {})",
239                        j, val
240                    ));
241                }
242                l[[j, j]] = val.sqrt();
243            } else {
244                for k in 0..j {
245                    sum += l[[i, k]] * l[[j, k]];
246                }
247                l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
248            }
249        }
250    }
251
252    Ok(l)
253}
254
255/// Solve L x = b where L is lower triangular
256fn solve_lower_triangular(l: &Array2<f64>, b: &Array1<f64>) -> StatsResult<Array1<f64>> {
257    let n = l.nrows();
258    let mut x = Array1::zeros(n);
259
260    for i in 0..n {
261        let mut sum = 0.0;
262        for j in 0..i {
263            sum += l[[i, j]] * x[j];
264        }
265        x[i] = (b[i] - sum) / l[[i, i]];
266    }
267
268    Ok(x)
269}
270
271/// Solve U x = b where U is upper triangular
272fn solve_upper_triangular(u: &Array2<f64>, b: &Array1<f64>) -> StatsResult<Array1<f64>> {
273    let n = u.nrows();
274    let mut x = Array1::zeros(n);
275
276    for i in (0..n).rev() {
277        let mut sum = 0.0;
278        for j in (i + 1)..n {
279            sum += u[[i, j]] * x[j];
280        }
281        x[i] = (b[i] - sum) / u[[i, i]];
282    }
283
284    Ok(x)
285}
286
287/// Solve L X = B where L is lower triangular and B is a matrix
288fn solve_lower_triangular_matrix(l: &Array2<f64>, b: &Array2<f64>) -> StatsResult<Array2<f64>> {
289    let n = l.nrows();
290    let m = b.ncols();
291    let mut x = Array2::zeros((n, m));
292
293    for col in 0..m {
294        let b_col = b.column(col).to_owned();
295        let x_col = solve_lower_triangular(l, &b_col)?;
296        for row in 0..n {
297            x[[row, col]] = x_col[row];
298        }
299    }
300
301    Ok(x)
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::gaussian_process::kernel::SquaredExponential;
308    use crate::gaussian_process::prior::ZeroPrior;
309    use scirs2_core::ndarray::{array, Array2};
310
311    #[test]
312    fn test_gp_fit_predict() {
313        let kernel = SquaredExponential::new(1.0, 1.0);
314        let prior = ZeroPrior::new();
315        let mut gp = GaussianProcess::new(kernel, prior, 0.01);
316
317        // Simple training data
318        let x_train =
319            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Operation failed");
320        let y_train = array![0.0, 1.0, 0.0];
321
322        gp.fit(&x_train, &y_train).expect("Operation failed");
323
324        // Predict at training points
325        let predictions = gp.predict(&x_train).expect("Operation failed");
326
327        // Should be close to training values
328        for i in 0..3 {
329            assert!((predictions[i] - y_train[i]).abs() < 0.1);
330        }
331    }
332
333    #[test]
334    fn test_gp_uncertainty() {
335        let kernel = SquaredExponential::new(1.0, 1.0);
336        let prior = ZeroPrior::new();
337        let mut gp = GaussianProcess::new(kernel, prior, 0.01);
338
339        let x_train = Array2::from_shape_vec((2, 1), vec![0.0, 2.0]).expect("Operation failed");
340        let y_train = array![1.0, -1.0];
341
342        gp.fit(&x_train, &y_train).expect("Operation failed");
343
344        // Predict at interpolation point
345        let x_test = Array2::from_shape_vec((1, 1), vec![1.0]).expect("Operation failed");
346        let (_mean, std) = gp.predict_with_std(&x_test).expect("Operation failed");
347
348        // Uncertainty should be positive and reasonable
349        assert!(std[0] > 0.0);
350        assert!(std[0] < 2.0); // Not too large
351    }
352}