1use crate::{
10 FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::{thread_rng, Rng};
15use sklears_core::error::{Result, SklearsError};
16use sklears_core::prelude::{Estimator, Fit, Float, Predict};
17use std::marker::PhantomData;
18
19use super::core_types::*;
20
21#[derive(Debug, Clone)]
26pub struct KernelRidgeRegression<State = Untrained> {
27 pub approximation_method: ApproximationMethod,
28 pub alpha: Float,
29 pub solver: Solver,
30 pub random_state: Option<u64>,
31
32 pub(crate) weights_: Option<Array1<Float>>,
34 pub(crate) feature_transformer_: Option<FeatureTransformer>,
35
36 pub(crate) _state: PhantomData<State>,
37}
38
39impl KernelRidgeRegression<Untrained> {
40 pub fn new(approximation_method: ApproximationMethod) -> Self {
42 Self {
43 approximation_method,
44 alpha: 1.0,
45 solver: Solver::Direct,
46 random_state: None,
47 weights_: None,
48 feature_transformer_: None,
49 _state: PhantomData,
50 }
51 }
52
53 pub fn alpha(mut self, alpha: Float) -> Self {
55 self.alpha = alpha;
56 self
57 }
58
59 pub fn solver(mut self, solver: Solver) -> Self {
61 self.solver = solver;
62 self
63 }
64
65 pub fn random_state(mut self, seed: u64) -> Self {
67 self.random_state = Some(seed);
68 self
69 }
70}
71
72impl Estimator for KernelRidgeRegression<Untrained> {
73 type Config = ();
74 type Error = SklearsError;
75 type Float = Float;
76
77 fn config(&self) -> &Self::Config {
78 &()
79 }
80}
81
82impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
83 type Fitted = KernelRidgeRegression<Trained>;
84
85 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
86 let (n_samples, _) = x.dim();
87
88 if y.len() != n_samples {
89 return Err(SklearsError::InvalidInput(
90 "Number of samples in X and y must match".to_string(),
91 ));
92 }
93
94 let feature_transformer = self.fit_feature_transformer(x)?;
96
97 let x_transformed = feature_transformer.transform(x)?;
99
100 let weights = self.solve_ridge_regression(&x_transformed, y)?;
102
103 Ok(KernelRidgeRegression {
104 approximation_method: self.approximation_method,
105 alpha: self.alpha,
106 solver: self.solver,
107 random_state: self.random_state,
108 weights_: Some(weights),
109 feature_transformer_: Some(feature_transformer),
110 _state: PhantomData,
111 })
112 }
113}
114
115impl KernelRidgeRegression<Untrained> {
116 fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
118 match &self.approximation_method {
119 ApproximationMethod::Nystroem {
120 kernel,
121 n_components,
122 sampling_strategy,
123 } => {
124 let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
125 .sampling_strategy(sampling_strategy.clone());
126
127 if let Some(seed) = self.random_state {
128 nystroem = nystroem.random_state(seed);
129 }
130
131 let fitted = nystroem.fit(x, &())?;
132 Ok(FeatureTransformer::Nystroem(fitted))
133 }
134 ApproximationMethod::RandomFourierFeatures {
135 n_components,
136 gamma,
137 } => {
138 let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
139
140 if let Some(seed) = self.random_state {
141 rbf_sampler = rbf_sampler.random_state(seed);
142 }
143
144 let fitted = rbf_sampler.fit(x, &())?;
145 Ok(FeatureTransformer::RBFSampler(fitted))
146 }
147 ApproximationMethod::StructuredRandomFeatures {
148 n_components,
149 gamma,
150 } => {
151 let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
152
153 if let Some(seed) = self.random_state {
154 structured_rff = structured_rff.random_state(seed);
155 }
156
157 let fitted = structured_rff.fit(x, &())?;
158 Ok(FeatureTransformer::StructuredRFF(fitted))
159 }
160 ApproximationMethod::Fastfood {
161 n_components,
162 gamma,
163 } => {
164 let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
165
166 if let Some(seed) = self.random_state {
167 fastfood = fastfood.random_state(seed);
168 }
169
170 let fitted = fastfood.fit(x, &())?;
171 Ok(FeatureTransformer::Fastfood(fitted))
172 }
173 }
174 }
175
176 fn solve_ridge_regression(
178 &self,
179 x: &Array2<Float>,
180 y: &Array1<Float>,
181 ) -> Result<Array1<Float>> {
182 let (n_samples, n_features) = x.dim();
183
184 match &self.solver {
185 Solver::Direct => self.solve_direct(x, y),
186 Solver::SVD => self.solve_svd(x, y),
187 Solver::ConjugateGradient { max_iter, tol } => {
188 self.solve_conjugate_gradient(x, y, *max_iter, *tol)
189 }
190 }
191 }
192
193 fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
195 let (_, n_features) = x.dim();
196
197 let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]] as f64);
199 let y_f64 = Array1::from_vec(y.iter().map(|&val| val as f64).collect());
200
201 let weights_f64 = {
224 let gram_matrix = x_f64.t().dot(&x_f64);
226 let mut regularized_gram = gram_matrix;
227 for i in 0..n_features {
228 regularized_gram[[i, i]] += self.alpha as f64;
229 }
230 let xty_f64 = x_f64.t().dot(&y_f64);
231
232 regularized_gram
234 .solve(&xty_f64)
235 .map_err(|e| SklearsError::InvalidParameter {
236 name: "regularization".to_string(),
237 reason: format!("Linear system solving failed: {:?}", e),
238 })?
239 };
240
241 let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
243 Ok(weights)
244 }
245
246 fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
248 let (n_samples, n_features) = x.dim();
252
253 let (u, s, vt) = self.compute_svd(x)?;
255
256 let threshold = 1e-12;
258 let mut s_reg_inv = Array1::zeros(s.len());
259 for i in 0..s.len() {
260 if s[i] > threshold {
261 s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
262 }
263 }
264
265 let ut_y = u.t().dot(y);
267 let mut temp = Array1::zeros(s.len());
268 for i in 0..s.len() {
269 temp[i] = s_reg_inv[i] * ut_y[i];
270 }
271
272 let weights = vt.t().dot(&temp);
273 Ok(weights)
274 }
275
276 fn solve_conjugate_gradient(
278 &self,
279 x: &Array2<Float>,
280 y: &Array1<Float>,
281 max_iter: usize,
282 tol: Float,
283 ) -> Result<Array1<Float>> {
284 let (n_samples, n_features) = x.dim();
285
286 let mut w = Array1::zeros(n_features);
288
289 let xty = x.t().dot(y);
291 let mut r = xty.clone();
292
293 let mut p = r.clone();
294 let mut rsold = r.dot(&r);
295
296 for _iter in 0..max_iter {
297 let xtxp = x.t().dot(&x.dot(&p));
299 let mut ap = xtxp;
300 for i in 0..n_features {
301 ap[i] += self.alpha * p[i];
302 }
303
304 let alpha_cg = rsold / p.dot(&ap);
306
307 w = w + alpha_cg * &p;
309
310 r = r - alpha_cg * ≈
312
313 let rsnew = r.dot(&r);
314
315 if rsnew.sqrt() < tol {
317 break;
318 }
319
320 let beta = rsnew / rsold;
322 p = &r + beta * &p;
323 rsold = rsnew;
324 }
325
326 Ok(w)
327 }
328
329 fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
331 let n = a.nrows();
332 if n != a.ncols() || n != b.len() {
333 return Err(SklearsError::InvalidInput(
334 "Matrix dimensions must match for linear system solve".to_string(),
335 ));
336 }
337
338 let mut aug = Array2::zeros((n, n + 1));
340 for i in 0..n {
341 for j in 0..n {
342 aug[[i, j]] = a[[i, j]];
343 }
344 aug[[i, n]] = b[i];
345 }
346
347 for k in 0..n {
349 let mut max_row = k;
351 for i in (k + 1)..n {
352 if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
353 max_row = i;
354 }
355 }
356
357 if max_row != k {
359 for j in 0..=n {
360 let temp = aug[[k, j]];
361 aug[[k, j]] = aug[[max_row, j]];
362 aug[[max_row, j]] = temp;
363 }
364 }
365
366 if aug[[k, k]].abs() < 1e-12 {
368 return Err(SklearsError::InvalidInput(
369 "Matrix is singular or nearly singular".to_string(),
370 ));
371 }
372
373 for i in (k + 1)..n {
375 let factor = aug[[i, k]] / aug[[k, k]];
376 for j in k..=n {
377 aug[[i, j]] -= factor * aug[[k, j]];
378 }
379 }
380 }
381
382 let mut x = Array1::zeros(n);
384 for i in (0..n).rev() {
385 let mut sum = aug[[i, n]];
386 for j in (i + 1)..n {
387 sum -= aug[[i, j]] * x[j];
388 }
389 x[i] = sum / aug[[i, i]];
390 }
391
392 Ok(x)
393 }
394
395 fn compute_svd(
398 &self,
399 x: &Array2<Float>,
400 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
401 let (m, n) = x.dim();
402 let min_dim = m.min(n);
403
404 let xt = x.t();
406
407 if n <= m {
408 let xtx = xt.dot(x);
410 let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
411
412 let mut singular_vals = Array1::zeros(min_dim);
414 let mut valid_indices = Vec::new();
415 for i in 0..eigenvals_v.len() {
416 if eigenvals_v[i] > 1e-12 {
417 singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
418 valid_indices.push(i);
419 if valid_indices.len() >= min_dim {
420 break;
421 }
422 }
423 }
424
425 let mut v = Array2::zeros((n, min_dim));
427 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
428 v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
429 }
430
431 let mut u = Array2::zeros((m, min_dim));
433 for j in 0..valid_indices.len() {
434 let v_col = v.column(j);
435 let xv = x.dot(&v_col);
436 let u_col = &xv / singular_vals[j];
437 u.column_mut(j).assign(&u_col);
438 }
439
440 Ok((u, singular_vals, v.t().to_owned()))
441 } else {
442 let xxt = x.dot(&xt);
444 let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
445
446 let mut singular_vals = Array1::zeros(min_dim);
448 let mut valid_indices = Vec::new();
449 for i in 0..eigenvals_u.len() {
450 if eigenvals_u[i] > 1e-12 {
451 singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
452 valid_indices.push(i);
453 if valid_indices.len() >= min_dim {
454 break;
455 }
456 }
457 }
458
459 let mut u = Array2::zeros((m, min_dim));
461 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
462 u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
463 }
464
465 let mut v = Array2::zeros((n, min_dim));
467 for j in 0..valid_indices.len() {
468 let u_col = u.column(j);
469 let xtu = xt.dot(&u_col);
470 let v_col = &xtu / singular_vals[j];
471 v.column_mut(j).assign(&v_col);
472 }
473
474 Ok((u, singular_vals, v.t().to_owned()))
475 }
476 }
477
478 fn compute_eigendecomposition_svd(
480 &self,
481 matrix: &Array2<Float>,
482 ) -> Result<(Array1<Float>, Array2<Float>)> {
483 let n = matrix.nrows();
484
485 if n != matrix.ncols() {
486 return Err(SklearsError::InvalidInput(
487 "Matrix must be square for eigendecomposition".to_string(),
488 ));
489 }
490
491 let mut eigenvals = Array1::zeros(n);
492 let mut eigenvecs = Array2::zeros((n, n));
493
494 let mut deflated_matrix = matrix.clone();
496
497 for k in 0..n {
498 let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
500
501 eigenvals[k] = eigenval;
502 eigenvecs.column_mut(k).assign(&eigenvec);
503
504 for i in 0..n {
506 for j in 0..n {
507 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
508 }
509 }
510 }
511
512 let mut indices: Vec<usize> = (0..n).collect();
514 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
515
516 let mut sorted_eigenvals = Array1::zeros(n);
517 let mut sorted_eigenvecs = Array2::zeros((n, n));
518
519 for (new_idx, &old_idx) in indices.iter().enumerate() {
520 sorted_eigenvals[new_idx] = eigenvals[old_idx];
521 sorted_eigenvecs
522 .column_mut(new_idx)
523 .assign(&eigenvecs.column(old_idx));
524 }
525
526 Ok((sorted_eigenvals, sorted_eigenvecs))
527 }
528
529 fn power_iteration_svd(
531 &self,
532 matrix: &Array2<Float>,
533 max_iter: usize,
534 tol: Float,
535 ) -> Result<(Float, Array1<Float>)> {
536 let n = matrix.nrows();
537
538 let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
540
541 let norm = v.dot(&v).sqrt();
543 if norm < 1e-10 {
544 return Err(SklearsError::InvalidInput(
545 "Initial vector has zero norm".to_string(),
546 ));
547 }
548 v /= norm;
549
550 let mut eigenval = 0.0;
551
552 for _iter in 0..max_iter {
553 let w = matrix.dot(&v);
555
556 let new_eigenval = v.dot(&w);
558
559 let w_norm = w.dot(&w).sqrt();
561 if w_norm < 1e-10 {
562 break;
563 }
564 let new_v = w / w_norm;
565
566 let eigenval_change = (new_eigenval - eigenval).abs();
568 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
569
570 if eigenval_change < tol && vector_change < tol {
571 return Ok((new_eigenval, new_v));
572 }
573
574 eigenval = new_eigenval;
575 v = new_v;
576 }
577
578 Ok((eigenval, v))
579 }
580}
581
582impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
583 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
584 let weights = self
585 .weights_
586 .as_ref()
587 .ok_or_else(|| SklearsError::NotFitted {
588 operation: "predict".to_string(),
589 })?;
590
591 let feature_transformer =
592 self.feature_transformer_
593 .as_ref()
594 .ok_or_else(|| SklearsError::NotFitted {
595 operation: "predict".to_string(),
596 })?;
597
598 let x_transformed = feature_transformer.transform(x)?;
600
601 let x_f64 =
603 Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]] as f64);
604 let weights_f64 = Array1::from_vec(weights.iter().map(|&val| val as f64).collect());
605
606 let predictions_f64 = x_f64.dot(&weights_f64);
612
613 let predictions =
615 Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
616
617 Ok(predictions)
618 }
619}
620
621#[derive(Debug, Clone)]
625pub struct OnlineKernelRidgeRegression<State = Untrained> {
626 pub base_model: KernelRidgeRegression<State>,
628 pub forgetting_factor: Float,
630 pub update_frequency: usize,
632
633 update_count_: usize,
635 accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
636
637 _state: PhantomData<State>,
638}
639
640impl OnlineKernelRidgeRegression<Untrained> {
641 pub fn new(approximation_method: ApproximationMethod) -> Self {
643 Self {
644 base_model: KernelRidgeRegression::new(approximation_method),
645 forgetting_factor: 0.99,
646 update_frequency: 100,
647 update_count_: 0,
648 accumulated_data_: None,
649 _state: PhantomData,
650 }
651 }
652
653 pub fn forgetting_factor(mut self, factor: Float) -> Self {
655 self.forgetting_factor = factor;
656 self
657 }
658
659 pub fn update_frequency(mut self, frequency: usize) -> Self {
661 self.update_frequency = frequency;
662 self
663 }
664
665 pub fn alpha(mut self, alpha: Float) -> Self {
667 self.base_model = self.base_model.alpha(alpha);
668 self
669 }
670
671 pub fn random_state(mut self, seed: u64) -> Self {
673 self.base_model = self.base_model.random_state(seed);
674 self
675 }
676}
677
678impl Estimator for OnlineKernelRidgeRegression<Untrained> {
679 type Config = ();
680 type Error = SklearsError;
681 type Float = Float;
682
683 fn config(&self) -> &Self::Config {
684 &()
685 }
686}
687
688impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
689 type Fitted = OnlineKernelRidgeRegression<Trained>;
690
691 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
692 let fitted_base = self.base_model.fit(x, y)?;
693
694 Ok(OnlineKernelRidgeRegression {
695 base_model: fitted_base,
696 forgetting_factor: self.forgetting_factor,
697 update_frequency: self.update_frequency,
698 update_count_: 0,
699 accumulated_data_: None,
700 _state: PhantomData,
701 })
702 }
703}
704
705impl OnlineKernelRidgeRegression<Trained> {
706 pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
708 match &self.accumulated_data_ {
710 Some((x_acc, y_acc)) => {
711 let x_combined =
712 scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
713 let y_combined =
714 scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
715 self.accumulated_data_ = Some((x_combined, y_combined));
716 }
717 None => {
718 self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
719 }
720 }
721
722 self.update_count_ += 1;
723
724 if self.update_count_ % self.update_frequency == 0 {
726 if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
727 let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
731 self.base_model = updated_base;
732 self.accumulated_data_ = None;
733 }
734 }
735
736 Ok(self)
737 }
738
739 pub fn update_count(&self) -> usize {
741 self.update_count_
742 }
743}
744
745impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
746 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
747 self.base_model.predict(x)
748 }
749}
750
751pub trait IntoUntrained<T> {
753 fn into_untrained(self) -> T;
754}
755
756impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
757 fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
758 KernelRidgeRegression {
759 approximation_method: self.approximation_method,
760 alpha: self.alpha,
761 solver: self.solver,
762 random_state: self.random_state,
763 weights_: None,
764 feature_transformer_: None,
765 _state: PhantomData,
766 }
767 }
768}
769
770#[allow(non_snake_case)]
771#[cfg(test)]
772mod tests {
773 use super::*;
774 use scirs2_core::ndarray::array;
775
776 #[test]
777 fn test_kernel_ridge_regression_rff() {
778 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
779 let y = array![1.0, 4.0, 9.0, 16.0];
780
781 let approximation = ApproximationMethod::RandomFourierFeatures {
782 n_components: 50,
783 gamma: 0.1,
784 };
785
786 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
787 let fitted = krr.fit(&x, &y).unwrap();
788 let predictions = fitted.predict(&x).unwrap();
789
790 assert_eq!(predictions.len(), 4);
791 for pred in predictions.iter() {
793 assert!(pred.is_finite());
794 }
795 }
796
797 #[test]
798 fn test_kernel_ridge_regression_nystroem() {
799 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
800 let y = array![1.0, 2.0, 3.0];
801
802 let approximation = ApproximationMethod::Nystroem {
803 kernel: Kernel::Rbf { gamma: 1.0 },
804 n_components: 3,
805 sampling_strategy: SamplingStrategy::Random,
806 };
807
808 let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
809 let fitted = krr.fit(&x, &y).unwrap();
810 let predictions = fitted.predict(&x).unwrap();
811
812 assert_eq!(predictions.len(), 3);
813 }
814
815 #[test]
816 fn test_kernel_ridge_regression_fastfood() {
817 let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
818 let y = array![1.0, 2.0];
819
820 let approximation = ApproximationMethod::Fastfood {
821 n_components: 8,
822 gamma: 0.5,
823 };
824
825 let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
826 let fitted = krr.fit(&x, &y).unwrap();
827 let predictions = fitted.predict(&x).unwrap();
828
829 assert_eq!(predictions.len(), 2);
830 }
831
832 #[test]
833 fn test_different_solvers() {
834 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
835 let y = array![1.0, 2.0, 3.0];
836
837 let approximation = ApproximationMethod::RandomFourierFeatures {
838 n_components: 10,
839 gamma: 1.0,
840 };
841
842 let krr_direct = KernelRidgeRegression::new(approximation.clone())
844 .solver(Solver::Direct)
845 .alpha(0.1);
846 let fitted_direct = krr_direct.fit(&x, &y).unwrap();
847 let pred_direct = fitted_direct.predict(&x).unwrap();
848
849 let krr_svd = KernelRidgeRegression::new(approximation.clone())
851 .solver(Solver::SVD)
852 .alpha(0.1);
853 let fitted_svd = krr_svd.fit(&x, &y).unwrap();
854 let pred_svd = fitted_svd.predict(&x).unwrap();
855
856 let krr_cg = KernelRidgeRegression::new(approximation)
858 .solver(Solver::ConjugateGradient {
859 max_iter: 100,
860 tol: 1e-6,
861 })
862 .alpha(0.1);
863 let fitted_cg = krr_cg.fit(&x, &y).unwrap();
864 let pred_cg = fitted_cg.predict(&x).unwrap();
865
866 assert_eq!(pred_direct.len(), 3);
867 assert_eq!(pred_svd.len(), 3);
868 assert_eq!(pred_cg.len(), 3);
869 }
870
871 #[test]
872 fn test_online_kernel_ridge_regression() {
873 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
874 let y_initial = array![1.0, 2.0];
875 let x_new = array![[3.0, 4.0], [4.0, 5.0]];
876 let y_new = array![3.0, 4.0];
877
878 let approximation = ApproximationMethod::RandomFourierFeatures {
879 n_components: 20,
880 gamma: 0.5,
881 };
882
883 let online_krr = OnlineKernelRidgeRegression::new(approximation)
884 .alpha(0.1)
885 .update_frequency(2);
886
887 let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
888 let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
889
890 assert_eq!(updated.update_count(), 1);
891
892 let predictions = updated.predict(&x_initial).unwrap();
893 assert_eq!(predictions.len(), 2);
894 }
895
896 #[test]
897 fn test_reproducibility() {
898 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
899 let y = array![1.0, 2.0, 3.0];
900
901 let approximation = ApproximationMethod::RandomFourierFeatures {
902 n_components: 10,
903 gamma: 1.0,
904 };
905
906 let krr1 = KernelRidgeRegression::new(approximation.clone())
907 .alpha(0.1)
908 .random_state(42);
909 let fitted1 = krr1.fit(&x, &y).unwrap();
910 let pred1 = fitted1.predict(&x).unwrap();
911
912 let krr2 = KernelRidgeRegression::new(approximation)
913 .alpha(0.1)
914 .random_state(42);
915 let fitted2 = krr2.fit(&x, &y).unwrap();
916 let pred2 = fitted2.predict(&x).unwrap();
917
918 assert_eq!(pred1.len(), pred2.len());
919 for i in 0..pred1.len() {
920 assert!((pred1[i] - pred2[i]).abs() < 1e-10);
921 }
922 }
923}