1use crate::{
10 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_linalg::compat::ArrayLinalgExt;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
15use scirs2_core::random::thread_rng;
16use scirs2_core::random::Rng;
17use sklears_core::error::{Result, SklearsError};
18use sklears_core::prelude::{Estimator, Fit, Float, Predict};
19use std::marker::PhantomData;
20
21use super::core_types::*;
22
23#[derive(Debug, Clone)]
28pub struct KernelRidgeRegression<State = Untrained> {
29 pub approximation_method: ApproximationMethod,
30 pub alpha: Float,
31 pub solver: Solver,
32 pub random_state: Option<u64>,
33
34 pub(crate) weights_: Option<Array1<Float>>,
36 pub(crate) feature_transformer_: Option<FeatureTransformer>,
37
38 pub(crate) _state: PhantomData<State>,
39}
40
41impl KernelRidgeRegression<Untrained> {
42 pub fn new(approximation_method: ApproximationMethod) -> Self {
44 Self {
45 approximation_method,
46 alpha: 1.0,
47 solver: Solver::Direct,
48 random_state: None,
49 weights_: None,
50 feature_transformer_: None,
51 _state: PhantomData,
52 }
53 }
54
55 pub fn alpha(mut self, alpha: Float) -> Self {
57 self.alpha = alpha;
58 self
59 }
60
61 pub fn solver(mut self, solver: Solver) -> Self {
63 self.solver = solver;
64 self
65 }
66
67 pub fn random_state(mut self, seed: u64) -> Self {
69 self.random_state = Some(seed);
70 self
71 }
72}
73
74impl Estimator for KernelRidgeRegression<Untrained> {
75 type Config = ();
76 type Error = SklearsError;
77 type Float = Float;
78
79 fn config(&self) -> &Self::Config {
80 &()
81 }
82}
83
84impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
85 type Fitted = KernelRidgeRegression<Trained>;
86
87 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
88 let (n_samples, _) = x.dim();
89
90 if y.len() != n_samples {
91 return Err(SklearsError::InvalidInput(
92 "Number of samples in X and y must match".to_string(),
93 ));
94 }
95
96 let feature_transformer = self.fit_feature_transformer(x)?;
98
99 let x_transformed = feature_transformer.transform(x)?;
101
102 let weights = self.solve_ridge_regression(&x_transformed, y)?;
104
105 Ok(KernelRidgeRegression {
106 approximation_method: self.approximation_method,
107 alpha: self.alpha,
108 solver: self.solver,
109 random_state: self.random_state,
110 weights_: Some(weights),
111 feature_transformer_: Some(feature_transformer),
112 _state: PhantomData,
113 })
114 }
115}
116
117impl KernelRidgeRegression<Untrained> {
118 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
120 match &self.approximation_method {
121 ApproximationMethod::Nystroem {
122 kernel,
123 n_components,
124 sampling_strategy,
125 } => {
126 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
127 .sampling_strategy(sampling_strategy.clone());
128
129 if let Some(seed) = self.random_state {
130 nystroem = nystroem.random_state(seed);
131 }
132
133 let fitted = nystroem.fit(x, &())?;
134 Ok(FeatureTransformer::Nystroem(fitted))
135 }
136 ApproximationMethod::RandomFourierFeatures {
137 n_components,
138 gamma,
139 } => {
140 let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
141
142 if let Some(seed) = self.random_state {
143 rbf_sampler = rbf_sampler.random_state(seed);
144 }
145
146 let fitted = rbf_sampler.fit(x, &())?;
147 Ok(FeatureTransformer::RBFSampler(fitted))
148 }
149 ApproximationMethod::StructuredRandomFeatures {
150 n_components,
151 gamma,
152 } => {
153 let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
154
155 if let Some(seed) = self.random_state {
156 structured_rff = structured_rff.random_state(seed);
157 }
158
159 let fitted = structured_rff.fit(x, &())?;
160 Ok(FeatureTransformer::StructuredRFF(fitted))
161 }
162 ApproximationMethod::Fastfood {
163 n_components,
164 gamma,
165 } => {
166 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
167
168 if let Some(seed) = self.random_state {
169 fastfood = fastfood.random_state(seed);
170 }
171
172 let fitted = fastfood.fit(x, &())?;
173 Ok(FeatureTransformer::Fastfood(fitted))
174 }
175 }
176 }
177
178 fn solve_ridge_regression(
180 &self,
181 x: &Array2<Float>,
182 y: &Array1<Float>,
183 ) -> Result<Array1<Float>> {
184 let (_n_samples, _n_features) = x.dim();
185
186 match &self.solver {
187 Solver::Direct => self.solve_direct(x, y),
188 Solver::SVD => self.solve_svd(x, y),
189 Solver::ConjugateGradient { max_iter, tol } => {
190 self.solve_conjugate_gradient(x, y, *max_iter, *tol)
191 }
192 }
193 }
194
195 fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
197 let (_, n_features) = x.dim();
198
199 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
201 let y_f64 = Array1::from_vec(y.iter().copied().collect());
202
203 let weights_f64 = {
226 let gram_matrix = x_f64.t().dot(&x_f64);
228 let mut regularized_gram = gram_matrix;
229 for i in 0..n_features {
230 regularized_gram[[i, i]] += self.alpha;
231 }
232 let xty_f64 = x_f64.t().dot(&y_f64);
233
234 regularized_gram
236 .solve(&xty_f64)
237 .map_err(|e| SklearsError::InvalidParameter {
238 name: "regularization".to_string(),
239 reason: format!("Linear system solving failed: {:?}", e),
240 })?
241 };
242
243 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
245 Ok(weights)
246 }
247
248 fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
250 let (_n_samples, _n_features) = x.dim();
254
255 let (u, s, vt) = self.compute_svd(x)?;
257
258 let threshold = 1e-12;
260 let mut s_reg_inv = Array1::zeros(s.len());
261 for i in 0..s.len() {
262 if s[i] > threshold {
263 s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
264 }
265 }
266
267 let ut_y = u.t().dot(y);
269 let mut temp = Array1::zeros(s.len());
270 for i in 0..s.len() {
271 temp[i] = s_reg_inv[i] * ut_y[i];
272 }
273
274 let weights = vt.t().dot(&temp);
275 Ok(weights)
276 }
277
278 fn solve_conjugate_gradient(
280 &self,
281 x: &Array2<Float>,
282 y: &Array1<Float>,
283 max_iter: usize,
284 tol: Float,
285 ) -> Result<Array1<Float>> {
286 let (_n_samples, n_features) = x.dim();
287
288 let mut w = Array1::zeros(n_features);
290
291 let xty = x.t().dot(y);
293 let mut r = xty.clone();
294
295 let mut p = r.clone();
296 let mut rsold = r.dot(&r);
297
298 for _iter in 0..max_iter {
299 let xtxp = x.t().dot(&x.dot(&p));
301 let mut ap = xtxp;
302 for i in 0..n_features {
303 ap[i] += self.alpha * p[i];
304 }
305
306 let alpha_cg = rsold / p.dot(&ap);
308
309 w = w + alpha_cg * &p;
311
312 r = r - alpha_cg * ≈
314
315 let rsnew = r.dot(&r);
316
317 if rsnew.sqrt() < tol {
319 break;
320 }
321
322 let beta = rsnew / rsold;
324 p = &r + beta * &p;
325 rsold = rsnew;
326 }
327
328 Ok(w)
329 }
330
331 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
333 let n = a.nrows();
334 if n != a.ncols() || n != b.len() {
335 return Err(SklearsError::InvalidInput(
336 "Matrix dimensions must match for linear system solve".to_string(),
337 ));
338 }
339
340 let mut aug = Array2::zeros((n, n + 1));
342 for i in 0..n {
343 for j in 0..n {
344 aug[[i, j]] = a[[i, j]];
345 }
346 aug[[i, n]] = b[i];
347 }
348
349 for k in 0..n {
351 let mut max_row = k;
353 for i in (k + 1)..n {
354 if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
355 max_row = i;
356 }
357 }
358
359 if max_row != k {
361 for j in 0..=n {
362 let temp = aug[[k, j]];
363 aug[[k, j]] = aug[[max_row, j]];
364 aug[[max_row, j]] = temp;
365 }
366 }
367
368 if aug[[k, k]].abs() < 1e-12 {
370 return Err(SklearsError::InvalidInput(
371 "Matrix is singular or nearly singular".to_string(),
372 ));
373 }
374
375 for i in (k + 1)..n {
377 let factor = aug[[i, k]] / aug[[k, k]];
378 for j in k..=n {
379 aug[[i, j]] -= factor * aug[[k, j]];
380 }
381 }
382 }
383
384 let mut x = Array1::zeros(n);
386 for i in (0..n).rev() {
387 let mut sum = aug[[i, n]];
388 for j in (i + 1)..n {
389 sum -= aug[[i, j]] * x[j];
390 }
391 x[i] = sum / aug[[i, i]];
392 }
393
394 Ok(x)
395 }
396
397 fn compute_svd(
400 &self,
401 x: &Array2<Float>,
402 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
403 let (m, n) = x.dim();
404 let min_dim = m.min(n);
405
406 let xt = x.t();
408
409 if n <= m {
410 let xtx = xt.dot(x);
412 let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
413
414 let mut singular_vals = Array1::zeros(min_dim);
416 let mut valid_indices = Vec::new();
417 for i in 0..eigenvals_v.len() {
418 if eigenvals_v[i] > 1e-12 {
419 singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
420 valid_indices.push(i);
421 if valid_indices.len() >= min_dim {
422 break;
423 }
424 }
425 }
426
427 let mut v = Array2::zeros((n, min_dim));
429 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
430 v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
431 }
432
433 let mut u = Array2::zeros((m, min_dim));
435 for j in 0..valid_indices.len() {
436 let v_col = v.column(j);
437 let xv = x.dot(&v_col);
438 let u_col = &xv / singular_vals[j];
439 u.column_mut(j).assign(&u_col);
440 }
441
442 Ok((u, singular_vals, v.t().to_owned()))
443 } else {
444 let xxt = x.dot(&xt);
446 let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
447
448 let mut singular_vals = Array1::zeros(min_dim);
450 let mut valid_indices = Vec::new();
451 for i in 0..eigenvals_u.len() {
452 if eigenvals_u[i] > 1e-12 {
453 singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
454 valid_indices.push(i);
455 if valid_indices.len() >= min_dim {
456 break;
457 }
458 }
459 }
460
461 let mut u = Array2::zeros((m, min_dim));
463 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
464 u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
465 }
466
467 let mut v = Array2::zeros((n, min_dim));
469 for j in 0..valid_indices.len() {
470 let u_col = u.column(j);
471 let xtu = xt.dot(&u_col);
472 let v_col = &xtu / singular_vals[j];
473 v.column_mut(j).assign(&v_col);
474 }
475
476 Ok((u, singular_vals, v.t().to_owned()))
477 }
478 }
479
480 fn compute_eigendecomposition_svd(
482 &self,
483 matrix: &Array2<Float>,
484 ) -> Result<(Array1<Float>, Array2<Float>)> {
485 let n = matrix.nrows();
486
487 if n != matrix.ncols() {
488 return Err(SklearsError::InvalidInput(
489 "Matrix must be square for eigendecomposition".to_string(),
490 ));
491 }
492
493 let mut eigenvals = Array1::zeros(n);
494 let mut eigenvecs = Array2::zeros((n, n));
495
496 let mut deflated_matrix = matrix.clone();
498
499 for k in 0..n {
500 let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
502
503 eigenvals[k] = eigenval;
504 eigenvecs.column_mut(k).assign(&eigenvec);
505
506 for i in 0..n {
508 for j in 0..n {
509 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
510 }
511 }
512 }
513
514 let mut indices: Vec<usize> = (0..n).collect();
516 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
517
518 let mut sorted_eigenvals = Array1::zeros(n);
519 let mut sorted_eigenvecs = Array2::zeros((n, n));
520
521 for (new_idx, &old_idx) in indices.iter().enumerate() {
522 sorted_eigenvals[new_idx] = eigenvals[old_idx];
523 sorted_eigenvecs
524 .column_mut(new_idx)
525 .assign(&eigenvecs.column(old_idx));
526 }
527
528 Ok((sorted_eigenvals, sorted_eigenvecs))
529 }
530
531 fn power_iteration_svd(
533 &self,
534 matrix: &Array2<Float>,
535 max_iter: usize,
536 tol: Float,
537 ) -> Result<(Float, Array1<Float>)> {
538 let n = matrix.nrows();
539
540 let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
542
543 let norm = v.dot(&v).sqrt();
545 if norm < 1e-10 {
546 return Err(SklearsError::InvalidInput(
547 "Initial vector has zero norm".to_string(),
548 ));
549 }
550 v /= norm;
551
552 let mut eigenval = 0.0;
553
554 for _iter in 0..max_iter {
555 let w = matrix.dot(&v);
557
558 let new_eigenval = v.dot(&w);
560
561 let w_norm = w.dot(&w).sqrt();
563 if w_norm < 1e-10 {
564 break;
565 }
566 let new_v = w / w_norm;
567
568 let eigenval_change = (new_eigenval - eigenval).abs();
570 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
571
572 if eigenval_change < tol && vector_change < tol {
573 return Ok((new_eigenval, new_v));
574 }
575
576 eigenval = new_eigenval;
577 v = new_v;
578 }
579
580 Ok((eigenval, v))
581 }
582}
583
584impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
585 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
586 let weights = self
587 .weights_
588 .as_ref()
589 .ok_or_else(|| SklearsError::NotFitted {
590 operation: "predict".to_string(),
591 })?;
592
593 let feature_transformer =
594 self.feature_transformer_
595 .as_ref()
596 .ok_or_else(|| SklearsError::NotFitted {
597 operation: "predict".to_string(),
598 })?;
599
600 let x_transformed = feature_transformer.transform(x)?;
602
603 let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
605 let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
606
607 let predictions_f64 = x_f64.dot(&weights_f64);
609
610 let predictions =
612 Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
613
614 Ok(predictions)
615 }
616}
617
618#[derive(Debug, Clone)]
622pub struct OnlineKernelRidgeRegression<State = Untrained> {
623 pub base_model: KernelRidgeRegression<State>,
625 pub forgetting_factor: Float,
627 pub update_frequency: usize,
629
630 update_count_: usize,
632 accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
633
634 _state: PhantomData<State>,
635}
636
637impl OnlineKernelRidgeRegression<Untrained> {
638 pub fn new(approximation_method: ApproximationMethod) -> Self {
640 Self {
641 base_model: KernelRidgeRegression::new(approximation_method),
642 forgetting_factor: 0.99,
643 update_frequency: 100,
644 update_count_: 0,
645 accumulated_data_: None,
646 _state: PhantomData,
647 }
648 }
649
650 pub fn forgetting_factor(mut self, factor: Float) -> Self {
652 self.forgetting_factor = factor;
653 self
654 }
655
656 pub fn update_frequency(mut self, frequency: usize) -> Self {
658 self.update_frequency = frequency;
659 self
660 }
661
662 pub fn alpha(mut self, alpha: Float) -> Self {
664 self.base_model = self.base_model.alpha(alpha);
665 self
666 }
667
668 pub fn random_state(mut self, seed: u64) -> Self {
670 self.base_model = self.base_model.random_state(seed);
671 self
672 }
673}
674
675impl Estimator for OnlineKernelRidgeRegression<Untrained> {
676 type Config = ();
677 type Error = SklearsError;
678 type Float = Float;
679
680 fn config(&self) -> &Self::Config {
681 &()
682 }
683}
684
685impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
686 type Fitted = OnlineKernelRidgeRegression<Trained>;
687
688 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
689 let fitted_base = self.base_model.fit(x, y)?;
690
691 Ok(OnlineKernelRidgeRegression {
692 base_model: fitted_base,
693 forgetting_factor: self.forgetting_factor,
694 update_frequency: self.update_frequency,
695 update_count_: 0,
696 accumulated_data_: None,
697 _state: PhantomData,
698 })
699 }
700}
701
702impl OnlineKernelRidgeRegression<Trained> {
703 pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
705 match &self.accumulated_data_ {
707 Some((x_acc, y_acc)) => {
708 let x_combined =
709 scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
710 let y_combined =
711 scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
712 self.accumulated_data_ = Some((x_combined, y_combined));
713 }
714 None => {
715 self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
716 }
717 }
718
719 self.update_count_ += 1;
720
721 if self.update_count_ % self.update_frequency == 0 {
723 if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
724 let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
728 self.base_model = updated_base;
729 self.accumulated_data_ = None;
730 }
731 }
732
733 Ok(self)
734 }
735
736 pub fn update_count(&self) -> usize {
738 self.update_count_
739 }
740}
741
742impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
743 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
744 self.base_model.predict(x)
745 }
746}
747
748pub trait IntoUntrained<T> {
750 fn into_untrained(self) -> T;
751}
752
753impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
754 fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
755 KernelRidgeRegression {
756 approximation_method: self.approximation_method,
757 alpha: self.alpha,
758 solver: self.solver,
759 random_state: self.random_state,
760 weights_: None,
761 feature_transformer_: None,
762 _state: PhantomData,
763 }
764 }
765}
766
767#[allow(non_snake_case)]
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use scirs2_core::ndarray::array;
772
773 #[test]
774 fn test_kernel_ridge_regression_rff() {
775 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
776 let y = array![1.0, 4.0, 9.0, 16.0];
777
778 let approximation = ApproximationMethod::RandomFourierFeatures {
779 n_components: 50,
780 gamma: 0.1,
781 };
782
783 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
784 let fitted = krr.fit(&x, &y).unwrap();
785 let predictions = fitted.predict(&x).unwrap();
786
787 assert_eq!(predictions.len(), 4);
788 for pred in predictions.iter() {
790 assert!(pred.is_finite());
791 }
792 }
793
794 #[test]
795 fn test_kernel_ridge_regression_nystroem() {
796 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
797 let y = array![1.0, 2.0, 3.0];
798
799 let approximation = ApproximationMethod::Nystroem {
800 kernel: Kernel::Rbf { gamma: 1.0 },
801 n_components: 3,
802 sampling_strategy: SamplingStrategy::Random,
803 };
804
805 let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
806 let fitted = krr.fit(&x, &y).unwrap();
807 let predictions = fitted.predict(&x).unwrap();
808
809 assert_eq!(predictions.len(), 3);
810 }
811
812 #[test]
813 fn test_kernel_ridge_regression_fastfood() {
814 let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
815 let y = array![1.0, 2.0];
816
817 let approximation = ApproximationMethod::Fastfood {
818 n_components: 8,
819 gamma: 0.5,
820 };
821
822 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
823 let fitted = krr.fit(&x, &y).unwrap();
824 let predictions = fitted.predict(&x).unwrap();
825
826 assert_eq!(predictions.len(), 2);
827 }
828
829 #[test]
830 fn test_different_solvers() {
831 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
832 let y = array![1.0, 2.0, 3.0];
833
834 let approximation = ApproximationMethod::RandomFourierFeatures {
835 n_components: 10,
836 gamma: 1.0,
837 };
838
839 let krr_direct = KernelRidgeRegression::new(approximation.clone())
841 .solver(Solver::Direct)
842 .alpha(0.1);
843 let fitted_direct = krr_direct.fit(&x, &y).unwrap();
844 let pred_direct = fitted_direct.predict(&x).unwrap();
845
846 let krr_svd = KernelRidgeRegression::new(approximation.clone())
848 .solver(Solver::SVD)
849 .alpha(0.1);
850 let fitted_svd = krr_svd.fit(&x, &y).unwrap();
851 let pred_svd = fitted_svd.predict(&x).unwrap();
852
853 let krr_cg = KernelRidgeRegression::new(approximation)
855 .solver(Solver::ConjugateGradient {
856 max_iter: 100,
857 tol: 1e-6,
858 })
859 .alpha(0.1);
860 let fitted_cg = krr_cg.fit(&x, &y).unwrap();
861 let pred_cg = fitted_cg.predict(&x).unwrap();
862
863 assert_eq!(pred_direct.len(), 3);
864 assert_eq!(pred_svd.len(), 3);
865 assert_eq!(pred_cg.len(), 3);
866 }
867
868 #[test]
869 fn test_online_kernel_ridge_regression() {
870 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
871 let y_initial = array![1.0, 2.0];
872 let x_new = array![[3.0, 4.0], [4.0, 5.0]];
873 let y_new = array![3.0, 4.0];
874
875 let approximation = ApproximationMethod::RandomFourierFeatures {
876 n_components: 20,
877 gamma: 0.5,
878 };
879
880 let online_krr = OnlineKernelRidgeRegression::new(approximation)
881 .alpha(0.1)
882 .update_frequency(2);
883
884 let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
885 let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
886
887 assert_eq!(updated.update_count(), 1);
888
889 let predictions = updated.predict(&x_initial).unwrap();
890 assert_eq!(predictions.len(), 2);
891 }
892
893 #[test]
894 fn test_reproducibility() {
895 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
896 let y = array![1.0, 2.0, 3.0];
897
898 let approximation = ApproximationMethod::RandomFourierFeatures {
899 n_components: 10,
900 gamma: 1.0,
901 };
902
903 let krr1 = KernelRidgeRegression::new(approximation.clone())
904 .alpha(0.1)
905 .random_state(42);
906 let fitted1 = krr1.fit(&x, &y).unwrap();
907 let pred1 = fitted1.predict(&x).unwrap();
908
909 let krr2 = KernelRidgeRegression::new(approximation)
910 .alpha(0.1)
911 .random_state(42);
912 let fitted2 = krr2.fit(&x, &y).unwrap();
913 let pred2 = fitted2.predict(&x).unwrap();
914
915 assert_eq!(pred1.len(), pred2.len());
916 for i in 0..pred1.len() {
917 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
918 }
919 }
920}