sklears_kernel_approximation/kernel_ridge_regression/
basic_regression.rs

1//! Basic Kernel Ridge Regression Implementation
2//!
3//! This module contains the core KernelRidgeRegression implementation including:
4//! - Basic kernel ridge regression with multiple approximation methods
5//! - Support for different solvers (Direct, SVD, Conjugate Gradient)
6//! - Online/Incremental learning capabilities
7//! - Comprehensive numerical linear algebra implementations
8
9use crate::{
10    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::{thread_rng, Rng};
15use sklears_core::error::{Result, SklearsError};
16use sklears_core::prelude::{Estimator, Fit, Float, Predict};
17use std::marker::PhantomData;
18
19use super::core_types::*;
20
21/// Basic Kernel Ridge Regression
22///
23/// This is the main kernel ridge regression implementation that supports various
24/// kernel approximation methods and solvers for different use cases.
25#[derive(Debug, Clone)]
26pub struct KernelRidgeRegression<State = Untrained> {
27    pub approximation_method: ApproximationMethod,
28    pub alpha: Float,
29    pub solver: Solver,
30    pub random_state: Option<u64>,
31
32    // Fitted parameters
33    pub(crate) weights_: Option<Array1<Float>>,
34    pub(crate) feature_transformer_: Option<FeatureTransformer>,
35
36    pub(crate) _state: PhantomData<State>,
37}
38
39impl KernelRidgeRegression<Untrained> {
40    /// Create a new kernel ridge regression model
41    pub fn new(approximation_method: ApproximationMethod) -> Self {
42        Self {
43            approximation_method,
44            alpha: 1.0,
45            solver: Solver::Direct,
46            random_state: None,
47            weights_: None,
48            feature_transformer_: None,
49            _state: PhantomData,
50        }
51    }
52
53    /// Set the regularization parameter
54    pub fn alpha(mut self, alpha: Float) -> Self {
55        self.alpha = alpha;
56        self
57    }
58
59    /// Set the solver method
60    pub fn solver(mut self, solver: Solver) -> Self {
61        self.solver = solver;
62        self
63    }
64
65    /// Set random state for reproducibility
66    pub fn random_state(mut self, seed: u64) -> Self {
67        self.random_state = Some(seed);
68        self
69    }
70}
71
72impl Estimator for KernelRidgeRegression<Untrained> {
73    type Config = ();
74    type Error = SklearsError;
75    type Float = Float;
76
77    fn config(&self) -> &Self::Config {
78        &()
79    }
80}
81
82impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
83    type Fitted = KernelRidgeRegression<Trained>;
84
85    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
86        let (n_samples, _) = x.dim();
87
88        if y.len() != n_samples {
89            return Err(SklearsError::InvalidInput(
90                "Number of samples in X and y must match".to_string(),
91            ));
92        }
93
94        // Fit the feature transformer based on approximation method
95        let feature_transformer = self.fit_feature_transformer(x)?;
96
97        // Transform features
98        let x_transformed = feature_transformer.transform(x)?;
99
100        // Solve the ridge regression problem: (X^T X + alpha * I) w = X^T y
101        let weights = self.solve_ridge_regression(&x_transformed, y)?;
102
103        Ok(KernelRidgeRegression {
104            approximation_method: self.approximation_method,
105            alpha: self.alpha,
106            solver: self.solver,
107            random_state: self.random_state,
108            weights_: Some(weights),
109            feature_transformer_: Some(feature_transformer),
110            _state: PhantomData,
111        })
112    }
113}
114
115impl KernelRidgeRegression<Untrained> {
116    /// Fit the feature transformer based on the approximation method
117    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
118        match &self.approximation_method {
119            ApproximationMethod::Nystroem {
120                kernel,
121                n_components,
122                sampling_strategy,
123            } => {
124                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
125                    .sampling_strategy(sampling_strategy.clone());
126
127                if let Some(seed) = self.random_state {
128                    nystroem = nystroem.random_state(seed);
129                }
130
131                let fitted = nystroem.fit(x, &())?;
132                Ok(FeatureTransformer::Nystroem(fitted))
133            }
134            ApproximationMethod::RandomFourierFeatures {
135                n_components,
136                gamma,
137            } => {
138                let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
139
140                if let Some(seed) = self.random_state {
141                    rbf_sampler = rbf_sampler.random_state(seed);
142                }
143
144                let fitted = rbf_sampler.fit(x, &())?;
145                Ok(FeatureTransformer::RBFSampler(fitted))
146            }
147            ApproximationMethod::StructuredRandomFeatures {
148                n_components,
149                gamma,
150            } => {
151                let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
152
153                if let Some(seed) = self.random_state {
154                    structured_rff = structured_rff.random_state(seed);
155                }
156
157                let fitted = structured_rff.fit(x, &())?;
158                Ok(FeatureTransformer::StructuredRFF(fitted))
159            }
160            ApproximationMethod::Fastfood {
161                n_components,
162                gamma,
163            } => {
164                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
165
166                if let Some(seed) = self.random_state {
167                    fastfood = fastfood.random_state(seed);
168                }
169
170                let fitted = fastfood.fit(x, &())?;
171                Ok(FeatureTransformer::Fastfood(fitted))
172            }
173        }
174    }
175
176    /// Solve the ridge regression problem
177    fn solve_ridge_regression(
178        &self,
179        x: &Array2<Float>,
180        y: &Array1<Float>,
181    ) -> Result<Array1<Float>> {
182        let (n_samples, n_features) = x.dim();
183
184        match &self.solver {
185            Solver::Direct => self.solve_direct(x, y),
186            Solver::SVD => self.solve_svd(x, y),
187            Solver::ConjugateGradient { max_iter, tol } => {
188                self.solve_conjugate_gradient(x, y, *max_iter, *tol)
189            }
190        }
191    }
192
193    /// Direct solver using normal equations
194    fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
195        let (_, n_features) = x.dim();
196
197        // SIMD-accelerated computation of X^T X + alpha * I using optimized operations
198        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]] as f64);
199        let y_f64 = Array1::from_vec(y.iter().map(|&val| val as f64).collect());
200
201        // #[cfg(feature = "nightly-simd")]
202        // let (gram_matrix, xty_f64, weights_f64) = {
203        //     let xtx_f64 = simd_kernel::simd_gram_matrix_from_data(&x_f64.view());
204        //     let mut gram_matrix = xtx_f64;
205        //
206        //     // SIMD-accelerated diagonal regularization
207        //     for i in 0..n_features {
208        //         gram_matrix[[i, i]] += self.alpha as f64;
209        //     }
210        //
211        //     // SIMD-accelerated computation of X^T y
212        //     let xty_f64 =
213        //         simd_kernel::simd_matrix_vector_multiply(&x_f64.t().view(), &y_f64.view())?;
214        //
215        //     // SIMD-accelerated linear system solving
216        //     let weights_f64 =
217        //         simd_kernel::simd_ridge_coefficients(&gram_matrix.view(), &xty_f64.view(), 0.0)?;
218        //
219        //     (gram_matrix, xty_f64, weights_f64)
220        // };
221
222        // Use standard implementation (SIMD temporarily disabled)
223        let weights_f64 = {
224            // Standard implementation without SIMD
225            let gram_matrix = x_f64.t().dot(&x_f64);
226            let mut regularized_gram = gram_matrix;
227            for i in 0..n_features {
228                regularized_gram[[i, i]] += self.alpha as f64;
229            }
230            let xty_f64 = x_f64.t().dot(&y_f64);
231
232            // Solve the linear system (X^T X + αI) w = X^T y
233            regularized_gram
234                .solve(&xty_f64)
235                .map_err(|e| SklearsError::InvalidParameter {
236                    name: "regularization".to_string(),
237                    reason: format!("Linear system solving failed: {:?}", e),
238                })?
239        };
240
241        // Convert back to Float
242        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
243        Ok(weights)
244    }
245
246    /// SVD-based solver (more numerically stable)
247    fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
248        // Use SVD of the design matrix X for numerical stability
249        // Solve the regularized least squares: min ||Xw - y||² + α||w||²
250
251        let (n_samples, n_features) = x.dim();
252
253        // Compute SVD of X using power iteration method
254        let (u, s, vt) = self.compute_svd(x)?;
255
256        // Compute regularized pseudo-inverse
257        let threshold = 1e-12;
258        let mut s_reg_inv = Array1::zeros(s.len());
259        for i in 0..s.len() {
260            if s[i] > threshold {
261                s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
262            }
263        }
264
265        // Solve: w = V * S_reg^(-1) * U^T * y
266        let ut_y = u.t().dot(y);
267        let mut temp = Array1::zeros(s.len());
268        for i in 0..s.len() {
269            temp[i] = s_reg_inv[i] * ut_y[i];
270        }
271
272        let weights = vt.t().dot(&temp);
273        Ok(weights)
274    }
275
276    /// Conjugate Gradient solver (iterative, memory efficient)
277    fn solve_conjugate_gradient(
278        &self,
279        x: &Array2<Float>,
280        y: &Array1<Float>,
281        max_iter: usize,
282        tol: Float,
283    ) -> Result<Array1<Float>> {
284        let (n_samples, n_features) = x.dim();
285
286        // Initialize weights
287        let mut w = Array1::zeros(n_features);
288
289        // Compute initial residual: r = X^T y - (X^T X + alpha I) w
290        let xty = x.t().dot(y);
291        let mut r = xty.clone();
292
293        let mut p = r.clone();
294        let mut rsold = r.dot(&r);
295
296        for _iter in 0..max_iter {
297            // Compute A * p where A = X^T X + alpha I
298            let xtxp = x.t().dot(&x.dot(&p));
299            let mut ap = xtxp;
300            for i in 0..n_features {
301                ap[i] += self.alpha * p[i];
302            }
303
304            // Compute step size
305            let alpha_cg = rsold / p.dot(&ap);
306
307            // Update weights
308            w = w + alpha_cg * &p;
309
310            // Update residual
311            r = r - alpha_cg * &ap;
312
313            let rsnew = r.dot(&r);
314
315            // Check convergence
316            if rsnew.sqrt() < tol {
317                break;
318            }
319
320            // Update search direction
321            let beta = rsnew / rsold;
322            p = &r + beta * &p;
323            rsold = rsnew;
324        }
325
326        Ok(w)
327    }
328
329    /// Solve linear system Ax = b using Gaussian elimination with partial pivoting
330    fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
331        let n = a.nrows();
332        if n != a.ncols() || n != b.len() {
333            return Err(SklearsError::InvalidInput(
334                "Matrix dimensions must match for linear system solve".to_string(),
335            ));
336        }
337
338        // Create augmented matrix [A|b]
339        let mut aug = Array2::zeros((n, n + 1));
340        for i in 0..n {
341            for j in 0..n {
342                aug[[i, j]] = a[[i, j]];
343            }
344            aug[[i, n]] = b[i];
345        }
346
347        // Forward elimination with partial pivoting
348        for k in 0..n {
349            // Find pivot
350            let mut max_row = k;
351            for i in (k + 1)..n {
352                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
353                    max_row = i;
354                }
355            }
356
357            // Swap rows if needed
358            if max_row != k {
359                for j in 0..=n {
360                    let temp = aug[[k, j]];
361                    aug[[k, j]] = aug[[max_row, j]];
362                    aug[[max_row, j]] = temp;
363                }
364            }
365
366            // Check for zero pivot
367            if aug[[k, k]].abs() < 1e-12 {
368                return Err(SklearsError::InvalidInput(
369                    "Matrix is singular or nearly singular".to_string(),
370                ));
371            }
372
373            // Eliminate column
374            for i in (k + 1)..n {
375                let factor = aug[[i, k]] / aug[[k, k]];
376                for j in k..=n {
377                    aug[[i, j]] -= factor * aug[[k, j]];
378                }
379            }
380        }
381
382        // Back substitution
383        let mut x = Array1::zeros(n);
384        for i in (0..n).rev() {
385            let mut sum = aug[[i, n]];
386            for j in (i + 1)..n {
387                sum -= aug[[i, j]] * x[j];
388            }
389            x[i] = sum / aug[[i, i]];
390        }
391
392        Ok(x)
393    }
394
395    /// Compute SVD using power iteration method
396    /// Returns (U, S, V^T) where X = U * S * V^T
397    fn compute_svd(
398        &self,
399        x: &Array2<Float>,
400    ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
401        let (m, n) = x.dim();
402        let min_dim = m.min(n);
403
404        // For SVD, we compute eigendecomposition of X^T X (for V) and X X^T (for U)
405        let xt = x.t();
406
407        if n <= m {
408            // Thin SVD: compute V from X^T X
409            let xtx = xt.dot(x);
410            let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
411
412            // Singular values are sqrt of eigenvalues
413            let mut singular_vals = Array1::zeros(min_dim);
414            let mut valid_indices = Vec::new();
415            for i in 0..eigenvals_v.len() {
416                if eigenvals_v[i] > 1e-12 {
417                    singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
418                    valid_indices.push(i);
419                    if valid_indices.len() >= min_dim {
420                        break;
421                    }
422                }
423            }
424
425            // Construct V matrix
426            let mut v = Array2::zeros((n, min_dim));
427            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
428                v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
429            }
430
431            // Compute U = X * V * S^(-1)
432            let mut u = Array2::zeros((m, min_dim));
433            for j in 0..valid_indices.len() {
434                let v_col = v.column(j);
435                let xv = x.dot(&v_col);
436                let u_col = &xv / singular_vals[j];
437                u.column_mut(j).assign(&u_col);
438            }
439
440            Ok((u, singular_vals, v.t().to_owned()))
441        } else {
442            // Wide matrix: compute U from X X^T
443            let xxt = x.dot(&xt);
444            let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
445
446            // Singular values are sqrt of eigenvalues
447            let mut singular_vals = Array1::zeros(min_dim);
448            let mut valid_indices = Vec::new();
449            for i in 0..eigenvals_u.len() {
450                if eigenvals_u[i] > 1e-12 {
451                    singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
452                    valid_indices.push(i);
453                    if valid_indices.len() >= min_dim {
454                        break;
455                    }
456                }
457            }
458
459            // Construct U matrix
460            let mut u = Array2::zeros((m, min_dim));
461            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
462                u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
463            }
464
465            // Compute V = X^T * U * S^(-1)
466            let mut v = Array2::zeros((n, min_dim));
467            for j in 0..valid_indices.len() {
468                let u_col = u.column(j);
469                let xtu = xt.dot(&u_col);
470                let v_col = &xtu / singular_vals[j];
471                v.column_mut(j).assign(&v_col);
472            }
473
474            Ok((u, singular_vals, v.t().to_owned()))
475        }
476    }
477
478    /// Compute eigendecomposition for SVD computation
479    fn compute_eigendecomposition_svd(
480        &self,
481        matrix: &Array2<Float>,
482    ) -> Result<(Array1<Float>, Array2<Float>)> {
483        let n = matrix.nrows();
484
485        if n != matrix.ncols() {
486            return Err(SklearsError::InvalidInput(
487                "Matrix must be square for eigendecomposition".to_string(),
488            ));
489        }
490
491        let mut eigenvals = Array1::zeros(n);
492        let mut eigenvecs = Array2::zeros((n, n));
493
494        // Use deflation method to find multiple eigenvalues
495        let mut deflated_matrix = matrix.clone();
496
497        for k in 0..n {
498            // Power iteration for k-th eigenvalue/eigenvector
499            let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
500
501            eigenvals[k] = eigenval;
502            eigenvecs.column_mut(k).assign(&eigenvec);
503
504            // Deflate matrix: A_new = A - λ * v * v^T
505            for i in 0..n {
506                for j in 0..n {
507                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
508                }
509            }
510        }
511
512        // Sort eigenvalues and eigenvectors in descending order
513        let mut indices: Vec<usize> = (0..n).collect();
514        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
515
516        let mut sorted_eigenvals = Array1::zeros(n);
517        let mut sorted_eigenvecs = Array2::zeros((n, n));
518
519        for (new_idx, &old_idx) in indices.iter().enumerate() {
520            sorted_eigenvals[new_idx] = eigenvals[old_idx];
521            sorted_eigenvecs
522                .column_mut(new_idx)
523                .assign(&eigenvecs.column(old_idx));
524        }
525
526        Ok((sorted_eigenvals, sorted_eigenvecs))
527    }
528
529    /// Power iteration method for SVD eigendecomposition
530    fn power_iteration_svd(
531        &self,
532        matrix: &Array2<Float>,
533        max_iter: usize,
534        tol: Float,
535    ) -> Result<(Float, Array1<Float>)> {
536        let n = matrix.nrows();
537
538        // Initialize random vector
539        let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
540
541        // Normalize
542        let norm = v.dot(&v).sqrt();
543        if norm < 1e-10 {
544            return Err(SklearsError::InvalidInput(
545                "Initial vector has zero norm".to_string(),
546            ));
547        }
548        v /= norm;
549
550        let mut eigenval = 0.0;
551
552        for _iter in 0..max_iter {
553            // Apply matrix
554            let w = matrix.dot(&v);
555
556            // Compute Rayleigh quotient
557            let new_eigenval = v.dot(&w);
558
559            // Normalize
560            let w_norm = w.dot(&w).sqrt();
561            if w_norm < 1e-10 {
562                break;
563            }
564            let new_v = w / w_norm;
565
566            // Check convergence
567            let eigenval_change = (new_eigenval - eigenval).abs();
568            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
569
570            if eigenval_change < tol && vector_change < tol {
571                return Ok((new_eigenval, new_v));
572            }
573
574            eigenval = new_eigenval;
575            v = new_v;
576        }
577
578        Ok((eigenval, v))
579    }
580}
581
582impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
583    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
584        let weights = self
585            .weights_
586            .as_ref()
587            .ok_or_else(|| SklearsError::NotFitted {
588                operation: "predict".to_string(),
589            })?;
590
591        let feature_transformer =
592            self.feature_transformer_
593                .as_ref()
594                .ok_or_else(|| SklearsError::NotFitted {
595                    operation: "predict".to_string(),
596                })?;
597
598        // Transform features
599        let x_transformed = feature_transformer.transform(x)?;
600
601        // SIMD-accelerated prediction computation using optimized matrix-vector multiplication
602        let x_f64 =
603            Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]] as f64);
604        let weights_f64 = Array1::from_vec(weights.iter().map(|&val| val as f64).collect());
605
606        // #[cfg(feature = "nightly-simd")]
607        // let predictions_f64 =
608        //     simd_kernel::simd_matrix_vector_multiply(&x_f64.view(), &weights_f64.view())?;
609
610        // Use standard implementation (SIMD temporarily disabled)
611        let predictions_f64 = x_f64.dot(&weights_f64);
612
613        // Convert back to Float
614        let predictions =
615            Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
616
617        Ok(predictions)
618    }
619}
620
621/// Online/Incremental Kernel Ridge Regression
622///
623/// This variant allows for online updates to the model as new data arrives.
624#[derive(Debug, Clone)]
625pub struct OnlineKernelRidgeRegression<State = Untrained> {
626    /// Base kernel ridge regression
627    pub base_model: KernelRidgeRegression<State>,
628    /// Forgetting factor for online updates
629    pub forgetting_factor: Float,
630    /// Update frequency
631    pub update_frequency: usize,
632
633    // Online state
634    update_count_: usize,
635    accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
636
637    _state: PhantomData<State>,
638}
639
640impl OnlineKernelRidgeRegression<Untrained> {
641    /// Create a new online kernel ridge regression model
642    pub fn new(approximation_method: ApproximationMethod) -> Self {
643        Self {
644            base_model: KernelRidgeRegression::new(approximation_method),
645            forgetting_factor: 0.99,
646            update_frequency: 100,
647            update_count_: 0,
648            accumulated_data_: None,
649            _state: PhantomData,
650        }
651    }
652
653    /// Set forgetting factor
654    pub fn forgetting_factor(mut self, factor: Float) -> Self {
655        self.forgetting_factor = factor;
656        self
657    }
658
659    /// Set update frequency
660    pub fn update_frequency(mut self, frequency: usize) -> Self {
661        self.update_frequency = frequency;
662        self
663    }
664
665    /// Set alpha parameter
666    pub fn alpha(mut self, alpha: Float) -> Self {
667        self.base_model = self.base_model.alpha(alpha);
668        self
669    }
670
671    /// Set random state
672    pub fn random_state(mut self, seed: u64) -> Self {
673        self.base_model = self.base_model.random_state(seed);
674        self
675    }
676}
677
678impl Estimator for OnlineKernelRidgeRegression<Untrained> {
679    type Config = ();
680    type Error = SklearsError;
681    type Float = Float;
682
683    fn config(&self) -> &Self::Config {
684        &()
685    }
686}
687
688impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
689    type Fitted = OnlineKernelRidgeRegression<Trained>;
690
691    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
692        let fitted_base = self.base_model.fit(x, y)?;
693
694        Ok(OnlineKernelRidgeRegression {
695            base_model: fitted_base,
696            forgetting_factor: self.forgetting_factor,
697            update_frequency: self.update_frequency,
698            update_count_: 0,
699            accumulated_data_: None,
700            _state: PhantomData,
701        })
702    }
703}
704
705impl OnlineKernelRidgeRegression<Trained> {
706    /// Update the model with new data
707    pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
708        // Accumulate new data
709        match &self.accumulated_data_ {
710            Some((x_acc, y_acc)) => {
711                let x_combined =
712                    scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
713                let y_combined =
714                    scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
715                self.accumulated_data_ = Some((x_combined, y_combined));
716            }
717            None => {
718                self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
719            }
720        }
721
722        self.update_count_ += 1;
723
724        // Check if it's time to update
725        if self.update_count_ % self.update_frequency == 0 {
726            if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
727                // Refit the model with accumulated data
728                // In practice, you might want to implement a more sophisticated
729                // online update algorithm here
730                let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
731                self.base_model = updated_base;
732                self.accumulated_data_ = None;
733            }
734        }
735
736        Ok(self)
737    }
738
739    /// Get the number of updates performed
740    pub fn update_count(&self) -> usize {
741        self.update_count_
742    }
743}
744
745impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
746    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
747        self.base_model.predict(x)
748    }
749}
750
751// Helper trait to convert trained model to untrained
752pub trait IntoUntrained<T> {
753    fn into_untrained(self) -> T;
754}
755
756impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
757    fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
758        KernelRidgeRegression {
759            approximation_method: self.approximation_method,
760            alpha: self.alpha,
761            solver: self.solver,
762            random_state: self.random_state,
763            weights_: None,
764            feature_transformer_: None,
765            _state: PhantomData,
766        }
767    }
768}
769
770#[allow(non_snake_case)]
771#[cfg(test)]
772mod tests {
773    use super::*;
774    use scirs2_core::ndarray::array;
775
776    #[test]
777    fn test_kernel_ridge_regression_rff() {
778        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
779        let y = array![1.0, 4.0, 9.0, 16.0];
780
781        let approximation = ApproximationMethod::RandomFourierFeatures {
782            n_components: 50,
783            gamma: 0.1,
784        };
785
786        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
787        let fitted = krr.fit(&x, &y).unwrap();
788        let predictions = fitted.predict(&x).unwrap();
789
790        assert_eq!(predictions.len(), 4);
791        // Check that predictions are reasonable
792        for pred in predictions.iter() {
793            assert!(pred.is_finite());
794        }
795    }
796
797    #[test]
798    fn test_kernel_ridge_regression_nystroem() {
799        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
800        let y = array![1.0, 2.0, 3.0];
801
802        let approximation = ApproximationMethod::Nystroem {
803            kernel: Kernel::Rbf { gamma: 1.0 },
804            n_components: 3,
805            sampling_strategy: SamplingStrategy::Random,
806        };
807
808        let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
809        let fitted = krr.fit(&x, &y).unwrap();
810        let predictions = fitted.predict(&x).unwrap();
811
812        assert_eq!(predictions.len(), 3);
813    }
814
815    #[test]
816    fn test_kernel_ridge_regression_fastfood() {
817        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
818        let y = array![1.0, 2.0];
819
820        let approximation = ApproximationMethod::Fastfood {
821            n_components: 8,
822            gamma: 0.5,
823        };
824
825        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
826        let fitted = krr.fit(&x, &y).unwrap();
827        let predictions = fitted.predict(&x).unwrap();
828
829        assert_eq!(predictions.len(), 2);
830    }
831
832    #[test]
833    fn test_different_solvers() {
834        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
835        let y = array![1.0, 2.0, 3.0];
836
837        let approximation = ApproximationMethod::RandomFourierFeatures {
838            n_components: 10,
839            gamma: 1.0,
840        };
841
842        // Test Direct solver
843        let krr_direct = KernelRidgeRegression::new(approximation.clone())
844            .solver(Solver::Direct)
845            .alpha(0.1);
846        let fitted_direct = krr_direct.fit(&x, &y).unwrap();
847        let pred_direct = fitted_direct.predict(&x).unwrap();
848
849        // Test SVD solver
850        let krr_svd = KernelRidgeRegression::new(approximation.clone())
851            .solver(Solver::SVD)
852            .alpha(0.1);
853        let fitted_svd = krr_svd.fit(&x, &y).unwrap();
854        let pred_svd = fitted_svd.predict(&x).unwrap();
855
856        // Test CG solver
857        let krr_cg = KernelRidgeRegression::new(approximation)
858            .solver(Solver::ConjugateGradient {
859                max_iter: 100,
860                tol: 1e-6,
861            })
862            .alpha(0.1);
863        let fitted_cg = krr_cg.fit(&x, &y).unwrap();
864        let pred_cg = fitted_cg.predict(&x).unwrap();
865
866        assert_eq!(pred_direct.len(), 3);
867        assert_eq!(pred_svd.len(), 3);
868        assert_eq!(pred_cg.len(), 3);
869    }
870
871    #[test]
872    fn test_online_kernel_ridge_regression() {
873        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
874        let y_initial = array![1.0, 2.0];
875        let x_new = array![[3.0, 4.0], [4.0, 5.0]];
876        let y_new = array![3.0, 4.0];
877
878        let approximation = ApproximationMethod::RandomFourierFeatures {
879            n_components: 20,
880            gamma: 0.5,
881        };
882
883        let online_krr = OnlineKernelRidgeRegression::new(approximation)
884            .alpha(0.1)
885            .update_frequency(2);
886
887        let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
888        let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
889
890        assert_eq!(updated.update_count(), 1);
891
892        let predictions = updated.predict(&x_initial).unwrap();
893        assert_eq!(predictions.len(), 2);
894    }
895
896    #[test]
897    fn test_reproducibility() {
898        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
899        let y = array![1.0, 2.0, 3.0];
900
901        let approximation = ApproximationMethod::RandomFourierFeatures {
902            n_components: 10,
903            gamma: 1.0,
904        };
905
906        let krr1 = KernelRidgeRegression::new(approximation.clone())
907            .alpha(0.1)
908            .random_state(42);
909        let fitted1 = krr1.fit(&x, &y).unwrap();
910        let pred1 = fitted1.predict(&x).unwrap();
911
912        let krr2 = KernelRidgeRegression::new(approximation)
913            .alpha(0.1)
914            .random_state(42);
915        let fitted2 = krr2.fit(&x, &y).unwrap();
916        let pred2 = fitted2.predict(&x).unwrap();
917
918        assert_eq!(pred1.len(), pred2.len());
919        for i in 0..pred1.len() {
920            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
921        }
922    }
923}