Skip to main content

scirs2_transform/kernel/
kernel_ridge.rs

1//! Kernel Ridge Regression
2//!
3//! Kernel Ridge Regression (KRR) combines Ridge Regression with the kernel trick.
4//! It learns a linear function in the kernel-induced feature space that corresponds
5//! to a nonlinear function in the original space.
6//!
7//! ## Algorithm
8//!
9//! The KRR solution is: alpha = (K + lambda * I)^{-1} y
10//! Prediction: y_pred = K_test * alpha
11//!
12//! ## Features
13//!
14//! - Tikhonov regularized kernel regression
15//! - Leave-one-out cross-validation in closed form (O(n^3) once)
16//! - Multiple output support (each output trained independently)
17//! - Support for all kernel types
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_linalg::solve;
22
23use super::kernels::{cross_gram_matrix, gram_matrix, KernelType};
24use crate::error::{Result, TransformError};
25
26/// Kernel Ridge Regression
27///
28/// # Example
29///
30/// ```rust,no_run
31/// use scirs2_transform::kernel::{KernelRidgeRegression, KernelType};
32/// use scirs2_core::ndarray::{Array1, Array2};
33///
34/// let x = Array2::<f64>::zeros((50, 3));
35/// let y = Array1::<f64>::zeros(50);
36/// let mut krr = KernelRidgeRegression::new(1.0, KernelType::RBF { gamma: 0.1 });
37/// krr.fit(&x, &y).expect("should succeed");
38/// let predictions = krr.predict(&x).expect("should succeed");
39/// ```
40#[derive(Debug, Clone)]
41pub struct KernelRidgeRegression {
42    /// Regularization parameter (lambda)
43    alpha: f64,
44    /// Kernel function type
45    kernel: KernelType,
46    /// Dual coefficients (solution in kernel space)
47    dual_coef: Option<Array2<f64>>,
48    /// Training data
49    training_data: Option<Array2<f64>>,
50    /// Training kernel matrix (for LOO-CV)
51    k_train: Option<Array2<f64>>,
52    /// Number of outputs
53    n_outputs: usize,
54}
55
56impl KernelRidgeRegression {
57    /// Create a new KernelRidgeRegression
58    ///
59    /// # Arguments
60    /// * `alpha` - Regularization parameter (lambda). Larger values = more regularization.
61    /// * `kernel` - The kernel function to use
62    pub fn new(alpha: f64, kernel: KernelType) -> Self {
63        KernelRidgeRegression {
64            alpha,
65            kernel,
66            dual_coef: None,
67            training_data: None,
68            k_train: None,
69            n_outputs: 0,
70        }
71    }
72
73    /// Set the regularization parameter
74    pub fn with_alpha(mut self, alpha: f64) -> Self {
75        self.alpha = alpha;
76        self
77    }
78
79    /// Get the dual coefficients
80    pub fn dual_coef(&self) -> Option<&Array2<f64>> {
81        self.dual_coef.as_ref()
82    }
83
84    /// Get the kernel type
85    pub fn kernel(&self) -> &KernelType {
86        &self.kernel
87    }
88
89    /// Get the regularization parameter
90    pub fn regularization(&self) -> f64 {
91        self.alpha
92    }
93
94    /// Fit the model with a single output target
95    ///
96    /// # Arguments
97    /// * `x` - Training data, shape (n_samples, n_features)
98    /// * `y` - Target values, shape (n_samples,)
99    pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
100    where
101        S1: Data,
102        S2: Data,
103        S1::Elem: Float + NumCast,
104        S2::Elem: Float + NumCast,
105    {
106        let n_samples = x.nrows();
107        if n_samples == 0 {
108            return Err(TransformError::InvalidInput("Empty input data".to_string()));
109        }
110        if n_samples != y.len() {
111            return Err(TransformError::InvalidInput(format!(
112                "x has {} samples but y has {} elements",
113                n_samples,
114                y.len()
115            )));
116        }
117
118        // Convert y to a column matrix for uniform handling
119        let y_f64: Array1<f64> = y.mapv(|v| NumCast::from(v).unwrap_or(0.0));
120        let mut y_mat = Array2::zeros((n_samples, 1));
121        for i in 0..n_samples {
122            y_mat[[i, 0]] = y_f64[i];
123        }
124
125        self.fit_multi(x, &y_mat.view())
126    }
127
128    /// Fit the model with multiple output targets
129    ///
130    /// # Arguments
131    /// * `x` - Training data, shape (n_samples, n_features)
132    /// * `y` - Target values, shape (n_samples, n_outputs)
133    pub fn fit_multi<S1, S2>(
134        &mut self,
135        x: &ArrayBase<S1, Ix2>,
136        y: &ArrayBase<S2, Ix2>,
137    ) -> Result<()>
138    where
139        S1: Data,
140        S2: Data,
141        S1::Elem: Float + NumCast,
142        S2::Elem: Float + NumCast,
143    {
144        let n_samples = x.nrows();
145        let n_outputs = y.ncols();
146
147        if n_samples == 0 {
148            return Err(TransformError::InvalidInput("Empty input data".to_string()));
149        }
150        if n_samples != y.nrows() {
151            return Err(TransformError::InvalidInput(format!(
152                "x has {} samples but y has {} rows",
153                n_samples,
154                y.nrows()
155            )));
156        }
157        if self.alpha < 0.0 {
158            return Err(TransformError::InvalidInput(
159                "Regularization parameter alpha must be non-negative".to_string(),
160            ));
161        }
162
163        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
164        let y_f64: Array2<f64> = y.mapv(|v| NumCast::from(v).unwrap_or(0.0));
165
166        // Compute kernel matrix
167        let k = gram_matrix(&x_f64.view(), &self.kernel)?;
168
169        // Add regularization: K + alpha * I
170        let mut k_reg = k.clone();
171        for i in 0..n_samples {
172            k_reg[[i, i]] += self.alpha;
173        }
174
175        // Solve (K + alpha*I) * alpha_coef = Y for each output
176        let mut dual_coef = Array2::zeros((n_samples, n_outputs));
177        for out in 0..n_outputs {
178            let y_col = y_f64.column(out).to_owned();
179            let coef = solve(&k_reg.view(), &y_col.view(), None).map_err(|e| {
180                TransformError::ComputationError(format!(
181                    "Failed to solve kernel system for output {}: {}",
182                    out, e
183                ))
184            })?;
185
186            for i in 0..n_samples {
187                dual_coef[[i, out]] = coef[i];
188            }
189        }
190
191        self.dual_coef = Some(dual_coef);
192        self.training_data = Some(x_f64);
193        self.k_train = Some(k);
194        self.n_outputs = n_outputs;
195
196        Ok(())
197    }
198
199    /// Predict for new data (single output)
200    ///
201    /// # Arguments
202    /// * `x` - Test data, shape (n_test, n_features)
203    ///
204    /// # Returns
205    /// * Predictions, shape (n_test,) for single output
206    pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>>
207    where
208        S: Data,
209        S::Elem: Float + NumCast,
210    {
211        let predictions = self.predict_multi(x)?;
212        if self.n_outputs == 1 {
213            Ok(predictions.column(0).to_owned())
214        } else {
215            Err(TransformError::InvalidInput(
216                "Model was fitted with multiple outputs. Use predict_multi instead.".to_string(),
217            ))
218        }
219    }
220
221    /// Predict for new data (multiple outputs)
222    ///
223    /// # Arguments
224    /// * `x` - Test data, shape (n_test, n_features)
225    ///
226    /// # Returns
227    /// * Predictions, shape (n_test, n_outputs)
228    pub fn predict_multi<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
229    where
230        S: Data,
231        S::Elem: Float + NumCast,
232    {
233        let dual_coef = self
234            .dual_coef
235            .as_ref()
236            .ok_or_else(|| TransformError::NotFitted("KRR not fitted".to_string()))?;
237        let training_data = self
238            .training_data
239            .as_ref()
240            .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
241
242        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
243
244        // Compute kernel between test and training data
245        let k_test = cross_gram_matrix(&x_f64.view(), &training_data.view(), &self.kernel)?;
246
247        // Predict: y = K_test * dual_coef
248        let n_test = x_f64.nrows();
249        let n_train = training_data.nrows();
250        let mut predictions = Array2::zeros((n_test, self.n_outputs));
251
252        for i in 0..n_test {
253            for out in 0..self.n_outputs {
254                let mut pred = 0.0;
255                for j in 0..n_train {
256                    pred += k_test[[i, j]] * dual_coef[[j, out]];
257                }
258                predictions[[i, out]] = pred;
259            }
260        }
261
262        Ok(predictions)
263    }
264
265    /// Leave-one-out cross-validation in closed form
266    ///
267    /// Computes the LOO-CV predictions and error without explicitly
268    /// re-fitting the model n times. Uses the formula:
269    ///
270    /// LOO_residual_i = alpha_i / (K + lambda*I)^{-1}_{ii}
271    ///
272    /// which requires only one matrix inversion.
273    ///
274    /// # Returns
275    /// * `(loo_predictions, loo_mse)` - LOO predictions for each sample and mean squared error
276    pub fn loo_cv(&self) -> Result<(Array2<f64>, f64)> {
277        let dual_coef = self
278            .dual_coef
279            .as_ref()
280            .ok_or_else(|| TransformError::NotFitted("KRR not fitted".to_string()))?;
281        let k_train = self.k_train.as_ref().ok_or_else(|| {
282            TransformError::NotFitted("Training kernel not available".to_string())
283        })?;
284
285        let n = k_train.nrows();
286
287        // Compute (K + alpha*I)^{-1}
288        // We need the diagonal of the inverse
289        // Solve (K + alpha*I) * X = I to get the inverse
290        let mut k_reg = k_train.clone();
291        for i in 0..n {
292            k_reg[[i, i]] += self.alpha;
293        }
294
295        // Compute inverse column by column
296        let mut k_inv_diag = Array1::zeros(n);
297        for col in 0..n {
298            let mut e = Array1::zeros(n);
299            e[col] = 1.0;
300            let inv_col = solve(&k_reg.view(), &e.view(), None).map_err(|e| {
301                TransformError::ComputationError(format!(
302                    "Failed to compute inverse for LOO-CV: {}",
303                    e
304                ))
305            })?;
306            k_inv_diag[col] = inv_col[col];
307        }
308
309        // LOO residual: r_i = alpha_i / (K_inv)_{ii}
310        // LOO prediction: y_loo_i = y_i - r_i
311        // But we need y_i = K * alpha (training predictions)
312        let mut y_train = Array2::zeros((n, self.n_outputs));
313        for i in 0..n {
314            for out in 0..self.n_outputs {
315                let mut pred = 0.0;
316                for j in 0..n {
317                    pred += k_train[[i, j]] * dual_coef[[j, out]];
318                }
319                y_train[[i, out]] = pred;
320            }
321        }
322
323        let mut loo_predictions = Array2::zeros((n, self.n_outputs));
324        let mut total_sq_error = 0.0;
325
326        for i in 0..n {
327            let h_ii = k_inv_diag[i];
328            if h_ii.abs() < 1e-15 {
329                // Degenerate case, skip
330                for out in 0..self.n_outputs {
331                    loo_predictions[[i, out]] = y_train[[i, out]];
332                }
333                continue;
334            }
335
336            for out in 0..self.n_outputs {
337                let residual = dual_coef[[i, out]] / h_ii;
338                loo_predictions[[i, out]] = y_train[[i, out]] - residual;
339                total_sq_error += residual * residual;
340            }
341        }
342
343        let loo_mse = total_sq_error / (n as f64 * self.n_outputs as f64);
344
345        Ok((loo_predictions, loo_mse))
346    }
347
348    /// Automatic selection of the regularization parameter via LOO-CV
349    ///
350    /// Tries multiple alpha values and selects the one with lowest LOO-CV error.
351    ///
352    /// # Arguments
353    /// * `x` - Training data
354    /// * `y` - Target values (single output)
355    /// * `alpha_values` - Candidate regularization parameters
356    ///
357    /// # Returns
358    /// * `(best_alpha, best_mse)` - Best alpha and corresponding LOO-CV MSE
359    pub fn auto_select_alpha<S1, S2>(
360        x: &ArrayBase<S1, Ix2>,
361        y: &ArrayBase<S2, Ix1>,
362        kernel: &KernelType,
363        alpha_values: &[f64],
364    ) -> Result<(f64, f64)>
365    where
366        S1: Data,
367        S2: Data,
368        S1::Elem: Float + NumCast,
369        S2::Elem: Float + NumCast,
370    {
371        if alpha_values.is_empty() {
372            return Err(TransformError::InvalidInput(
373                "alpha_values must not be empty".to_string(),
374            ));
375        }
376
377        let mut best_alpha = alpha_values[0];
378        let mut best_mse = f64::INFINITY;
379
380        for &alpha in alpha_values {
381            let mut krr = KernelRidgeRegression::new(alpha, kernel.clone());
382            match krr.fit(x, y) {
383                Ok(()) => {}
384                Err(_) => continue,
385            }
386
387            match krr.loo_cv() {
388                Ok((_, mse)) => {
389                    if mse < best_mse {
390                        best_mse = mse;
391                        best_alpha = alpha;
392                    }
393                }
394                Err(_) => continue,
395            }
396        }
397
398        if best_mse.is_infinite() {
399            return Err(TransformError::ComputationError(
400                "All alpha values failed in LOO-CV".to_string(),
401            ));
402        }
403
404        Ok((best_alpha, best_mse))
405    }
406
407    /// Compute the R-squared score for the training data
408    ///
409    /// # Arguments
410    /// * `y_true` - True target values
411    ///
412    /// # Returns
413    /// * R-squared score
414    pub fn score<S>(&self, x: &ArrayBase<S, Ix2>, y_true: &Array1<f64>) -> Result<f64>
415    where
416        S: Data,
417        S::Elem: Float + NumCast,
418    {
419        let y_pred = self.predict(x)?;
420
421        let n = y_true.len();
422        if n != y_pred.len() {
423            return Err(TransformError::InvalidInput(
424                "Predictions and true values have different lengths".to_string(),
425            ));
426        }
427
428        let y_mean = y_true.sum() / n as f64;
429
430        let mut ss_res = 0.0;
431        let mut ss_tot = 0.0;
432        for i in 0..n {
433            let residual = y_true[i] - y_pred[i];
434            ss_res += residual * residual;
435            let deviation = y_true[i] - y_mean;
436            ss_tot += deviation * deviation;
437        }
438
439        if ss_tot < 1e-15 {
440            // All targets are the same
441            Ok(if ss_res < 1e-15 { 1.0 } else { 0.0 })
442        } else {
443            Ok(1.0 - ss_res / ss_tot)
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use scirs2_core::ndarray::Array;
452
453    fn make_regression_data(n: usize) -> (Array2<f64>, Array1<f64>) {
454        let mut x_data = Vec::with_capacity(n * 2);
455        let mut y_data = Vec::with_capacity(n);
456        for i in 0..n {
457            let t = i as f64 / n as f64 * 4.0;
458            x_data.push(t);
459            x_data.push(t * t);
460            y_data.push((t * std::f64::consts::PI).sin() + 0.1 * (i as f64 * 0.1));
461        }
462        let x = Array::from_shape_vec((n, 2), x_data).expect("Failed");
463        let y = Array::from_vec(y_data);
464        (x, y)
465    }
466
467    #[test]
468    fn test_krr_basic_fit_predict() {
469        let (x, y) = make_regression_data(30);
470        let mut krr = KernelRidgeRegression::new(1.0, KernelType::RBF { gamma: 0.5 });
471        krr.fit(&x, &y).expect("KRR fit failed");
472
473        let predictions = krr.predict(&x).expect("KRR predict failed");
474        assert_eq!(predictions.len(), 30);
475        for val in predictions.iter() {
476            assert!(val.is_finite());
477        }
478    }
479
480    #[test]
481    fn test_krr_linear_kernel() {
482        let (x, y) = make_regression_data(20);
483        let mut krr = KernelRidgeRegression::new(0.1, KernelType::Linear);
484        krr.fit(&x, &y).expect("KRR fit failed");
485
486        let predictions = krr.predict(&x).expect("KRR predict failed");
487        assert_eq!(predictions.len(), 20);
488        for val in predictions.iter() {
489            assert!(val.is_finite());
490        }
491    }
492
493    #[test]
494    fn test_krr_polynomial_kernel() {
495        let (x, y) = make_regression_data(20);
496        let kernel = KernelType::Polynomial {
497            gamma: 1.0,
498            coef0: 1.0,
499            degree: 2,
500        };
501        let mut krr = KernelRidgeRegression::new(0.5, kernel);
502        krr.fit(&x, &y).expect("KRR fit failed");
503
504        let predictions = krr.predict(&x).expect("KRR predict failed");
505        assert_eq!(predictions.len(), 20);
506    }
507
508    #[test]
509    fn test_krr_multi_output() {
510        let n = 20;
511        let mut x_data = Vec::with_capacity(n * 2);
512        let mut y_data = Vec::with_capacity(n * 2);
513        for i in 0..n {
514            let t = i as f64 / n as f64;
515            x_data.push(t);
516            x_data.push(t * t);
517            y_data.push(t.sin());
518            y_data.push(t.cos());
519        }
520        let x = Array::from_shape_vec((n, 2), x_data).expect("Failed");
521        let y = Array::from_shape_vec((n, 2), y_data).expect("Failed");
522
523        let mut krr = KernelRidgeRegression::new(0.1, KernelType::RBF { gamma: 1.0 });
524        krr.fit_multi(&x, &y).expect("KRR multi-fit failed");
525
526        let predictions = krr.predict_multi(&x).expect("KRR predict_multi failed");
527        assert_eq!(predictions.shape(), &[n, 2]);
528        for val in predictions.iter() {
529            assert!(val.is_finite());
530        }
531    }
532
533    #[test]
534    fn test_krr_loo_cv() {
535        let (x, y) = make_regression_data(20);
536        let mut krr = KernelRidgeRegression::new(1.0, KernelType::RBF { gamma: 0.5 });
537        krr.fit(&x, &y).expect("KRR fit failed");
538
539        let (loo_preds, loo_mse) = krr.loo_cv().expect("LOO-CV failed");
540        assert_eq!(loo_preds.shape(), &[20, 1]);
541        assert!(loo_mse >= 0.0);
542        assert!(loo_mse.is_finite());
543    }
544
545    #[test]
546    fn test_krr_auto_alpha() {
547        let (x, y) = make_regression_data(20);
548        let kernel = KernelType::RBF { gamma: 0.5 };
549        let alphas = vec![0.001, 0.01, 0.1, 1.0, 10.0];
550
551        let (best_alpha, best_mse) =
552            KernelRidgeRegression::auto_select_alpha(&x.view(), &y.view(), &kernel, &alphas)
553                .expect("Auto alpha failed");
554
555        assert!(best_alpha > 0.0);
556        assert!(best_mse >= 0.0);
557        assert!(best_mse.is_finite());
558    }
559
560    #[test]
561    fn test_krr_r_squared() {
562        let (x, y) = make_regression_data(30);
563        let mut krr = KernelRidgeRegression::new(0.1, KernelType::RBF { gamma: 1.0 });
564        krr.fit(&x, &y).expect("KRR fit failed");
565
566        let r2 = krr.score(&x, &y).expect("R2 score failed");
567        // On training data with RBF kernel, R2 should be high
568        assert!(r2 > 0.5, "R2 should be > 0.5 on training data, got {}", r2);
569        assert!(r2 <= 1.0 + 1e-10);
570    }
571
572    #[test]
573    fn test_krr_out_of_sample() {
574        let (x_train, y_train) = make_regression_data(30);
575        let mut krr = KernelRidgeRegression::new(0.5, KernelType::RBF { gamma: 0.5 });
576        krr.fit(&x_train, &y_train).expect("KRR fit failed");
577
578        let x_test =
579            Array::from_shape_vec((3, 2), vec![0.5, 0.25, 1.0, 1.0, 2.0, 4.0]).expect("Failed");
580
581        let predictions = krr.predict(&x_test).expect("KRR predict failed");
582        assert_eq!(predictions.len(), 3);
583        for val in predictions.iter() {
584            assert!(val.is_finite());
585        }
586    }
587
588    #[test]
589    fn test_krr_empty_data() {
590        let x: Array2<f64> = Array2::zeros((0, 3));
591        let y: Array1<f64> = Array1::zeros(0);
592        let mut krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
593        assert!(krr.fit(&x, &y).is_err());
594    }
595
596    #[test]
597    fn test_krr_mismatched_samples() {
598        let x = Array::from_shape_vec((5, 2), vec![1.0; 10]).expect("Failed");
599        let y = Array::from_vec(vec![1.0; 3]);
600        let mut krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
601        assert!(krr.fit(&x, &y).is_err());
602    }
603
604    #[test]
605    fn test_krr_not_fitted() {
606        let krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
607        let x = Array::from_shape_vec((3, 2), vec![1.0; 6]).expect("Failed");
608        assert!(krr.predict(&x).is_err());
609    }
610
611    #[test]
612    fn test_krr_laplacian_kernel() {
613        let (x, y) = make_regression_data(20);
614        let mut krr = KernelRidgeRegression::new(0.5, KernelType::Laplacian { gamma: 0.5 });
615        krr.fit(&x, &y).expect("KRR fit failed");
616
617        let predictions = krr.predict(&x).expect("KRR predict failed");
618        assert_eq!(predictions.len(), 20);
619        for val in predictions.iter() {
620            assert!(val.is_finite());
621        }
622    }
623
624    #[test]
625    fn test_krr_high_regularization() {
626        let (x, y) = make_regression_data(20);
627        let mut krr = KernelRidgeRegression::new(1000.0, KernelType::RBF { gamma: 1.0 });
628        krr.fit(&x, &y).expect("KRR fit failed");
629
630        // High regularization should make predictions closer to the mean
631        let predictions = krr.predict(&x).expect("KRR predict failed");
632        let pred_var: f64 = {
633            let mean = predictions.sum() / predictions.len() as f64;
634            predictions
635                .iter()
636                .map(|&p| (p - mean) * (p - mean))
637                .sum::<f64>()
638                / predictions.len() as f64
639        };
640        // Variance should be small with high regularization
641        assert!(
642            pred_var < 1.0,
643            "High regularization should reduce prediction variance, got {}",
644            pred_var
645        );
646    }
647}