1use crate::{
10 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_linalg::compat::ArrayLinalgExt;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
15use scirs2_core::random::thread_rng;
16use sklears_core::error::{Result, SklearsError};
17use sklears_core::prelude::{Estimator, Fit, Float, Predict};
18use std::marker::PhantomData;
19
20use super::core_types::*;
21
22#[derive(Debug, Clone)]
27pub struct KernelRidgeRegression<State = Untrained> {
28 pub approximation_method: ApproximationMethod,
29 pub alpha: Float,
30 pub solver: Solver,
31 pub random_state: Option<u64>,
32
33 pub(crate) weights_: Option<Array1<Float>>,
35 pub(crate) feature_transformer_: Option<FeatureTransformer>,
36
37 pub(crate) _state: PhantomData<State>,
38}
39
40impl KernelRidgeRegression<Untrained> {
41 pub fn new(approximation_method: ApproximationMethod) -> Self {
43 Self {
44 approximation_method,
45 alpha: 1.0,
46 solver: Solver::Direct,
47 random_state: None,
48 weights_: None,
49 feature_transformer_: None,
50 _state: PhantomData,
51 }
52 }
53
54 pub fn alpha(mut self, alpha: Float) -> Self {
56 self.alpha = alpha;
57 self
58 }
59
60 pub fn solver(mut self, solver: Solver) -> Self {
62 self.solver = solver;
63 self
64 }
65
66 pub fn random_state(mut self, seed: u64) -> Self {
68 self.random_state = Some(seed);
69 self
70 }
71}
72
73impl Estimator for KernelRidgeRegression<Untrained> {
74 type Config = ();
75 type Error = SklearsError;
76 type Float = Float;
77
78 fn config(&self) -> &Self::Config {
79 &()
80 }
81}
82
83impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
84 type Fitted = KernelRidgeRegression<Trained>;
85
86 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
87 let (n_samples, _) = x.dim();
88
89 if y.len() != n_samples {
90 return Err(SklearsError::InvalidInput(
91 "Number of samples in X and y must match".to_string(),
92 ));
93 }
94
95 let feature_transformer = self.fit_feature_transformer(x)?;
97
98 let x_transformed = feature_transformer.transform(x)?;
100
101 let weights = self.solve_ridge_regression(&x_transformed, y)?;
103
104 Ok(KernelRidgeRegression {
105 approximation_method: self.approximation_method,
106 alpha: self.alpha,
107 solver: self.solver,
108 random_state: self.random_state,
109 weights_: Some(weights),
110 feature_transformer_: Some(feature_transformer),
111 _state: PhantomData,
112 })
113 }
114}
115
116impl KernelRidgeRegression<Untrained> {
117 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
119 match &self.approximation_method {
120 ApproximationMethod::Nystroem {
121 kernel,
122 n_components,
123 sampling_strategy,
124 } => {
125 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
126 .sampling_strategy(sampling_strategy.clone());
127
128 if let Some(seed) = self.random_state {
129 nystroem = nystroem.random_state(seed);
130 }
131
132 let fitted = nystroem.fit(x, &())?;
133 Ok(FeatureTransformer::Nystroem(fitted))
134 }
135 ApproximationMethod::RandomFourierFeatures {
136 n_components,
137 gamma,
138 } => {
139 let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
140
141 if let Some(seed) = self.random_state {
142 rbf_sampler = rbf_sampler.random_state(seed);
143 }
144
145 let fitted = rbf_sampler.fit(x, &())?;
146 Ok(FeatureTransformer::RBFSampler(fitted))
147 }
148 ApproximationMethod::StructuredRandomFeatures {
149 n_components,
150 gamma,
151 } => {
152 let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
153
154 if let Some(seed) = self.random_state {
155 structured_rff = structured_rff.random_state(seed);
156 }
157
158 let fitted = structured_rff.fit(x, &())?;
159 Ok(FeatureTransformer::StructuredRFF(fitted))
160 }
161 ApproximationMethod::Fastfood {
162 n_components,
163 gamma,
164 } => {
165 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
166
167 if let Some(seed) = self.random_state {
168 fastfood = fastfood.random_state(seed);
169 }
170
171 let fitted = fastfood.fit(x, &())?;
172 Ok(FeatureTransformer::Fastfood(fitted))
173 }
174 }
175 }
176
177 fn solve_ridge_regression(
179 &self,
180 x: &Array2<Float>,
181 y: &Array1<Float>,
182 ) -> Result<Array1<Float>> {
183 let (_n_samples, _n_features) = x.dim();
184
185 match &self.solver {
186 Solver::Direct => self.solve_direct(x, y),
187 Solver::SVD => self.solve_svd(x, y),
188 Solver::ConjugateGradient { max_iter, tol } => {
189 self.solve_conjugate_gradient(x, y, *max_iter, *tol)
190 }
191 }
192 }
193
194 fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
196 let (_, n_features) = x.dim();
197
198 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
200 let y_f64 = Array1::from_vec(y.iter().copied().collect());
201
202 let weights_f64 = {
225 let gram_matrix = x_f64.t().dot(&x_f64);
227 let mut regularized_gram = gram_matrix;
228 for i in 0..n_features {
229 regularized_gram[[i, i]] += self.alpha;
230 }
231 let xty_f64 = x_f64.t().dot(&y_f64);
232
233 regularized_gram
235 .solve(&xty_f64)
236 .map_err(|e| SklearsError::InvalidParameter {
237 name: "regularization".to_string(),
238 reason: format!("Linear system solving failed: {:?}", e),
239 })?
240 };
241
242 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
244 Ok(weights)
245 }
246
247 fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
249 let (_n_samples, _n_features) = x.dim();
253
254 let (u, s, vt) = self.compute_svd(x)?;
256
257 let threshold = 1e-12;
259 let mut s_reg_inv = Array1::zeros(s.len());
260 for i in 0..s.len() {
261 if s[i] > threshold {
262 s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
263 }
264 }
265
266 let ut_y = u.t().dot(y);
268 let mut temp = Array1::zeros(s.len());
269 for i in 0..s.len() {
270 temp[i] = s_reg_inv[i] * ut_y[i];
271 }
272
273 let weights = vt.t().dot(&temp);
274 Ok(weights)
275 }
276
277 fn solve_conjugate_gradient(
279 &self,
280 x: &Array2<Float>,
281 y: &Array1<Float>,
282 max_iter: usize,
283 tol: Float,
284 ) -> Result<Array1<Float>> {
285 let (_n_samples, n_features) = x.dim();
286
287 let mut w = Array1::zeros(n_features);
289
290 let xty = x.t().dot(y);
292 let mut r = xty.clone();
293
294 let mut p = r.clone();
295 let mut rsold = r.dot(&r);
296
297 for _iter in 0..max_iter {
298 let xtxp = x.t().dot(&x.dot(&p));
300 let mut ap = xtxp;
301 for i in 0..n_features {
302 ap[i] += self.alpha * p[i];
303 }
304
305 let alpha_cg = rsold / p.dot(&ap);
307
308 w = w + alpha_cg * &p;
310
311 r = r - alpha_cg * ≈
313
314 let rsnew = r.dot(&r);
315
316 if rsnew.sqrt() < tol {
318 break;
319 }
320
321 let beta = rsnew / rsold;
323 p = &r + beta * &p;
324 rsold = rsnew;
325 }
326
327 Ok(w)
328 }
329
330 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
332 let n = a.nrows();
333 if n != a.ncols() || n != b.len() {
334 return Err(SklearsError::InvalidInput(
335 "Matrix dimensions must match for linear system solve".to_string(),
336 ));
337 }
338
339 let mut aug = Array2::zeros((n, n + 1));
341 for i in 0..n {
342 for j in 0..n {
343 aug[[i, j]] = a[[i, j]];
344 }
345 aug[[i, n]] = b[i];
346 }
347
348 for k in 0..n {
350 let mut max_row = k;
352 for i in (k + 1)..n {
353 if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
354 max_row = i;
355 }
356 }
357
358 if max_row != k {
360 for j in 0..=n {
361 let temp = aug[[k, j]];
362 aug[[k, j]] = aug[[max_row, j]];
363 aug[[max_row, j]] = temp;
364 }
365 }
366
367 if aug[[k, k]].abs() < 1e-12 {
369 return Err(SklearsError::InvalidInput(
370 "Matrix is singular or nearly singular".to_string(),
371 ));
372 }
373
374 for i in (k + 1)..n {
376 let factor = aug[[i, k]] / aug[[k, k]];
377 for j in k..=n {
378 aug[[i, j]] -= factor * aug[[k, j]];
379 }
380 }
381 }
382
383 let mut x = Array1::zeros(n);
385 for i in (0..n).rev() {
386 let mut sum = aug[[i, n]];
387 for j in (i + 1)..n {
388 sum -= aug[[i, j]] * x[j];
389 }
390 x[i] = sum / aug[[i, i]];
391 }
392
393 Ok(x)
394 }
395
396 fn compute_svd(
399 &self,
400 x: &Array2<Float>,
401 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
402 let (m, n) = x.dim();
403 let min_dim = m.min(n);
404
405 let xt = x.t();
407
408 if n <= m {
409 let xtx = xt.dot(x);
411 let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
412
413 let mut singular_vals = Array1::zeros(min_dim);
415 let mut valid_indices = Vec::new();
416 for i in 0..eigenvals_v.len() {
417 if eigenvals_v[i] > 1e-12 {
418 singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
419 valid_indices.push(i);
420 if valid_indices.len() >= min_dim {
421 break;
422 }
423 }
424 }
425
426 let mut v = Array2::zeros((n, min_dim));
428 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
429 v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
430 }
431
432 let mut u = Array2::zeros((m, min_dim));
434 for j in 0..valid_indices.len() {
435 let v_col = v.column(j);
436 let xv = x.dot(&v_col);
437 let u_col = &xv / singular_vals[j];
438 u.column_mut(j).assign(&u_col);
439 }
440
441 Ok((u, singular_vals, v.t().to_owned()))
442 } else {
443 let xxt = x.dot(&xt);
445 let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
446
447 let mut singular_vals = Array1::zeros(min_dim);
449 let mut valid_indices = Vec::new();
450 for i in 0..eigenvals_u.len() {
451 if eigenvals_u[i] > 1e-12 {
452 singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
453 valid_indices.push(i);
454 if valid_indices.len() >= min_dim {
455 break;
456 }
457 }
458 }
459
460 let mut u = Array2::zeros((m, min_dim));
462 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
463 u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
464 }
465
466 let mut v = Array2::zeros((n, min_dim));
468 for j in 0..valid_indices.len() {
469 let u_col = u.column(j);
470 let xtu = xt.dot(&u_col);
471 let v_col = &xtu / singular_vals[j];
472 v.column_mut(j).assign(&v_col);
473 }
474
475 Ok((u, singular_vals, v.t().to_owned()))
476 }
477 }
478
479 fn compute_eigendecomposition_svd(
481 &self,
482 matrix: &Array2<Float>,
483 ) -> Result<(Array1<Float>, Array2<Float>)> {
484 let n = matrix.nrows();
485
486 if n != matrix.ncols() {
487 return Err(SklearsError::InvalidInput(
488 "Matrix must be square for eigendecomposition".to_string(),
489 ));
490 }
491
492 let mut eigenvals = Array1::zeros(n);
493 let mut eigenvecs = Array2::zeros((n, n));
494
495 let mut deflated_matrix = matrix.clone();
497
498 for k in 0..n {
499 let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
501
502 eigenvals[k] = eigenval;
503 eigenvecs.column_mut(k).assign(&eigenvec);
504
505 for i in 0..n {
507 for j in 0..n {
508 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
509 }
510 }
511 }
512
513 let mut indices: Vec<usize> = (0..n).collect();
515 indices.sort_by(|&i, &j| {
516 eigenvals[j]
517 .partial_cmp(&eigenvals[i])
518 .expect("operation should succeed")
519 });
520
521 let mut sorted_eigenvals = Array1::zeros(n);
522 let mut sorted_eigenvecs = Array2::zeros((n, n));
523
524 for (new_idx, &old_idx) in indices.iter().enumerate() {
525 sorted_eigenvals[new_idx] = eigenvals[old_idx];
526 sorted_eigenvecs
527 .column_mut(new_idx)
528 .assign(&eigenvecs.column(old_idx));
529 }
530
531 Ok((sorted_eigenvals, sorted_eigenvecs))
532 }
533
534 fn power_iteration_svd(
536 &self,
537 matrix: &Array2<Float>,
538 max_iter: usize,
539 tol: Float,
540 ) -> Result<(Float, Array1<Float>)> {
541 let n = matrix.nrows();
542
543 let mut v = Array1::from_shape_fn(n, |_| thread_rng().random::<Float>() - 0.5);
545
546 let norm = v.dot(&v).sqrt();
548 if norm < 1e-10 {
549 return Err(SklearsError::InvalidInput(
550 "Initial vector has zero norm".to_string(),
551 ));
552 }
553 v /= norm;
554
555 let mut eigenval = 0.0;
556
557 for _iter in 0..max_iter {
558 let w = matrix.dot(&v);
560
561 let new_eigenval = v.dot(&w);
563
564 let w_norm = w.dot(&w).sqrt();
566 if w_norm < 1e-10 {
567 break;
568 }
569 let new_v = w / w_norm;
570
571 let eigenval_change = (new_eigenval - eigenval).abs();
573 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
574
575 if eigenval_change < tol && vector_change < tol {
576 return Ok((new_eigenval, new_v));
577 }
578
579 eigenval = new_eigenval;
580 v = new_v;
581 }
582
583 Ok((eigenval, v))
584 }
585}
586
587impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
588 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
589 let weights = self
590 .weights_
591 .as_ref()
592 .ok_or_else(|| SklearsError::NotFitted {
593 operation: "predict".to_string(),
594 })?;
595
596 let feature_transformer =
597 self.feature_transformer_
598 .as_ref()
599 .ok_or_else(|| SklearsError::NotFitted {
600 operation: "predict".to_string(),
601 })?;
602
603 let x_transformed = feature_transformer.transform(x)?;
605
606 let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
608 let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
609
610 let predictions_f64 = x_f64.dot(&weights_f64);
612
613 let predictions =
615 Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
616
617 Ok(predictions)
618 }
619}
620
621#[derive(Debug, Clone)]
625pub struct OnlineKernelRidgeRegression<State = Untrained> {
626 pub base_model: KernelRidgeRegression<State>,
628 pub forgetting_factor: Float,
630 pub update_frequency: usize,
632
633 update_count_: usize,
635 accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
636
637 _state: PhantomData<State>,
638}
639
640impl OnlineKernelRidgeRegression<Untrained> {
641 pub fn new(approximation_method: ApproximationMethod) -> Self {
643 Self {
644 base_model: KernelRidgeRegression::new(approximation_method),
645 forgetting_factor: 0.99,
646 update_frequency: 100,
647 update_count_: 0,
648 accumulated_data_: None,
649 _state: PhantomData,
650 }
651 }
652
653 pub fn forgetting_factor(mut self, factor: Float) -> Self {
655 self.forgetting_factor = factor;
656 self
657 }
658
659 pub fn update_frequency(mut self, frequency: usize) -> Self {
661 self.update_frequency = frequency;
662 self
663 }
664
665 pub fn alpha(mut self, alpha: Float) -> Self {
667 self.base_model = self.base_model.alpha(alpha);
668 self
669 }
670
671 pub fn random_state(mut self, seed: u64) -> Self {
673 self.base_model = self.base_model.random_state(seed);
674 self
675 }
676}
677
678impl Estimator for OnlineKernelRidgeRegression<Untrained> {
679 type Config = ();
680 type Error = SklearsError;
681 type Float = Float;
682
683 fn config(&self) -> &Self::Config {
684 &()
685 }
686}
687
688impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
689 type Fitted = OnlineKernelRidgeRegression<Trained>;
690
691 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
692 let fitted_base = self.base_model.fit(x, y)?;
693
694 Ok(OnlineKernelRidgeRegression {
695 base_model: fitted_base,
696 forgetting_factor: self.forgetting_factor,
697 update_frequency: self.update_frequency,
698 update_count_: 0,
699 accumulated_data_: None,
700 _state: PhantomData,
701 })
702 }
703}
704
705impl OnlineKernelRidgeRegression<Trained> {
706 pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
708 match &self.accumulated_data_ {
710 Some((x_acc, y_acc)) => {
711 let x_combined =
712 scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
713 let y_combined =
714 scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
715 self.accumulated_data_ = Some((x_combined, y_combined));
716 }
717 None => {
718 self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
719 }
720 }
721
722 self.update_count_ += 1;
723
724 if self.update_count_ % self.update_frequency == 0 {
726 if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
727 let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
731 self.base_model = updated_base;
732 self.accumulated_data_ = None;
733 }
734 }
735
736 Ok(self)
737 }
738
739 pub fn update_count(&self) -> usize {
741 self.update_count_
742 }
743}
744
745impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
746 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
747 self.base_model.predict(x)
748 }
749}
750
751pub trait IntoUntrained<T> {
753 fn into_untrained(self) -> T;
754}
755
756impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
757 fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
758 KernelRidgeRegression {
759 approximation_method: self.approximation_method,
760 alpha: self.alpha,
761 solver: self.solver,
762 random_state: self.random_state,
763 weights_: None,
764 feature_transformer_: None,
765 _state: PhantomData,
766 }
767 }
768}
769
770#[allow(non_snake_case)]
771#[cfg(test)]
772mod tests {
773 use super::*;
774 use scirs2_core::ndarray::array;
775
776 #[test]
777 fn test_kernel_ridge_regression_rff() {
778 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
779 let y = array![1.0, 4.0, 9.0, 16.0];
780
781 let approximation = ApproximationMethod::RandomFourierFeatures {
782 n_components: 50,
783 gamma: 0.1,
784 };
785
786 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
787 let fitted = krr.fit(&x, &y).expect("operation should succeed");
788 let predictions = fitted.predict(&x).expect("operation should succeed");
789
790 assert_eq!(predictions.len(), 4);
791 for pred in predictions.iter() {
793 assert!(pred.is_finite());
794 }
795 }
796
797 #[test]
798 fn test_kernel_ridge_regression_nystroem() {
799 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
800 let y = array![1.0, 2.0, 3.0];
801
802 let approximation = ApproximationMethod::Nystroem {
803 kernel: Kernel::Rbf { gamma: 1.0 },
804 n_components: 3,
805 sampling_strategy: SamplingStrategy::Random,
806 };
807
808 let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
809 let fitted = krr.fit(&x, &y).expect("operation should succeed");
810 let predictions = fitted.predict(&x).expect("operation should succeed");
811
812 assert_eq!(predictions.len(), 3);
813 }
814
815 #[test]
816 fn test_kernel_ridge_regression_fastfood() {
817 let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
818 let y = array![1.0, 2.0];
819
820 let approximation = ApproximationMethod::Fastfood {
821 n_components: 8,
822 gamma: 0.5,
823 };
824
825 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
826 let fitted = krr.fit(&x, &y).expect("operation should succeed");
827 let predictions = fitted.predict(&x).expect("operation should succeed");
828
829 assert_eq!(predictions.len(), 2);
830 }
831
832 #[test]
833 fn test_different_solvers() {
834 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
835 let y = array![1.0, 2.0, 3.0];
836
837 let approximation = ApproximationMethod::RandomFourierFeatures {
838 n_components: 10,
839 gamma: 1.0,
840 };
841
842 let krr_direct = KernelRidgeRegression::new(approximation.clone())
844 .solver(Solver::Direct)
845 .alpha(0.1);
846 let fitted_direct = krr_direct.fit(&x, &y).expect("operation should succeed");
847 let pred_direct = fitted_direct.predict(&x).expect("operation should succeed");
848
849 let krr_svd = KernelRidgeRegression::new(approximation.clone())
851 .solver(Solver::SVD)
852 .alpha(0.1);
853 let fitted_svd = krr_svd.fit(&x, &y).expect("operation should succeed");
854 let pred_svd = fitted_svd.predict(&x).expect("operation should succeed");
855
856 let krr_cg = KernelRidgeRegression::new(approximation)
858 .solver(Solver::ConjugateGradient {
859 max_iter: 100,
860 tol: 1e-6,
861 })
862 .alpha(0.1);
863 let fitted_cg = krr_cg.fit(&x, &y).expect("operation should succeed");
864 let pred_cg = fitted_cg.predict(&x).expect("operation should succeed");
865
866 assert_eq!(pred_direct.len(), 3);
867 assert_eq!(pred_svd.len(), 3);
868 assert_eq!(pred_cg.len(), 3);
869 }
870
871 #[test]
872 fn test_online_kernel_ridge_regression() {
873 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
874 let y_initial = array![1.0, 2.0];
875 let x_new = array![[3.0, 4.0], [4.0, 5.0]];
876 let y_new = array![3.0, 4.0];
877
878 let approximation = ApproximationMethod::RandomFourierFeatures {
879 n_components: 20,
880 gamma: 0.5,
881 };
882
883 let online_krr = OnlineKernelRidgeRegression::new(approximation)
884 .alpha(0.1)
885 .update_frequency(2);
886
887 let fitted = online_krr
888 .fit(&x_initial, &y_initial)
889 .expect("operation should succeed");
890 let updated = fitted
891 .partial_fit(&x_new, &y_new)
892 .expect("operation should succeed");
893
894 assert_eq!(updated.update_count(), 1);
895
896 let predictions = updated
897 .predict(&x_initial)
898 .expect("operation should succeed");
899 assert_eq!(predictions.len(), 2);
900 }
901
902 #[test]
903 fn test_reproducibility() {
904 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
905 let y = array![1.0, 2.0, 3.0];
906
907 let approximation = ApproximationMethod::RandomFourierFeatures {
908 n_components: 10,
909 gamma: 1.0,
910 };
911
912 let krr1 = KernelRidgeRegression::new(approximation.clone())
913 .alpha(0.1)
914 .random_state(42);
915 let fitted1 = krr1.fit(&x, &y).expect("operation should succeed");
916 let pred1 = fitted1.predict(&x).expect("operation should succeed");
917
918 let krr2 = KernelRidgeRegression::new(approximation)
919 .alpha(0.1)
920 .random_state(42);
921 let fitted2 = krr2.fit(&x, &y).expect("operation should succeed");
922 let pred2 = fitted2.predict(&x).expect("operation should succeed");
923
924 assert_eq!(pred1.len(), pred2.len());
925 for i in 0..pred1.len() {
926 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
927 }
928 }
929}