1use crate::{
10 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::thread_rng;
15use scirs2_core::random::Rng;
16use sklears_core::error::{Result, SklearsError};
17use sklears_core::prelude::{Estimator, Fit, Float, Predict};
18use std::marker::PhantomData;
19
20use super::core_types::*;
21
22#[derive(Debug, Clone)]
27pub struct KernelRidgeRegression<State = Untrained> {
28 pub approximation_method: ApproximationMethod,
29 pub alpha: Float,
30 pub solver: Solver,
31 pub random_state: Option<u64>,
32
33 pub(crate) weights_: Option<Array1<Float>>,
35 pub(crate) feature_transformer_: Option<FeatureTransformer>,
36
37 pub(crate) _state: PhantomData<State>,
38}
39
40impl KernelRidgeRegression<Untrained> {
41 pub fn new(approximation_method: ApproximationMethod) -> Self {
43 Self {
44 approximation_method,
45 alpha: 1.0,
46 solver: Solver::Direct,
47 random_state: None,
48 weights_: None,
49 feature_transformer_: None,
50 _state: PhantomData,
51 }
52 }
53
54 pub fn alpha(mut self, alpha: Float) -> Self {
56 self.alpha = alpha;
57 self
58 }
59
60 pub fn solver(mut self, solver: Solver) -> Self {
62 self.solver = solver;
63 self
64 }
65
66 pub fn random_state(mut self, seed: u64) -> Self {
68 self.random_state = Some(seed);
69 self
70 }
71}
72
73impl Estimator for KernelRidgeRegression<Untrained> {
74 type Config = ();
75 type Error = SklearsError;
76 type Float = Float;
77
78 fn config(&self) -> &Self::Config {
79 &()
80 }
81}
82
83impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
84 type Fitted = KernelRidgeRegression<Trained>;
85
86 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
87 let (n_samples, _) = x.dim();
88
89 if y.len() != n_samples {
90 return Err(SklearsError::InvalidInput(
91 "Number of samples in X and y must match".to_string(),
92 ));
93 }
94
95 let feature_transformer = self.fit_feature_transformer(x)?;
97
98 let x_transformed = feature_transformer.transform(x)?;
100
101 let weights = self.solve_ridge_regression(&x_transformed, y)?;
103
104 Ok(KernelRidgeRegression {
105 approximation_method: self.approximation_method,
106 alpha: self.alpha,
107 solver: self.solver,
108 random_state: self.random_state,
109 weights_: Some(weights),
110 feature_transformer_: Some(feature_transformer),
111 _state: PhantomData,
112 })
113 }
114}
115
116impl KernelRidgeRegression<Untrained> {
117 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
119 match &self.approximation_method {
120 ApproximationMethod::Nystroem {
121 kernel,
122 n_components,
123 sampling_strategy,
124 } => {
125 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
126 .sampling_strategy(sampling_strategy.clone());
127
128 if let Some(seed) = self.random_state {
129 nystroem = nystroem.random_state(seed);
130 }
131
132 let fitted = nystroem.fit(x, &())?;
133 Ok(FeatureTransformer::Nystroem(fitted))
134 }
135 ApproximationMethod::RandomFourierFeatures {
136 n_components,
137 gamma,
138 } => {
139 let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
140
141 if let Some(seed) = self.random_state {
142 rbf_sampler = rbf_sampler.random_state(seed);
143 }
144
145 let fitted = rbf_sampler.fit(x, &())?;
146 Ok(FeatureTransformer::RBFSampler(fitted))
147 }
148 ApproximationMethod::StructuredRandomFeatures {
149 n_components,
150 gamma,
151 } => {
152 let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
153
154 if let Some(seed) = self.random_state {
155 structured_rff = structured_rff.random_state(seed);
156 }
157
158 let fitted = structured_rff.fit(x, &())?;
159 Ok(FeatureTransformer::StructuredRFF(fitted))
160 }
161 ApproximationMethod::Fastfood {
162 n_components,
163 gamma,
164 } => {
165 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
166
167 if let Some(seed) = self.random_state {
168 fastfood = fastfood.random_state(seed);
169 }
170
171 let fitted = fastfood.fit(x, &())?;
172 Ok(FeatureTransformer::Fastfood(fitted))
173 }
174 }
175 }
176
177 fn solve_ridge_regression(
179 &self,
180 x: &Array2<Float>,
181 y: &Array1<Float>,
182 ) -> Result<Array1<Float>> {
183 let (_n_samples, _n_features) = x.dim();
184
185 match &self.solver {
186 Solver::Direct => self.solve_direct(x, y),
187 Solver::SVD => self.solve_svd(x, y),
188 Solver::ConjugateGradient { max_iter, tol } => {
189 self.solve_conjugate_gradient(x, y, *max_iter, *tol)
190 }
191 }
192 }
193
194 fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
196 let (_, n_features) = x.dim();
197
198 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
200 let y_f64 = Array1::from_vec(y.iter().copied().collect());
201
202 let weights_f64 = {
225 let gram_matrix = x_f64.t().dot(&x_f64);
227 let mut regularized_gram = gram_matrix;
228 for i in 0..n_features {
229 regularized_gram[[i, i]] += self.alpha;
230 }
231 let xty_f64 = x_f64.t().dot(&y_f64);
232
233 regularized_gram
235 .solve(&xty_f64)
236 .map_err(|e| SklearsError::InvalidParameter {
237 name: "regularization".to_string(),
238 reason: format!("Linear system solving failed: {:?}", e),
239 })?
240 };
241
242 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
244 Ok(weights)
245 }
246
247 fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
249 let (_n_samples, _n_features) = x.dim();
253
254 let (u, s, vt) = self.compute_svd(x)?;
256
257 let threshold = 1e-12;
259 let mut s_reg_inv = Array1::zeros(s.len());
260 for i in 0..s.len() {
261 if s[i] > threshold {
262 s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
263 }
264 }
265
266 let ut_y = u.t().dot(y);
268 let mut temp = Array1::zeros(s.len());
269 for i in 0..s.len() {
270 temp[i] = s_reg_inv[i] * ut_y[i];
271 }
272
273 let weights = vt.t().dot(&temp);
274 Ok(weights)
275 }
276
277 fn solve_conjugate_gradient(
279 &self,
280 x: &Array2<Float>,
281 y: &Array1<Float>,
282 max_iter: usize,
283 tol: Float,
284 ) -> Result<Array1<Float>> {
285 let (_n_samples, n_features) = x.dim();
286
287 let mut w = Array1::zeros(n_features);
289
290 let xty = x.t().dot(y);
292 let mut r = xty.clone();
293
294 let mut p = r.clone();
295 let mut rsold = r.dot(&r);
296
297 for _iter in 0..max_iter {
298 let xtxp = x.t().dot(&x.dot(&p));
300 let mut ap = xtxp;
301 for i in 0..n_features {
302 ap[i] += self.alpha * p[i];
303 }
304
305 let alpha_cg = rsold / p.dot(&ap);
307
308 w = w + alpha_cg * &p;
310
311 r = r - alpha_cg * ≈
313
314 let rsnew = r.dot(&r);
315
316 if rsnew.sqrt() < tol {
318 break;
319 }
320
321 let beta = rsnew / rsold;
323 p = &r + beta * &p;
324 rsold = rsnew;
325 }
326
327 Ok(w)
328 }
329
330 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
332 let n = a.nrows();
333 if n != a.ncols() || n != b.len() {
334 return Err(SklearsError::InvalidInput(
335 "Matrix dimensions must match for linear system solve".to_string(),
336 ));
337 }
338
339 let mut aug = Array2::zeros((n, n + 1));
341 for i in 0..n {
342 for j in 0..n {
343 aug[[i, j]] = a[[i, j]];
344 }
345 aug[[i, n]] = b[i];
346 }
347
348 for k in 0..n {
350 let mut max_row = k;
352 for i in (k + 1)..n {
353 if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
354 max_row = i;
355 }
356 }
357
358 if max_row != k {
360 for j in 0..=n {
361 let temp = aug[[k, j]];
362 aug[[k, j]] = aug[[max_row, j]];
363 aug[[max_row, j]] = temp;
364 }
365 }
366
367 if aug[[k, k]].abs() < 1e-12 {
369 return Err(SklearsError::InvalidInput(
370 "Matrix is singular or nearly singular".to_string(),
371 ));
372 }
373
374 for i in (k + 1)..n {
376 let factor = aug[[i, k]] / aug[[k, k]];
377 for j in k..=n {
378 aug[[i, j]] -= factor * aug[[k, j]];
379 }
380 }
381 }
382
383 let mut x = Array1::zeros(n);
385 for i in (0..n).rev() {
386 let mut sum = aug[[i, n]];
387 for j in (i + 1)..n {
388 sum -= aug[[i, j]] * x[j];
389 }
390 x[i] = sum / aug[[i, i]];
391 }
392
393 Ok(x)
394 }
395
396 fn compute_svd(
399 &self,
400 x: &Array2<Float>,
401 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
402 let (m, n) = x.dim();
403 let min_dim = m.min(n);
404
405 let xt = x.t();
407
408 if n <= m {
409 let xtx = xt.dot(x);
411 let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
412
413 let mut singular_vals = Array1::zeros(min_dim);
415 let mut valid_indices = Vec::new();
416 for i in 0..eigenvals_v.len() {
417 if eigenvals_v[i] > 1e-12 {
418 singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
419 valid_indices.push(i);
420 if valid_indices.len() >= min_dim {
421 break;
422 }
423 }
424 }
425
426 let mut v = Array2::zeros((n, min_dim));
428 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
429 v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
430 }
431
432 let mut u = Array2::zeros((m, min_dim));
434 for j in 0..valid_indices.len() {
435 let v_col = v.column(j);
436 let xv = x.dot(&v_col);
437 let u_col = &xv / singular_vals[j];
438 u.column_mut(j).assign(&u_col);
439 }
440
441 Ok((u, singular_vals, v.t().to_owned()))
442 } else {
443 let xxt = x.dot(&xt);
445 let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
446
447 let mut singular_vals = Array1::zeros(min_dim);
449 let mut valid_indices = Vec::new();
450 for i in 0..eigenvals_u.len() {
451 if eigenvals_u[i] > 1e-12 {
452 singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
453 valid_indices.push(i);
454 if valid_indices.len() >= min_dim {
455 break;
456 }
457 }
458 }
459
460 let mut u = Array2::zeros((m, min_dim));
462 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
463 u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
464 }
465
466 let mut v = Array2::zeros((n, min_dim));
468 for j in 0..valid_indices.len() {
469 let u_col = u.column(j);
470 let xtu = xt.dot(&u_col);
471 let v_col = &xtu / singular_vals[j];
472 v.column_mut(j).assign(&v_col);
473 }
474
475 Ok((u, singular_vals, v.t().to_owned()))
476 }
477 }
478
479 fn compute_eigendecomposition_svd(
481 &self,
482 matrix: &Array2<Float>,
483 ) -> Result<(Array1<Float>, Array2<Float>)> {
484 let n = matrix.nrows();
485
486 if n != matrix.ncols() {
487 return Err(SklearsError::InvalidInput(
488 "Matrix must be square for eigendecomposition".to_string(),
489 ));
490 }
491
492 let mut eigenvals = Array1::zeros(n);
493 let mut eigenvecs = Array2::zeros((n, n));
494
495 let mut deflated_matrix = matrix.clone();
497
498 for k in 0..n {
499 let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
501
502 eigenvals[k] = eigenval;
503 eigenvecs.column_mut(k).assign(&eigenvec);
504
505 for i in 0..n {
507 for j in 0..n {
508 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
509 }
510 }
511 }
512
513 let mut indices: Vec<usize> = (0..n).collect();
515 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
516
517 let mut sorted_eigenvals = Array1::zeros(n);
518 let mut sorted_eigenvecs = Array2::zeros((n, n));
519
520 for (new_idx, &old_idx) in indices.iter().enumerate() {
521 sorted_eigenvals[new_idx] = eigenvals[old_idx];
522 sorted_eigenvecs
523 .column_mut(new_idx)
524 .assign(&eigenvecs.column(old_idx));
525 }
526
527 Ok((sorted_eigenvals, sorted_eigenvecs))
528 }
529
530 fn power_iteration_svd(
532 &self,
533 matrix: &Array2<Float>,
534 max_iter: usize,
535 tol: Float,
536 ) -> Result<(Float, Array1<Float>)> {
537 let n = matrix.nrows();
538
539 let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
541
542 let norm = v.dot(&v).sqrt();
544 if norm < 1e-10 {
545 return Err(SklearsError::InvalidInput(
546 "Initial vector has zero norm".to_string(),
547 ));
548 }
549 v /= norm;
550
551 let mut eigenval = 0.0;
552
553 for _iter in 0..max_iter {
554 let w = matrix.dot(&v);
556
557 let new_eigenval = v.dot(&w);
559
560 let w_norm = w.dot(&w).sqrt();
562 if w_norm < 1e-10 {
563 break;
564 }
565 let new_v = w / w_norm;
566
567 let eigenval_change = (new_eigenval - eigenval).abs();
569 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
570
571 if eigenval_change < tol && vector_change < tol {
572 return Ok((new_eigenval, new_v));
573 }
574
575 eigenval = new_eigenval;
576 v = new_v;
577 }
578
579 Ok((eigenval, v))
580 }
581}
582
583impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
584 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
585 let weights = self
586 .weights_
587 .as_ref()
588 .ok_or_else(|| SklearsError::NotFitted {
589 operation: "predict".to_string(),
590 })?;
591
592 let feature_transformer =
593 self.feature_transformer_
594 .as_ref()
595 .ok_or_else(|| SklearsError::NotFitted {
596 operation: "predict".to_string(),
597 })?;
598
599 let x_transformed = feature_transformer.transform(x)?;
601
602 let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
604 let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
605
606 let predictions_f64 = x_f64.dot(&weights_f64);
608
609 let predictions =
611 Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
612
613 Ok(predictions)
614 }
615}
616
617#[derive(Debug, Clone)]
621pub struct OnlineKernelRidgeRegression<State = Untrained> {
622 pub base_model: KernelRidgeRegression<State>,
624 pub forgetting_factor: Float,
626 pub update_frequency: usize,
628
629 update_count_: usize,
631 accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
632
633 _state: PhantomData<State>,
634}
635
636impl OnlineKernelRidgeRegression<Untrained> {
637 pub fn new(approximation_method: ApproximationMethod) -> Self {
639 Self {
640 base_model: KernelRidgeRegression::new(approximation_method),
641 forgetting_factor: 0.99,
642 update_frequency: 100,
643 update_count_: 0,
644 accumulated_data_: None,
645 _state: PhantomData,
646 }
647 }
648
649 pub fn forgetting_factor(mut self, factor: Float) -> Self {
651 self.forgetting_factor = factor;
652 self
653 }
654
655 pub fn update_frequency(mut self, frequency: usize) -> Self {
657 self.update_frequency = frequency;
658 self
659 }
660
661 pub fn alpha(mut self, alpha: Float) -> Self {
663 self.base_model = self.base_model.alpha(alpha);
664 self
665 }
666
667 pub fn random_state(mut self, seed: u64) -> Self {
669 self.base_model = self.base_model.random_state(seed);
670 self
671 }
672}
673
674impl Estimator for OnlineKernelRidgeRegression<Untrained> {
675 type Config = ();
676 type Error = SklearsError;
677 type Float = Float;
678
679 fn config(&self) -> &Self::Config {
680 &()
681 }
682}
683
684impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
685 type Fitted = OnlineKernelRidgeRegression<Trained>;
686
687 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
688 let fitted_base = self.base_model.fit(x, y)?;
689
690 Ok(OnlineKernelRidgeRegression {
691 base_model: fitted_base,
692 forgetting_factor: self.forgetting_factor,
693 update_frequency: self.update_frequency,
694 update_count_: 0,
695 accumulated_data_: None,
696 _state: PhantomData,
697 })
698 }
699}
700
701impl OnlineKernelRidgeRegression<Trained> {
702 pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
704 match &self.accumulated_data_ {
706 Some((x_acc, y_acc)) => {
707 let x_combined =
708 scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
709 let y_combined =
710 scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
711 self.accumulated_data_ = Some((x_combined, y_combined));
712 }
713 None => {
714 self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
715 }
716 }
717
718 self.update_count_ += 1;
719
720 if self.update_count_ % self.update_frequency == 0 {
722 if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
723 let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
727 self.base_model = updated_base;
728 self.accumulated_data_ = None;
729 }
730 }
731
732 Ok(self)
733 }
734
735 pub fn update_count(&self) -> usize {
737 self.update_count_
738 }
739}
740
741impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
742 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
743 self.base_model.predict(x)
744 }
745}
746
747pub trait IntoUntrained<T> {
749 fn into_untrained(self) -> T;
750}
751
752impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
753 fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
754 KernelRidgeRegression {
755 approximation_method: self.approximation_method,
756 alpha: self.alpha,
757 solver: self.solver,
758 random_state: self.random_state,
759 weights_: None,
760 feature_transformer_: None,
761 _state: PhantomData,
762 }
763 }
764}
765
766#[allow(non_snake_case)]
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use scirs2_core::ndarray::array;
771
772 #[test]
773 fn test_kernel_ridge_regression_rff() {
774 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
775 let y = array![1.0, 4.0, 9.0, 16.0];
776
777 let approximation = ApproximationMethod::RandomFourierFeatures {
778 n_components: 50,
779 gamma: 0.1,
780 };
781
782 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
783 let fitted = krr.fit(&x, &y).unwrap();
784 let predictions = fitted.predict(&x).unwrap();
785
786 assert_eq!(predictions.len(), 4);
787 for pred in predictions.iter() {
789 assert!(pred.is_finite());
790 }
791 }
792
793 #[test]
794 fn test_kernel_ridge_regression_nystroem() {
795 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
796 let y = array![1.0, 2.0, 3.0];
797
798 let approximation = ApproximationMethod::Nystroem {
799 kernel: Kernel::Rbf { gamma: 1.0 },
800 n_components: 3,
801 sampling_strategy: SamplingStrategy::Random,
802 };
803
804 let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
805 let fitted = krr.fit(&x, &y).unwrap();
806 let predictions = fitted.predict(&x).unwrap();
807
808 assert_eq!(predictions.len(), 3);
809 }
810
811 #[test]
812 fn test_kernel_ridge_regression_fastfood() {
813 let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
814 let y = array![1.0, 2.0];
815
816 let approximation = ApproximationMethod::Fastfood {
817 n_components: 8,
818 gamma: 0.5,
819 };
820
821 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
822 let fitted = krr.fit(&x, &y).unwrap();
823 let predictions = fitted.predict(&x).unwrap();
824
825 assert_eq!(predictions.len(), 2);
826 }
827
828 #[test]
829 fn test_different_solvers() {
830 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
831 let y = array![1.0, 2.0, 3.0];
832
833 let approximation = ApproximationMethod::RandomFourierFeatures {
834 n_components: 10,
835 gamma: 1.0,
836 };
837
838 let krr_direct = KernelRidgeRegression::new(approximation.clone())
840 .solver(Solver::Direct)
841 .alpha(0.1);
842 let fitted_direct = krr_direct.fit(&x, &y).unwrap();
843 let pred_direct = fitted_direct.predict(&x).unwrap();
844
845 let krr_svd = KernelRidgeRegression::new(approximation.clone())
847 .solver(Solver::SVD)
848 .alpha(0.1);
849 let fitted_svd = krr_svd.fit(&x, &y).unwrap();
850 let pred_svd = fitted_svd.predict(&x).unwrap();
851
852 let krr_cg = KernelRidgeRegression::new(approximation)
854 .solver(Solver::ConjugateGradient {
855 max_iter: 100,
856 tol: 1e-6,
857 })
858 .alpha(0.1);
859 let fitted_cg = krr_cg.fit(&x, &y).unwrap();
860 let pred_cg = fitted_cg.predict(&x).unwrap();
861
862 assert_eq!(pred_direct.len(), 3);
863 assert_eq!(pred_svd.len(), 3);
864 assert_eq!(pred_cg.len(), 3);
865 }
866
867 #[test]
868 fn test_online_kernel_ridge_regression() {
869 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
870 let y_initial = array![1.0, 2.0];
871 let x_new = array![[3.0, 4.0], [4.0, 5.0]];
872 let y_new = array![3.0, 4.0];
873
874 let approximation = ApproximationMethod::RandomFourierFeatures {
875 n_components: 20,
876 gamma: 0.5,
877 };
878
879 let online_krr = OnlineKernelRidgeRegression::new(approximation)
880 .alpha(0.1)
881 .update_frequency(2);
882
883 let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
884 let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
885
886 assert_eq!(updated.update_count(), 1);
887
888 let predictions = updated.predict(&x_initial).unwrap();
889 assert_eq!(predictions.len(), 2);
890 }
891
892 #[test]
893 fn test_reproducibility() {
894 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
895 let y = array![1.0, 2.0, 3.0];
896
897 let approximation = ApproximationMethod::RandomFourierFeatures {
898 n_components: 10,
899 gamma: 1.0,
900 };
901
902 let krr1 = KernelRidgeRegression::new(approximation.clone())
903 .alpha(0.1)
904 .random_state(42);
905 let fitted1 = krr1.fit(&x, &y).unwrap();
906 let pred1 = fitted1.predict(&x).unwrap();
907
908 let krr2 = KernelRidgeRegression::new(approximation)
909 .alpha(0.1)
910 .random_state(42);
911 let fitted2 = krr2.fit(&x, &y).unwrap();
912 let pred2 = fitted2.predict(&x).unwrap();
913
914 assert_eq!(pred1.len(), pred2.len());
915 for i in 0..pred1.len() {
916 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
917 }
918 }
919}