Skip to main content

scirs2_stats/gaussian_process/
regression.rs

1//! High-level Gaussian Process Regression API
2//!
3//! This module provides a scikit-learn compatible interface for GP regression.
4
5use super::gp::GaussianProcess;
6use super::kernel::{Kernel, SquaredExponential, SumKernel, WhiteKernel};
7use super::prior::{Prior, ZeroPrior};
8use crate::error::StatsResult;
9use scirs2_core::error::CoreError;
10use scirs2_core::ndarray::ArrayStatCompat;
11use scirs2_core::ndarray::{Array1, Array2};
12
13/// Gaussian Process Regressor with scikit-learn compatible API
14///
15/// # Examples
16///
17/// ```
18/// use scirs2_stats::gaussian_process::{GaussianProcessRegressor, SquaredExponential};
19/// use scirs2_core::ndarray::{array, Array2};
20///
21/// let kernel = SquaredExponential::default();
22/// let mut gpr = GaussianProcessRegressor::new(kernel);
23///
24/// let x_train = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Operation failed");
25/// let y_train = array![0.0, 1.0, 0.5];
26///
27/// gpr.fit(&x_train, &y_train).expect("Operation failed");
28///
29/// let x_test = Array2::from_shape_vec((1, 1), vec![1.5]).expect("Operation failed");
30/// let predictions = gpr.predict(&x_test).expect("Operation failed");
31/// ```
32pub struct GaussianProcessRegressor<K: Kernel> {
33    /// The underlying Gaussian Process
34    gp: GaussianProcess<SumKernel<K, WhiteKernel>, ZeroPrior>,
35    /// User-provided kernel (before adding noise)
36    user_kernel: K,
37    /// Alpha parameter for regularization
38    alpha: f64,
39    /// Whether to normalize target values
40    normalize_y: bool,
41    /// Mean of training targets (for normalization)
42    y_train_mean: Option<f64>,
43    /// Std of training targets (for normalization)
44    y_train_std: Option<f64>,
45}
46
47impl<K: Kernel> GaussianProcessRegressor<K> {
48    /// Create a new Gaussian Process Regressor
49    ///
50    /// # Arguments
51    ///
52    /// * `kernel` - The covariance kernel
53    ///
54    /// # Returns
55    ///
56    /// A new GaussianProcessRegressor with default settings
57    pub fn new(kernel: K) -> Self {
58        Self::with_options(kernel, 1e-10, false)
59    }
60
61    /// Create a new GP Regressor with custom options
62    ///
63    /// # Arguments
64    ///
65    /// * `kernel` - The covariance kernel
66    /// * `alpha` - Noise level / regularization parameter
67    /// * `normalize_y` - Whether to normalize target values
68    pub fn with_options(kernel: K, alpha: f64, normalize_y: bool) -> Self {
69        let noise_kernel = WhiteKernel::new(alpha);
70        let combined_kernel = SumKernel::new(kernel.clone(), noise_kernel);
71        let prior = ZeroPrior::new();
72        let gp = GaussianProcess::new(combined_kernel, prior, 0.0);
73
74        Self {
75            gp,
76            user_kernel: kernel,
77            alpha,
78            normalize_y,
79            y_train_mean: None,
80            y_train_std: None,
81        }
82    }
83
84    /// Fit the Gaussian Process model
85    ///
86    /// # Arguments
87    ///
88    /// * `x` - Training features (n_samples, n_features)
89    /// * `y` - Training targets (n_samples,)
90    ///
91    /// # Returns
92    ///
93    /// Result indicating success or failure
94    pub fn fit(&mut self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<()> {
95        if x.nrows() != y.len() {
96            return Err(
97                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
98                    "Number of samples in X and y must match",
99                ))
100                .into(),
101            );
102        }
103
104        // Normalize y if requested
105        let y_normalized = if self.normalize_y {
106            let mean = y.mean_or(0.0);
107            let std = y.std(0.0);
108            let std = if std < 1e-10 { 1.0 } else { std };
109
110            self.y_train_mean = Some(mean);
111            self.y_train_std = Some(std);
112
113            (y - mean) / std
114        } else {
115            y.clone()
116        };
117
118        self.gp.fit(x, &y_normalized)
119    }
120
121    /// Predict mean values
122    ///
123    /// # Arguments
124    ///
125    /// * `x` - Test features (n_samples, n_features)
126    ///
127    /// # Returns
128    ///
129    /// Predicted mean values
130    pub fn predict(&self, x: &Array2<f64>) -> StatsResult<Array1<f64>> {
131        let predictions = self.gp.predict(x)?;
132
133        // Denormalize if needed
134        Ok(if self.normalize_y {
135            let mean = self.y_train_mean.unwrap_or(0.0);
136            let std = self.y_train_std.unwrap_or(1.0);
137            predictions * std + mean
138        } else {
139            predictions
140        })
141    }
142
143    /// Predict with uncertainty estimates
144    ///
145    /// # Arguments
146    ///
147    /// * `x` - Test features (n_samples, n_features)
148    /// * `return_std` - Whether to return standard deviation
149    ///
150    /// # Returns
151    ///
152    /// (predictions, standard_deviations) if return_std is true
153    pub fn predict_with_std(&self, x: &Array2<f64>) -> StatsResult<(Array1<f64>, Array1<f64>)> {
154        let (mean, std) = self.gp.predict_with_std(x)?;
155
156        // Denormalize if needed
157        if self.normalize_y {
158            let y_mean = self.y_train_mean.unwrap_or(0.0);
159            let y_std = self.y_train_std.unwrap_or(1.0);
160            Ok((mean * y_std + y_mean, std * y_std))
161        } else {
162            Ok((mean, std))
163        }
164    }
165
166    /// Get the kernel
167    pub fn kernel(&self) -> &K {
168        &self.user_kernel
169    }
170
171    /// Get the kernel (mutable)
172    pub fn kernel_mut(&mut self) -> &mut K {
173        &mut self.user_kernel
174    }
175
176    /// Compute log marginal likelihood
177    pub fn log_marginal_likelihood(&self) -> StatsResult<f64> {
178        self.gp.log_marginal_likelihood()
179    }
180
181    /// Get number of training samples
182    pub fn n_train_samples(&self) -> usize {
183        self.gp.n_train_samples()
184    }
185
186    /// Score the model using R² metric
187    ///
188    /// # Arguments
189    ///
190    /// * `x` - Test features
191    /// * `y` - True test targets
192    ///
193    /// # Returns
194    ///
195    /// R² score
196    pub fn score(&self, x: &Array2<f64>, y: &Array1<f64>) -> StatsResult<f64> {
197        let y_pred = self.predict(x)?;
198
199        if y.len() != y_pred.len() {
200            return Err(
201                CoreError::InvalidInput(scirs2_core::error::ErrorContext::new(
202                    "Prediction and true values must have same length",
203                ))
204                .into(),
205            );
206        }
207
208        // Compute R²
209        let y_mean = y.mean_or(0.0);
210        let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
211        let ss_res: f64 = y
212            .iter()
213            .zip(y_pred.iter())
214            .map(|(&yi, &yp)| (yi - yp).powi(2))
215            .sum();
216
217        if ss_tot < 1e-10 {
218            return Ok(1.0); // Perfect prediction if variance is zero
219        }
220
221        Ok(1.0 - ss_res / ss_tot)
222    }
223}
224
225/// Create a default GP regressor with RBF kernel
226///
227/// This is a convenience function for the most common use case.
228pub fn default_gp_regressor() -> GaussianProcessRegressor<SquaredExponential> {
229    GaussianProcessRegressor::new(SquaredExponential::default())
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use scirs2_core::ndarray::{array, Array2};
236
237    #[test]
238    fn test_gpr_basic() {
239        let kernel = SquaredExponential::default();
240        let mut gpr = GaussianProcessRegressor::new(kernel);
241
242        let x_train = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
243            .expect("Operation failed");
244        let y_train = array![0.0, 1.0, 1.5, 1.0, 0.0];
245
246        gpr.fit(&x_train, &y_train).expect("Operation failed");
247
248        let x_test = Array2::from_shape_vec((1, 1), vec![2.5]).expect("Operation failed");
249        let predictions = gpr.predict(&x_test).expect("Operation failed");
250
251        // Prediction should be reasonable
252        assert!(predictions[0] > 0.5 && predictions[0] < 2.0);
253    }
254
255    #[test]
256    fn test_gpr_with_std() {
257        let kernel = SquaredExponential::default();
258        let mut gpr = GaussianProcessRegressor::new(kernel);
259
260        let x_train =
261            Array2::from_shape_vec((3, 1), vec![0.0, 2.0, 4.0]).expect("Operation failed");
262        let y_train = array![1.0, 0.0, 1.0];
263
264        gpr.fit(&x_train, &y_train).expect("Operation failed");
265
266        let x_test = Array2::from_shape_vec((2, 1), vec![1.0, 5.0]).expect("Operation failed");
267        let (mean, std) = gpr.predict_with_std(&x_test).expect("Operation failed");
268
269        // All predictions should have positive uncertainty
270        assert!(std.iter().all(|&s| s > 0.0));
271
272        // Point far from training data should have higher uncertainty
273        assert!(std[1] > std[0] || std[1].abs() - std[0].abs() < 0.1);
274    }
275
276    #[test]
277    fn test_gpr_normalize() {
278        let kernel = SquaredExponential::default();
279        let mut gpr = GaussianProcessRegressor::with_options(kernel, 1e-10, true);
280
281        let x_train =
282            Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).expect("Operation failed");
283        let y_train = array![100.0, 200.0, 150.0]; // Large values
284
285        gpr.fit(&x_train, &y_train).expect("Operation failed");
286
287        let predictions = gpr.predict(&x_train).expect("Operation failed");
288
289        // Should fit training data well despite large values
290        for i in 0..3 {
291            assert!((predictions[i] - y_train[i]).abs() < 20.0);
292        }
293    }
294
295    #[test]
296    fn test_gpr_score() {
297        let kernel = SquaredExponential::default();
298        let mut gpr = GaussianProcessRegressor::new(kernel);
299
300        let x = Array2::from_shape_vec((5, 1), vec![0.0, 1.0, 2.0, 3.0, 4.0])
301            .expect("Operation failed");
302        let y = array![0.0, 1.0, 2.0, 1.5, 0.5];
303
304        gpr.fit(&x, &y).expect("Operation failed");
305
306        let score = gpr.score(&x, &y).expect("Operation failed");
307
308        // Should fit training data well (R² close to 1)
309        assert!(score > 0.8);
310    }
311}