scirs2_stats/gaussian_process/
gp.rs1use super::kernel::Kernel;
6use super::prior::Prior;
7use crate::error::StatsResult;
8use scirs2_core::error::CoreError;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, Axis};
10
11#[derive(Clone)]
16pub struct GaussianProcess<K: Kernel, P: Prior> {
17 pub kernel: K,
19 pub prior: P,
21 x_train: Option<Array2<f64>>,
23 y_train_centered: Option<Array1<f64>>,
25 l_matrix: Option<Array2<f64>>,
27 alpha: Option<Array1<f64>>,
29 pub noise: f64,
31}
32
33impl<K: Kernel, P: Prior> GaussianProcess<K, P> {
34 pub fn new(kernel: K, prior: P, noise: f64) -> Self {
36 Self {
37 kernel,
38 prior,
39 x_train: None,
40 y_train_centered: None,
41 l_matrix: None,
42 alpha: None,
43 noise: noise.max(1e-10), }
45 }
46
47 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<()> {
49 if x.nrows() != y.len() {
50 return Err(
51 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
52 "Number of samples in X and y must match",
53 ))
54 .into(),
55 );
56 }
57
58 if x.nrows() == 0 {
59 return Err(
60 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
61 "Cannot fit with zero samples",
62 ))
63 .into(),
64 );
65 }
66
67 let prior_mean = self.prior.compute_vector(x);
69
70 let y_centered = y - &prior_mean;
72
73 let mut k = self.kernel.compute_matrix(x);
75
76 for i in 0..k.nrows() {
78 k[[i, i]] += self.noise;
79 }
80
81 let l = match cholesky_decomposition(&k) {
83 Ok(l) => l,
84 Err(_) => {
85 let jitter = 1e-6;
87 for i in 0..k.nrows() {
88 k[[i, i]] += jitter;
89 }
90 cholesky_decomposition(&k).map_err(|e| {
91 CoreError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
92 "Cholesky decomposition failed: {}",
93 e
94 )))
95 })?
96 }
97 };
98
99 let alpha_1 = solve_lower_triangular(&l, &y_centered)?;
101
102 let alpha = solve_upper_triangular(&l.t().to_owned(), &alpha_1)?;
104
105 self.x_train = Some(x.clone());
107 self.y_train_centered = Some(y_centered);
108 self.l_matrix = Some(l);
109 self.alpha = Some(alpha);
110
111 Ok(())
112 }
113
114 pub fn predict(&self, x: &Array2<f64>) -> StatsResult<Array1<f64>> {
116 let (mean, _std) = self.predict_with_std(x)?;
117 Ok(mean)
118 }
119
120 pub fn predict_with_std(&self, x: &Array2<f64>) -> StatsResult<(Array1<f64>, Array1<f64>)> {
122 if self.x_train.is_none() || self.alpha.is_none() {
123 return Err(
124 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
125 "GP must be fitted before making predictions",
126 ))
127 .into(),
128 );
129 }
130
131 let x_train = self.x_train.as_ref().expect("Operation failed");
132 let alpha = self.alpha.as_ref().expect("Operation failed");
133 let l = self.l_matrix.as_ref().expect("Operation failed");
134
135 let k_trans = self.kernel.compute_cross_matrix(x, x_train);
137
138 let mean_centered = k_trans.dot(alpha);
140 let prior_mean = self.prior.compute_vector(x);
141 let mean = mean_centered + prior_mean;
142
143 let k_trans_t = k_trans.t().to_owned();
146 let v = solve_lower_triangular_matrix(l, &k_trans_t)?;
147
148 let mut variance = Array1::zeros(x.nrows());
150 for i in 0..x.nrows() {
151 let k_self = self.kernel.compute(&x.row(i), &x.row(i));
153
154 let v_norm_sq: f64 = v.column(i).iter().map(|&x| x * x).sum();
156
157 variance[i] = (k_self - v_norm_sq + self.noise).max(0.0);
159 }
160
161 let std = variance.mapv(|x| x.sqrt());
162
163 Ok((mean, std))
164 }
165
166 pub fn predict_single(&self, x: &ArrayView1<f64>) -> StatsResult<f64> {
168 let x_mat = x.to_owned().insert_axis(Axis(0));
169 let pred = self.predict(&x_mat)?;
170 Ok(pred[0])
171 }
172
173 pub fn predict_variance_single(&self, x: &ArrayView1<f64>) -> StatsResult<f64> {
175 let x_mat = x.to_owned().insert_axis(Axis(0));
176 let (_mean, std) = self.predict_with_std(&x_mat)?;
177 Ok(std[0] * std[0])
178 }
179
180 pub fn log_marginal_likelihood(&self) -> StatsResult<f64> {
184 if self.y_train_centered.is_none() || self.l_matrix.is_none() {
185 return Err(
186 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
187 "GP must be fitted before computing log marginal likelihood",
188 ))
189 .into(),
190 );
191 }
192
193 let y = self.y_train_centered.as_ref().expect("Operation failed");
194 let l = self.l_matrix.as_ref().expect("Operation failed");
195 let alpha = self.alpha.as_ref().expect("Operation failed");
196
197 let n = y.len() as f64;
198
199 let data_fit = -0.5 * y.dot(alpha);
201
202 let log_det: f64 = l.diag().iter().map(|&x| x.ln()).sum();
204 let complexity = -log_det;
205
206 let normalization = -0.5 * n * (2.0 * std::f64::consts::PI).ln();
208
209 Ok(data_fit + complexity + normalization)
210 }
211
212 pub fn n_train_samples(&self) -> usize {
214 self.x_train.as_ref().map_or(0, |x| x.nrows())
215 }
216}
217
218fn cholesky_decomposition(a: &Array2<f64>) -> Result<Array2<f64>, String> {
220 let n = a.nrows();
221 if n != a.ncols() {
222 return Err("Matrix must be square".to_string());
223 }
224
225 let mut l = Array2::zeros((n, n));
226
227 for i in 0..n {
228 for j in 0..=i {
229 let mut sum = 0.0;
230
231 if j == i {
232 for k in 0..j {
233 sum += l[[j, k]] * l[[j, k]];
234 }
235 let val = a[[j, j]] - sum;
236 if val <= 0.0 {
237 return Err(format!(
238 "Matrix is not positive definite (diagonal {} = {})",
239 j, val
240 ));
241 }
242 l[[j, j]] = val.sqrt();
243 } else {
244 for k in 0..j {
245 sum += l[[i, k]] * l[[j, k]];
246 }
247 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
248 }
249 }
250 }
251
252 Ok(l)
253}
254
255fn solve_lower_triangular(l: &Array2<f64>, b: &Array1<f64>) -> StatsResult<Array1<f64>> {
257 let n = l.nrows();
258 let mut x = Array1::zeros(n);
259
260 for i in 0..n {
261 let mut sum = 0.0;
262 for j in 0..i {
263 sum += l[[i, j]] * x[j];
264 }
265 x[i] = (b[i] - sum) / l[[i, i]];
266 }
267
268 Ok(x)
269}
270
271fn solve_upper_triangular(u: &Array2<f64>, b: &Array1<f64>) -> StatsResult<Array1<f64>> {
273 let n = u.nrows();
274 let mut x = Array1::zeros(n);
275
276 for i in (0..n).rev() {
277 let mut sum = 0.0;
278 for j in (i + 1)..n {
279 sum += u[[i, j]] * x[j];
280 }
281 x[i] = (b[i] - sum) / u[[i, i]];
282 }
283
284 Ok(x)
285}
286
287fn solve_lower_triangular_matrix(l: &Array2<f64>, b: &Array2<f64>) -> StatsResult<Array2<f64>> {
289 let n = l.nrows();
290 let m = b.ncols();
291 let mut x = Array2::zeros((n, m));
292
293 for col in 0..m {
294 let b_col = b.column(col).to_owned();
295 let x_col = solve_lower_triangular(l, &b_col)?;
296 for row in 0..n {
297 x[[row, col]] = x_col[row];
298 }
299 }
300
301 Ok(x)
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::gaussian_process::kernel::SquaredExponential;
308 use crate::gaussian_process::prior::ZeroPrior;
309 use scirs2_core::ndarray::{array, Array2};
310
311 #[test]
312 fn test_gp_fit_predict() {
313 let kernel = SquaredExponential::new(1.0, 1.0);
314 let prior = ZeroPrior::new();
315 let mut gp = GaussianProcess::new(kernel, prior, 0.01);
316
317 let x_train =
319 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Operation failed");
320 let y_train = array![0.0, 1.0, 0.0];
321
322 gp.fit(&x_train, &y_train).expect("Operation failed");
323
324 let predictions = gp.predict(&x_train).expect("Operation failed");
326
327 for i in 0..3 {
329 assert!((predictions[i] - y_train[i]).abs() < 0.1);
330 }
331 }
332
333 #[test]
334 fn test_gp_uncertainty() {
335 let kernel = SquaredExponential::new(1.0, 1.0);
336 let prior = ZeroPrior::new();
337 let mut gp = GaussianProcess::new(kernel, prior, 0.01);
338
339 let x_train = Array2::from_shape_vec((2, 1), vec![0.0, 2.0]).expect("Operation failed");
340 let y_train = array![1.0, -1.0];
341
342 gp.fit(&x_train, &y_train).expect("Operation failed");
343
344 let x_test = Array2::from_shape_vec((1, 1), vec![1.0]).expect("Operation failed");
346 let (_mean, std) = gp.predict_with_std(&x_test).expect("Operation failed");
347
348 assert!(std[0] > 0.0);
350 assert!(std[0] < 2.0); }
352}