1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21use scirs2_linalg::solve;
22
23use super::kernels::{cross_gram_matrix, gram_matrix, KernelType};
24use crate::error::{Result, TransformError};
25
26#[derive(Debug, Clone)]
41pub struct KernelRidgeRegression {
42 alpha: f64,
44 kernel: KernelType,
46 dual_coef: Option<Array2<f64>>,
48 training_data: Option<Array2<f64>>,
50 k_train: Option<Array2<f64>>,
52 n_outputs: usize,
54}
55
56impl KernelRidgeRegression {
57 pub fn new(alpha: f64, kernel: KernelType) -> Self {
63 KernelRidgeRegression {
64 alpha,
65 kernel,
66 dual_coef: None,
67 training_data: None,
68 k_train: None,
69 n_outputs: 0,
70 }
71 }
72
73 pub fn with_alpha(mut self, alpha: f64) -> Self {
75 self.alpha = alpha;
76 self
77 }
78
79 pub fn dual_coef(&self) -> Option<&Array2<f64>> {
81 self.dual_coef.as_ref()
82 }
83
84 pub fn kernel(&self) -> &KernelType {
86 &self.kernel
87 }
88
89 pub fn regularization(&self) -> f64 {
91 self.alpha
92 }
93
94 pub fn fit<S1, S2>(&mut self, x: &ArrayBase<S1, Ix2>, y: &ArrayBase<S2, Ix1>) -> Result<()>
100 where
101 S1: Data,
102 S2: Data,
103 S1::Elem: Float + NumCast,
104 S2::Elem: Float + NumCast,
105 {
106 let n_samples = x.nrows();
107 if n_samples == 0 {
108 return Err(TransformError::InvalidInput("Empty input data".to_string()));
109 }
110 if n_samples != y.len() {
111 return Err(TransformError::InvalidInput(format!(
112 "x has {} samples but y has {} elements",
113 n_samples,
114 y.len()
115 )));
116 }
117
118 let y_f64: Array1<f64> = y.mapv(|v| NumCast::from(v).unwrap_or(0.0));
120 let mut y_mat = Array2::zeros((n_samples, 1));
121 for i in 0..n_samples {
122 y_mat[[i, 0]] = y_f64[i];
123 }
124
125 self.fit_multi(x, &y_mat.view())
126 }
127
128 pub fn fit_multi<S1, S2>(
134 &mut self,
135 x: &ArrayBase<S1, Ix2>,
136 y: &ArrayBase<S2, Ix2>,
137 ) -> Result<()>
138 where
139 S1: Data,
140 S2: Data,
141 S1::Elem: Float + NumCast,
142 S2::Elem: Float + NumCast,
143 {
144 let n_samples = x.nrows();
145 let n_outputs = y.ncols();
146
147 if n_samples == 0 {
148 return Err(TransformError::InvalidInput("Empty input data".to_string()));
149 }
150 if n_samples != y.nrows() {
151 return Err(TransformError::InvalidInput(format!(
152 "x has {} samples but y has {} rows",
153 n_samples,
154 y.nrows()
155 )));
156 }
157 if self.alpha < 0.0 {
158 return Err(TransformError::InvalidInput(
159 "Regularization parameter alpha must be non-negative".to_string(),
160 ));
161 }
162
163 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
164 let y_f64: Array2<f64> = y.mapv(|v| NumCast::from(v).unwrap_or(0.0));
165
166 let k = gram_matrix(&x_f64.view(), &self.kernel)?;
168
169 let mut k_reg = k.clone();
171 for i in 0..n_samples {
172 k_reg[[i, i]] += self.alpha;
173 }
174
175 let mut dual_coef = Array2::zeros((n_samples, n_outputs));
177 for out in 0..n_outputs {
178 let y_col = y_f64.column(out).to_owned();
179 let coef = solve(&k_reg.view(), &y_col.view(), None).map_err(|e| {
180 TransformError::ComputationError(format!(
181 "Failed to solve kernel system for output {}: {}",
182 out, e
183 ))
184 })?;
185
186 for i in 0..n_samples {
187 dual_coef[[i, out]] = coef[i];
188 }
189 }
190
191 self.dual_coef = Some(dual_coef);
192 self.training_data = Some(x_f64);
193 self.k_train = Some(k);
194 self.n_outputs = n_outputs;
195
196 Ok(())
197 }
198
199 pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>>
207 where
208 S: Data,
209 S::Elem: Float + NumCast,
210 {
211 let predictions = self.predict_multi(x)?;
212 if self.n_outputs == 1 {
213 Ok(predictions.column(0).to_owned())
214 } else {
215 Err(TransformError::InvalidInput(
216 "Model was fitted with multiple outputs. Use predict_multi instead.".to_string(),
217 ))
218 }
219 }
220
221 pub fn predict_multi<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
229 where
230 S: Data,
231 S::Elem: Float + NumCast,
232 {
233 let dual_coef = self
234 .dual_coef
235 .as_ref()
236 .ok_or_else(|| TransformError::NotFitted("KRR not fitted".to_string()))?;
237 let training_data = self
238 .training_data
239 .as_ref()
240 .ok_or_else(|| TransformError::NotFitted("Training data not available".to_string()))?;
241
242 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
243
244 let k_test = cross_gram_matrix(&x_f64.view(), &training_data.view(), &self.kernel)?;
246
247 let n_test = x_f64.nrows();
249 let n_train = training_data.nrows();
250 let mut predictions = Array2::zeros((n_test, self.n_outputs));
251
252 for i in 0..n_test {
253 for out in 0..self.n_outputs {
254 let mut pred = 0.0;
255 for j in 0..n_train {
256 pred += k_test[[i, j]] * dual_coef[[j, out]];
257 }
258 predictions[[i, out]] = pred;
259 }
260 }
261
262 Ok(predictions)
263 }
264
265 pub fn loo_cv(&self) -> Result<(Array2<f64>, f64)> {
277 let dual_coef = self
278 .dual_coef
279 .as_ref()
280 .ok_or_else(|| TransformError::NotFitted("KRR not fitted".to_string()))?;
281 let k_train = self.k_train.as_ref().ok_or_else(|| {
282 TransformError::NotFitted("Training kernel not available".to_string())
283 })?;
284
285 let n = k_train.nrows();
286
287 let mut k_reg = k_train.clone();
291 for i in 0..n {
292 k_reg[[i, i]] += self.alpha;
293 }
294
295 let mut k_inv_diag = Array1::zeros(n);
297 for col in 0..n {
298 let mut e = Array1::zeros(n);
299 e[col] = 1.0;
300 let inv_col = solve(&k_reg.view(), &e.view(), None).map_err(|e| {
301 TransformError::ComputationError(format!(
302 "Failed to compute inverse for LOO-CV: {}",
303 e
304 ))
305 })?;
306 k_inv_diag[col] = inv_col[col];
307 }
308
309 let mut y_train = Array2::zeros((n, self.n_outputs));
313 for i in 0..n {
314 for out in 0..self.n_outputs {
315 let mut pred = 0.0;
316 for j in 0..n {
317 pred += k_train[[i, j]] * dual_coef[[j, out]];
318 }
319 y_train[[i, out]] = pred;
320 }
321 }
322
323 let mut loo_predictions = Array2::zeros((n, self.n_outputs));
324 let mut total_sq_error = 0.0;
325
326 for i in 0..n {
327 let h_ii = k_inv_diag[i];
328 if h_ii.abs() < 1e-15 {
329 for out in 0..self.n_outputs {
331 loo_predictions[[i, out]] = y_train[[i, out]];
332 }
333 continue;
334 }
335
336 for out in 0..self.n_outputs {
337 let residual = dual_coef[[i, out]] / h_ii;
338 loo_predictions[[i, out]] = y_train[[i, out]] - residual;
339 total_sq_error += residual * residual;
340 }
341 }
342
343 let loo_mse = total_sq_error / (n as f64 * self.n_outputs as f64);
344
345 Ok((loo_predictions, loo_mse))
346 }
347
348 pub fn auto_select_alpha<S1, S2>(
360 x: &ArrayBase<S1, Ix2>,
361 y: &ArrayBase<S2, Ix1>,
362 kernel: &KernelType,
363 alpha_values: &[f64],
364 ) -> Result<(f64, f64)>
365 where
366 S1: Data,
367 S2: Data,
368 S1::Elem: Float + NumCast,
369 S2::Elem: Float + NumCast,
370 {
371 if alpha_values.is_empty() {
372 return Err(TransformError::InvalidInput(
373 "alpha_values must not be empty".to_string(),
374 ));
375 }
376
377 let mut best_alpha = alpha_values[0];
378 let mut best_mse = f64::INFINITY;
379
380 for &alpha in alpha_values {
381 let mut krr = KernelRidgeRegression::new(alpha, kernel.clone());
382 match krr.fit(x, y) {
383 Ok(()) => {}
384 Err(_) => continue,
385 }
386
387 match krr.loo_cv() {
388 Ok((_, mse)) => {
389 if mse < best_mse {
390 best_mse = mse;
391 best_alpha = alpha;
392 }
393 }
394 Err(_) => continue,
395 }
396 }
397
398 if best_mse.is_infinite() {
399 return Err(TransformError::ComputationError(
400 "All alpha values failed in LOO-CV".to_string(),
401 ));
402 }
403
404 Ok((best_alpha, best_mse))
405 }
406
407 pub fn score<S>(&self, x: &ArrayBase<S, Ix2>, y_true: &Array1<f64>) -> Result<f64>
415 where
416 S: Data,
417 S::Elem: Float + NumCast,
418 {
419 let y_pred = self.predict(x)?;
420
421 let n = y_true.len();
422 if n != y_pred.len() {
423 return Err(TransformError::InvalidInput(
424 "Predictions and true values have different lengths".to_string(),
425 ));
426 }
427
428 let y_mean = y_true.sum() / n as f64;
429
430 let mut ss_res = 0.0;
431 let mut ss_tot = 0.0;
432 for i in 0..n {
433 let residual = y_true[i] - y_pred[i];
434 ss_res += residual * residual;
435 let deviation = y_true[i] - y_mean;
436 ss_tot += deviation * deviation;
437 }
438
439 if ss_tot < 1e-15 {
440 Ok(if ss_res < 1e-15 { 1.0 } else { 0.0 })
442 } else {
443 Ok(1.0 - ss_res / ss_tot)
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use scirs2_core::ndarray::Array;
452
453 fn make_regression_data(n: usize) -> (Array2<f64>, Array1<f64>) {
454 let mut x_data = Vec::with_capacity(n * 2);
455 let mut y_data = Vec::with_capacity(n);
456 for i in 0..n {
457 let t = i as f64 / n as f64 * 4.0;
458 x_data.push(t);
459 x_data.push(t * t);
460 y_data.push((t * std::f64::consts::PI).sin() + 0.1 * (i as f64 * 0.1));
461 }
462 let x = Array::from_shape_vec((n, 2), x_data).expect("Failed");
463 let y = Array::from_vec(y_data);
464 (x, y)
465 }
466
467 #[test]
468 fn test_krr_basic_fit_predict() {
469 let (x, y) = make_regression_data(30);
470 let mut krr = KernelRidgeRegression::new(1.0, KernelType::RBF { gamma: 0.5 });
471 krr.fit(&x, &y).expect("KRR fit failed");
472
473 let predictions = krr.predict(&x).expect("KRR predict failed");
474 assert_eq!(predictions.len(), 30);
475 for val in predictions.iter() {
476 assert!(val.is_finite());
477 }
478 }
479
480 #[test]
481 fn test_krr_linear_kernel() {
482 let (x, y) = make_regression_data(20);
483 let mut krr = KernelRidgeRegression::new(0.1, KernelType::Linear);
484 krr.fit(&x, &y).expect("KRR fit failed");
485
486 let predictions = krr.predict(&x).expect("KRR predict failed");
487 assert_eq!(predictions.len(), 20);
488 for val in predictions.iter() {
489 assert!(val.is_finite());
490 }
491 }
492
493 #[test]
494 fn test_krr_polynomial_kernel() {
495 let (x, y) = make_regression_data(20);
496 let kernel = KernelType::Polynomial {
497 gamma: 1.0,
498 coef0: 1.0,
499 degree: 2,
500 };
501 let mut krr = KernelRidgeRegression::new(0.5, kernel);
502 krr.fit(&x, &y).expect("KRR fit failed");
503
504 let predictions = krr.predict(&x).expect("KRR predict failed");
505 assert_eq!(predictions.len(), 20);
506 }
507
508 #[test]
509 fn test_krr_multi_output() {
510 let n = 20;
511 let mut x_data = Vec::with_capacity(n * 2);
512 let mut y_data = Vec::with_capacity(n * 2);
513 for i in 0..n {
514 let t = i as f64 / n as f64;
515 x_data.push(t);
516 x_data.push(t * t);
517 y_data.push(t.sin());
518 y_data.push(t.cos());
519 }
520 let x = Array::from_shape_vec((n, 2), x_data).expect("Failed");
521 let y = Array::from_shape_vec((n, 2), y_data).expect("Failed");
522
523 let mut krr = KernelRidgeRegression::new(0.1, KernelType::RBF { gamma: 1.0 });
524 krr.fit_multi(&x, &y).expect("KRR multi-fit failed");
525
526 let predictions = krr.predict_multi(&x).expect("KRR predict_multi failed");
527 assert_eq!(predictions.shape(), &[n, 2]);
528 for val in predictions.iter() {
529 assert!(val.is_finite());
530 }
531 }
532
533 #[test]
534 fn test_krr_loo_cv() {
535 let (x, y) = make_regression_data(20);
536 let mut krr = KernelRidgeRegression::new(1.0, KernelType::RBF { gamma: 0.5 });
537 krr.fit(&x, &y).expect("KRR fit failed");
538
539 let (loo_preds, loo_mse) = krr.loo_cv().expect("LOO-CV failed");
540 assert_eq!(loo_preds.shape(), &[20, 1]);
541 assert!(loo_mse >= 0.0);
542 assert!(loo_mse.is_finite());
543 }
544
545 #[test]
546 fn test_krr_auto_alpha() {
547 let (x, y) = make_regression_data(20);
548 let kernel = KernelType::RBF { gamma: 0.5 };
549 let alphas = vec![0.001, 0.01, 0.1, 1.0, 10.0];
550
551 let (best_alpha, best_mse) =
552 KernelRidgeRegression::auto_select_alpha(&x.view(), &y.view(), &kernel, &alphas)
553 .expect("Auto alpha failed");
554
555 assert!(best_alpha > 0.0);
556 assert!(best_mse >= 0.0);
557 assert!(best_mse.is_finite());
558 }
559
560 #[test]
561 fn test_krr_r_squared() {
562 let (x, y) = make_regression_data(30);
563 let mut krr = KernelRidgeRegression::new(0.1, KernelType::RBF { gamma: 1.0 });
564 krr.fit(&x, &y).expect("KRR fit failed");
565
566 let r2 = krr.score(&x, &y).expect("R2 score failed");
567 assert!(r2 > 0.5, "R2 should be > 0.5 on training data, got {}", r2);
569 assert!(r2 <= 1.0 + 1e-10);
570 }
571
572 #[test]
573 fn test_krr_out_of_sample() {
574 let (x_train, y_train) = make_regression_data(30);
575 let mut krr = KernelRidgeRegression::new(0.5, KernelType::RBF { gamma: 0.5 });
576 krr.fit(&x_train, &y_train).expect("KRR fit failed");
577
578 let x_test =
579 Array::from_shape_vec((3, 2), vec![0.5, 0.25, 1.0, 1.0, 2.0, 4.0]).expect("Failed");
580
581 let predictions = krr.predict(&x_test).expect("KRR predict failed");
582 assert_eq!(predictions.len(), 3);
583 for val in predictions.iter() {
584 assert!(val.is_finite());
585 }
586 }
587
588 #[test]
589 fn test_krr_empty_data() {
590 let x: Array2<f64> = Array2::zeros((0, 3));
591 let y: Array1<f64> = Array1::zeros(0);
592 let mut krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
593 assert!(krr.fit(&x, &y).is_err());
594 }
595
596 #[test]
597 fn test_krr_mismatched_samples() {
598 let x = Array::from_shape_vec((5, 2), vec![1.0; 10]).expect("Failed");
599 let y = Array::from_vec(vec![1.0; 3]);
600 let mut krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
601 assert!(krr.fit(&x, &y).is_err());
602 }
603
604 #[test]
605 fn test_krr_not_fitted() {
606 let krr = KernelRidgeRegression::new(1.0, KernelType::Linear);
607 let x = Array::from_shape_vec((3, 2), vec![1.0; 6]).expect("Failed");
608 assert!(krr.predict(&x).is_err());
609 }
610
611 #[test]
612 fn test_krr_laplacian_kernel() {
613 let (x, y) = make_regression_data(20);
614 let mut krr = KernelRidgeRegression::new(0.5, KernelType::Laplacian { gamma: 0.5 });
615 krr.fit(&x, &y).expect("KRR fit failed");
616
617 let predictions = krr.predict(&x).expect("KRR predict failed");
618 assert_eq!(predictions.len(), 20);
619 for val in predictions.iter() {
620 assert!(val.is_finite());
621 }
622 }
623
624 #[test]
625 fn test_krr_high_regularization() {
626 let (x, y) = make_regression_data(20);
627 let mut krr = KernelRidgeRegression::new(1000.0, KernelType::RBF { gamma: 1.0 });
628 krr.fit(&x, &y).expect("KRR fit failed");
629
630 let predictions = krr.predict(&x).expect("KRR predict failed");
632 let pred_var: f64 = {
633 let mean = predictions.sum() / predictions.len() as f64;
634 predictions
635 .iter()
636 .map(|&p| (p - mean) * (p - mean))
637 .sum::<f64>()
638 / predictions.len() as f64
639 };
640 assert!(
642 pred_var < 1.0,
643 "High regularization should reduce prediction variance, got {}",
644 pred_var
645 );
646 }
647}