scirs2_stats/gaussian_process/
regression.rs1use super::gp::GaussianProcess;
6use super::kernel::{Kernel, SquaredExponential, SumKernel, WhiteKernel};
7use super::prior::{Prior, ZeroPrior};
8use crate::error::StatsResult;
9use scirs2_core::error::CoreError;
10use scirs2_core::ndarray::ArrayStatCompat;
11use scirs2_core::ndarray::{Array1, Array2};
12
13pub struct GaussianProcessRegressor<K: Kernel> {
33 gp: GaussianProcess<SumKernel<K, WhiteKernel>, ZeroPrior>,
35 user_kernel: K,
37 alpha: f64,
39 normalize_y: bool,
41 y_train_mean: Option<f64>,
43 y_train_std: Option<f64>,
45}
46
47impl<K: Kernel> GaussianProcessRegressor<K> {
48 pub fn new(kernel: K) -> Self {
58 Self::with_options(kernel, 1e-10, false)
59 }
60
61 pub fn with_options(kernel: K, alpha: f64, normalize_y: bool) -> Self {
69 let noise_kernel = WhiteKernel::new(alpha);
70 let combined_kernel = SumKernel::new(kernel.clone(), noise_kernel);
71 let prior = ZeroPrior::new();
72 let gp = GaussianProcess::new(combined_kernel, prior, 0.0);
73
74 Self {
75 gp,
76 user_kernel: kernel,
77 alpha,
78 normalize_y,
79 y_train_mean: None,
80 y_train_std: None,
81 }
82 }
83
84 pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<()> {
95 if x.nrows() != y.len() {
96 return Err(
97 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
98 "Number of samples in X and y must match",
99 ))
100 .into(),
101 );
102 }
103
104 let y_normalized = if self.normalize_y {
106 let mean = y.mean_or(0.0);
107 let std = y.std(0.0);
108 let std = if std < 1e-10 { 1.0 } else { std };
109
110 self.y_train_mean = Some(mean);
111 self.y_train_std = Some(std);
112
113 (y - mean) / std
114 } else {
115 y.clone()
116 };
117
118 self.gp.fit(x, &y_normalized)
119 }
120
121 pub fn predict(&self, x: &Array2<f64>) -> StatsResult<Array1<f64>> {
131 let predictions = self.gp.predict(x)?;
132
133 Ok(if self.normalize_y {
135 let mean = self.y_train_mean.unwrap_or(0.0);
136 let std = self.y_train_std.unwrap_or(1.0);
137 predictions * std + mean
138 } else {
139 predictions
140 })
141 }
142
143 pub fn predict_with_std(&self, x: &Array2<f64>) -> StatsResult<(Array1<f64>, Array1<f64>)> {
154 let (mean, std) = self.gp.predict_with_std(x)?;
155
156 if self.normalize_y {
158 let y_mean = self.y_train_mean.unwrap_or(0.0);
159 let y_std = self.y_train_std.unwrap_or(1.0);
160 Ok((mean * y_std + y_mean, std * y_std))
161 } else {
162 Ok((mean, std))
163 }
164 }
165
166 pub fn kernel(&self) -> &K {
168 &self.user_kernel
169 }
170
171 pub fn kernel_mut(&mut self) -> &mut K {
173 &mut self.user_kernel
174 }
175
176 pub fn log_marginal_likelihood(&self) -> StatsResult<f64> {
178 self.gp.log_marginal_likelihood()
179 }
180
181 pub fn n_train_samples(&self) -> usize {
183 self.gp.n_train_samples()
184 }
185
186 pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<f64> {
197 let y_pred = self.predict(x)?;
198
199 if y.len() != y_pred.len() {
200 return Err(
201 CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
202 "Prediction and true values must have same length",
203 ))
204 .into(),
205 );
206 }
207
208 let y_mean = y.mean_or(0.0);
210 let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
211 let ss_res: f64 = y
212 .iter()
213 .zip(y_pred.iter())
214 .map(|(&yi, &yp)| (yi - yp).powi(2))
215 .sum();
216
217 if ss_tot < 1e-10 {
218 return Ok(1.0); }
220
221 Ok(1.0 - ss_res / ss_tot)
222 }
223}
224
225pub fn default_gp_regressor() -> GaussianProcessRegressor<SquaredExponential> {
229 GaussianProcessRegressor::new(SquaredExponential::default())
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use scirs2_core::ndarray::{array, Array2};
236
237 #[test]
238 fn test_gpr_basic() {
239 let kernel = SquaredExponential::default();
240 let mut gpr = GaussianProcessRegressor::new(kernel);
241
242 let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
243 .expect("Operation failed");
244 let y_train = array![0.0, 1.0, 1.5, 1.0, 0.0];
245
246 gpr.fit(&x_train, &y_train).expect("Operation failed");
247
248 let x_test = Array2::from_shape_vec((1, 1), vec![2.5]).expect("Operation failed");
249 let predictions = gpr.predict(&x_test).expect("Operation failed");
250
251 assert!(predictions[0] > 0.5 && predictions[0] < 2.0);
253 }
254
255 #[test]
256 fn test_gpr_with_std() {
257 let kernel = SquaredExponential::default();
258 let mut gpr = GaussianProcessRegressor::new(kernel);
259
260 let x_train =
261 Array2::from_shape_vec((3, 1), vec![0.0, 2.0, 4.0]).expect("Operation failed");
262 let y_train = array![1.0, 0.0, 1.0];
263
264 gpr.fit(&x_train, &y_train).expect("Operation failed");
265
266 let x_test = Array2::from_shape_vec((2, 1), vec![1.0, 5.0]).expect("Operation failed");
267 let (mean, std) = gpr.predict_with_std(&x_test).expect("Operation failed");
268
269 assert!(std.iter().all(|&s| s > 0.0));
271
272 assert!(std[1] > std[0] || std[1].abs() - std[0].abs() < 0.1);
274 }
275
276 #[test]
277 fn test_gpr_normalize() {
278 let kernel = SquaredExponential::default();
279 let mut gpr = GaussianProcessRegressor::with_options(kernel, 1e-10, true);
280
281 let x_train =
282 Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Operation failed");
283 let y_train = array![100.0, 200.0, 150.0]; gpr.fit(&x_train, &y_train).expect("Operation failed");
286
287 let predictions = gpr.predict(&x_train).expect("Operation failed");
288
289 for i in 0..3 {
291 assert!((predictions[i] - y_train[i]).abs() < 20.0);
292 }
293 }
294
295 #[test]
296 fn test_gpr_score() {
297 let kernel = SquaredExponential::default();
298 let mut gpr = GaussianProcessRegressor::new(kernel);
299
300 let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
301 .expect("Operation failed");
302 let y = array![0.0, 1.0, 2.0, 1.5, 0.5];
303
304 gpr.fit(&x, &y).expect("Operation failed");
305
306 let score = gpr.score(&x, &y).expect("Operation failed");
307
308 assert!(score > 0.8);
310 }
311}