sklears_kernel_approximation/kernel_ridge_regression/
basic_regression.rs

1//! Basic Kernel Ridge Regression Implementation
2//!
3//! This module contains the core KernelRidgeRegression implementation including:
4//! - Basic kernel ridge regression with multiple approximation methods
5//! - Support for different solvers (Direct, SVD, Conjugate Gradient)
6//! - Online/Incremental learning capabilities
7//! - Comprehensive numerical linear algebra implementations
8
9use crate::{
10    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_linalg::compat::ArrayLinalgExt;
13// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
14use scirs2_core::ndarray::{Array1, Array2, Axis};
15use scirs2_core::random::thread_rng;
16use scirs2_core::random::Rng;
17use sklears_core::error::{Result, SklearsError};
18use sklears_core::prelude::{Estimator, Fit, Float, Predict};
19use std::marker::PhantomData;
20
21use super::core_types::*;
22
23/// Basic Kernel Ridge Regression
24///
25/// This is the main kernel ridge regression implementation that supports various
26/// kernel approximation methods and solvers for different use cases.
27#[derive(Debug, Clone)]
28pub struct KernelRidgeRegression<State = Untrained> {
29    pub approximation_method: ApproximationMethod,
30    pub alpha: Float,
31    pub solver: Solver,
32    pub random_state: Option<u64>,
33
34    // Fitted parameters
35    pub(crate) weights_: Option<Array1<Float>>,
36    pub(crate) feature_transformer_: Option<FeatureTransformer>,
37
38    pub(crate) _state: PhantomData<State>,
39}
40
41impl KernelRidgeRegression<Untrained> {
42    /// Create a new kernel ridge regression model
43    pub fn new(approximation_method: ApproximationMethod) -> Self {
44        Self {
45            approximation_method,
46            alpha: 1.0,
47            solver: Solver::Direct,
48            random_state: None,
49            weights_: None,
50            feature_transformer_: None,
51            _state: PhantomData,
52        }
53    }
54
55    /// Set the regularization parameter
56    pub fn alpha(mut self, alpha: Float) -> Self {
57        self.alpha = alpha;
58        self
59    }
60
61    /// Set the solver method
62    pub fn solver(mut self, solver: Solver) -> Self {
63        self.solver = solver;
64        self
65    }
66
67    /// Set random state for reproducibility
68    pub fn random_state(mut self, seed: u64) -> Self {
69        self.random_state = Some(seed);
70        self
71    }
72}
73
74impl Estimator for KernelRidgeRegression<Untrained> {
75    type Config = ();
76    type Error = SklearsError;
77    type Float = Float;
78
79    fn config(&self) -> &Self::Config {
80        &()
81    }
82}
83
84impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
85    type Fitted = KernelRidgeRegression<Trained>;
86
87    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
88        let (n_samples, _) = x.dim();
89
90        if y.len() != n_samples {
91            return Err(SklearsError::InvalidInput(
92                "Number of samples in X and y must match".to_string(),
93            ));
94        }
95
96        // Fit the feature transformer based on approximation method
97        let feature_transformer = self.fit_feature_transformer(x)?;
98
99        // Transform features
100        let x_transformed = feature_transformer.transform(x)?;
101
102        // Solve the ridge regression problem: (X^T X + alpha * I) w = X^T y
103        let weights = self.solve_ridge_regression(&x_transformed, y)?;
104
105        Ok(KernelRidgeRegression {
106            approximation_method: self.approximation_method,
107            alpha: self.alpha,
108            solver: self.solver,
109            random_state: self.random_state,
110            weights_: Some(weights),
111            feature_transformer_: Some(feature_transformer),
112            _state: PhantomData,
113        })
114    }
115}
116
117impl KernelRidgeRegression<Untrained> {
118    /// Fit the feature transformer based on the approximation method
119    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
120        match &self.approximation_method {
121            ApproximationMethod::Nystroem {
122                kernel,
123                n_components,
124                sampling_strategy,
125            } => {
126                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
127                    .sampling_strategy(sampling_strategy.clone());
128
129                if let Some(seed) = self.random_state {
130                    nystroem = nystroem.random_state(seed);
131                }
132
133                let fitted = nystroem.fit(x, &())?;
134                Ok(FeatureTransformer::Nystroem(fitted))
135            }
136            ApproximationMethod::RandomFourierFeatures {
137                n_components,
138                gamma,
139            } => {
140                let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
141
142                if let Some(seed) = self.random_state {
143                    rbf_sampler = rbf_sampler.random_state(seed);
144                }
145
146                let fitted = rbf_sampler.fit(x, &())?;
147                Ok(FeatureTransformer::RBFSampler(fitted))
148            }
149            ApproximationMethod::StructuredRandomFeatures {
150                n_components,
151                gamma,
152            } => {
153                let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
154
155                if let Some(seed) = self.random_state {
156                    structured_rff = structured_rff.random_state(seed);
157                }
158
159                let fitted = structured_rff.fit(x, &())?;
160                Ok(FeatureTransformer::StructuredRFF(fitted))
161            }
162            ApproximationMethod::Fastfood {
163                n_components,
164                gamma,
165            } => {
166                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
167
168                if let Some(seed) = self.random_state {
169                    fastfood = fastfood.random_state(seed);
170                }
171
172                let fitted = fastfood.fit(x, &())?;
173                Ok(FeatureTransformer::Fastfood(fitted))
174            }
175        }
176    }
177
178    /// Solve the ridge regression problem
179    fn solve_ridge_regression(
180        &self,
181        x: &Array2<Float>,
182        y: &Array1<Float>,
183    ) -> Result<Array1<Float>> {
184        let (_n_samples, _n_features) = x.dim();
185
186        match &self.solver {
187            Solver::Direct => self.solve_direct(x, y),
188            Solver::SVD => self.solve_svd(x, y),
189            Solver::ConjugateGradient { max_iter, tol } => {
190                self.solve_conjugate_gradient(x, y, *max_iter, *tol)
191            }
192        }
193    }
194
195    /// Direct solver using normal equations
196    fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
197        let (_, n_features) = x.dim();
198
199        // SIMD-accelerated computation of X^T X + alpha * I using optimized operations
200        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
201        let y_f64 = Array1::from_vec(y.iter().copied().collect());
202
203        // #[cfg(feature = "nightly-simd")]
204        // let (gram_matrix, xty_f64, weights_f64) = {
205        //     let xtx_f64 = simd_kernel::simd_gram_matrix_from_data(&x_f64.view());
206        //     let mut gram_matrix = xtx_f64;
207        //
208        //     // SIMD-accelerated diagonal regularization
209        //     for i in 0..n_features {
210        //         gram_matrix[[i, i]] += self.alpha as f64;
211        //     }
212        //
213        //     // SIMD-accelerated computation of X^T y
214        //     let xty_f64 =
215        //         simd_kernel::simd_matrix_vector_multiply(&x_f64.t().view(), &y_f64.view())?;
216        //
217        //     // SIMD-accelerated linear system solving
218        //     let weights_f64 =
219        //         simd_kernel::simd_ridge_coefficients(&gram_matrix.view(), &xty_f64.view(), 0.0)?;
220        //
221        //     (gram_matrix, xty_f64, weights_f64)
222        // };
223
224        // Use ndarray operations (optimized via BLAS backend)
225        let weights_f64 = {
226            // Matrix operations optimized through ndarray's BLAS backend
227            let gram_matrix = x_f64.t().dot(&x_f64);
228            let mut regularized_gram = gram_matrix;
229            for i in 0..n_features {
230                regularized_gram[[i, i]] += self.alpha;
231            }
232            let xty_f64 = x_f64.t().dot(&y_f64);
233
234            // Solve the linear system (X^T X + αI) w = X^T y
235            regularized_gram
236                .solve(&xty_f64)
237                .map_err(|e| SklearsError::InvalidParameter {
238                    name: "regularization".to_string(),
239                    reason: format!("Linear system solving failed: {:?}", e),
240                })?
241        };
242
243        // Convert back to Float
244        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
245        Ok(weights)
246    }
247
248    /// SVD-based solver (more numerically stable)
249    fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
250        // Use SVD of the design matrix X for numerical stability
251        // Solve the regularized least squares: min ||Xw - y||² + α||w||²
252
253        let (_n_samples, _n_features) = x.dim();
254
255        // Compute SVD of X using power iteration method
256        let (u, s, vt) = self.compute_svd(x)?;
257
258        // Compute regularized pseudo-inverse
259        let threshold = 1e-12;
260        let mut s_reg_inv = Array1::zeros(s.len());
261        for i in 0..s.len() {
262            if s[i] > threshold {
263                s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
264            }
265        }
266
267        // Solve: w = V * S_reg^(-1) * U^T * y
268        let ut_y = u.t().dot(y);
269        let mut temp = Array1::zeros(s.len());
270        for i in 0..s.len() {
271            temp[i] = s_reg_inv[i] * ut_y[i];
272        }
273
274        let weights = vt.t().dot(&temp);
275        Ok(weights)
276    }
277
278    /// Conjugate Gradient solver (iterative, memory efficient)
279    fn solve_conjugate_gradient(
280        &self,
281        x: &Array2<Float>,
282        y: &Array1<Float>,
283        max_iter: usize,
284        tol: Float,
285    ) -> Result<Array1<Float>> {
286        let (_n_samples, n_features) = x.dim();
287
288        // Initialize weights
289        let mut w = Array1::zeros(n_features);
290
291        // Compute initial residual: r = X^T y - (X^T X + alpha I) w
292        let xty = x.t().dot(y);
293        let mut r = xty.clone();
294
295        let mut p = r.clone();
296        let mut rsold = r.dot(&r);
297
298        for _iter in 0..max_iter {
299            // Compute A * p where A = X^T X + alpha I
300            let xtxp = x.t().dot(&x.dot(&p));
301            let mut ap = xtxp;
302            for i in 0..n_features {
303                ap[i] += self.alpha * p[i];
304            }
305
306            // Compute step size
307            let alpha_cg = rsold / p.dot(&ap);
308
309            // Update weights
310            w = w + alpha_cg * &p;
311
312            // Update residual
313            r = r - alpha_cg * &ap;
314
315            let rsnew = r.dot(&r);
316
317            // Check convergence
318            if rsnew.sqrt() < tol {
319                break;
320            }
321
322            // Update search direction
323            let beta = rsnew / rsold;
324            p = &r + beta * &p;
325            rsold = rsnew;
326        }
327
328        Ok(w)
329    }
330
331    /// Solve linear system Ax = b using Gaussian elimination with partial pivoting
332    fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
333        let n = a.nrows();
334        if n != a.ncols() || n != b.len() {
335            return Err(SklearsError::InvalidInput(
336                "Matrix dimensions must match for linear system solve".to_string(),
337            ));
338        }
339
340        // Create augmented matrix [A|b]
341        let mut aug = Array2::zeros((n, n + 1));
342        for i in 0..n {
343            for j in 0..n {
344                aug[[i, j]] = a[[i, j]];
345            }
346            aug[[i, n]] = b[i];
347        }
348
349        // Forward elimination with partial pivoting
350        for k in 0..n {
351            // Find pivot
352            let mut max_row = k;
353            for i in (k + 1)..n {
354                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
355                    max_row = i;
356                }
357            }
358
359            // Swap rows if needed
360            if max_row != k {
361                for j in 0..=n {
362                    let temp = aug[[k, j]];
363                    aug[[k, j]] = aug[[max_row, j]];
364                    aug[[max_row, j]] = temp;
365                }
366            }
367
368            // Check for zero pivot
369            if aug[[k, k]].abs() < 1e-12 {
370                return Err(SklearsError::InvalidInput(
371                    "Matrix is singular or nearly singular".to_string(),
372                ));
373            }
374
375            // Eliminate column
376            for i in (k + 1)..n {
377                let factor = aug[[i, k]] / aug[[k, k]];
378                for j in k..=n {
379                    aug[[i, j]] -= factor * aug[[k, j]];
380                }
381            }
382        }
383
384        // Back substitution
385        let mut x = Array1::zeros(n);
386        for i in (0..n).rev() {
387            let mut sum = aug[[i, n]];
388            for j in (i + 1)..n {
389                sum -= aug[[i, j]] * x[j];
390            }
391            x[i] = sum / aug[[i, i]];
392        }
393
394        Ok(x)
395    }
396
397    /// Compute SVD using power iteration method
398    /// Returns (U, S, V^T) where X = U * S * V^T
399    fn compute_svd(
400        &self,
401        x: &Array2<Float>,
402    ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
403        let (m, n) = x.dim();
404        let min_dim = m.min(n);
405
406        // For SVD, we compute eigendecomposition of X^T X (for V) and X X^T (for U)
407        let xt = x.t();
408
409        if n <= m {
410            // Thin SVD: compute V from X^T X
411            let xtx = xt.dot(x);
412            let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
413
414            // Singular values are sqrt of eigenvalues
415            let mut singular_vals = Array1::zeros(min_dim);
416            let mut valid_indices = Vec::new();
417            for i in 0..eigenvals_v.len() {
418                if eigenvals_v[i] > 1e-12 {
419                    singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
420                    valid_indices.push(i);
421                    if valid_indices.len() >= min_dim {
422                        break;
423                    }
424                }
425            }
426
427            // Construct V matrix
428            let mut v = Array2::zeros((n, min_dim));
429            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
430                v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
431            }
432
433            // Compute U = X * V * S^(-1)
434            let mut u = Array2::zeros((m, min_dim));
435            for j in 0..valid_indices.len() {
436                let v_col = v.column(j);
437                let xv = x.dot(&v_col);
438                let u_col = &xv / singular_vals[j];
439                u.column_mut(j).assign(&u_col);
440            }
441
442            Ok((u, singular_vals, v.t().to_owned()))
443        } else {
444            // Wide matrix: compute U from X X^T
445            let xxt = x.dot(&xt);
446            let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
447
448            // Singular values are sqrt of eigenvalues
449            let mut singular_vals = Array1::zeros(min_dim);
450            let mut valid_indices = Vec::new();
451            for i in 0..eigenvals_u.len() {
452                if eigenvals_u[i] > 1e-12 {
453                    singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
454                    valid_indices.push(i);
455                    if valid_indices.len() >= min_dim {
456                        break;
457                    }
458                }
459            }
460
461            // Construct U matrix
462            let mut u = Array2::zeros((m, min_dim));
463            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
464                u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
465            }
466
467            // Compute V = X^T * U * S^(-1)
468            let mut v = Array2::zeros((n, min_dim));
469            for j in 0..valid_indices.len() {
470                let u_col = u.column(j);
471                let xtu = xt.dot(&u_col);
472                let v_col = &xtu / singular_vals[j];
473                v.column_mut(j).assign(&v_col);
474            }
475
476            Ok((u, singular_vals, v.t().to_owned()))
477        }
478    }
479
480    /// Compute eigendecomposition for SVD computation
481    fn compute_eigendecomposition_svd(
482        &self,
483        matrix: &Array2<Float>,
484    ) -> Result<(Array1<Float>, Array2<Float>)> {
485        let n = matrix.nrows();
486
487        if n != matrix.ncols() {
488            return Err(SklearsError::InvalidInput(
489                "Matrix must be square for eigendecomposition".to_string(),
490            ));
491        }
492
493        let mut eigenvals = Array1::zeros(n);
494        let mut eigenvecs = Array2::zeros((n, n));
495
496        // Use deflation method to find multiple eigenvalues
497        let mut deflated_matrix = matrix.clone();
498
499        for k in 0..n {
500            // Power iteration for k-th eigenvalue/eigenvector
501            let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
502
503            eigenvals[k] = eigenval;
504            eigenvecs.column_mut(k).assign(&eigenvec);
505
506            // Deflate matrix: A_new = A - λ * v * v^T
507            for i in 0..n {
508                for j in 0..n {
509                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
510                }
511            }
512        }
513
514        // Sort eigenvalues and eigenvectors in descending order
515        let mut indices: Vec<usize> = (0..n).collect();
516        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
517
518        let mut sorted_eigenvals = Array1::zeros(n);
519        let mut sorted_eigenvecs = Array2::zeros((n, n));
520
521        for (new_idx, &old_idx) in indices.iter().enumerate() {
522            sorted_eigenvals[new_idx] = eigenvals[old_idx];
523            sorted_eigenvecs
524                .column_mut(new_idx)
525                .assign(&eigenvecs.column(old_idx));
526        }
527
528        Ok((sorted_eigenvals, sorted_eigenvecs))
529    }
530
531    /// Power iteration method for SVD eigendecomposition
532    fn power_iteration_svd(
533        &self,
534        matrix: &Array2<Float>,
535        max_iter: usize,
536        tol: Float,
537    ) -> Result<(Float, Array1<Float>)> {
538        let n = matrix.nrows();
539
540        // Initialize random vector
541        let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
542
543        // Normalize
544        let norm = v.dot(&v).sqrt();
545        if norm < 1e-10 {
546            return Err(SklearsError::InvalidInput(
547                "Initial vector has zero norm".to_string(),
548            ));
549        }
550        v /= norm;
551
552        let mut eigenval = 0.0;
553
554        for _iter in 0..max_iter {
555            // Apply matrix
556            let w = matrix.dot(&v);
557
558            // Compute Rayleigh quotient
559            let new_eigenval = v.dot(&w);
560
561            // Normalize
562            let w_norm = w.dot(&w).sqrt();
563            if w_norm < 1e-10 {
564                break;
565            }
566            let new_v = w / w_norm;
567
568            // Check convergence
569            let eigenval_change = (new_eigenval - eigenval).abs();
570            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
571
572            if eigenval_change < tol && vector_change < tol {
573                return Ok((new_eigenval, new_v));
574            }
575
576            eigenval = new_eigenval;
577            v = new_v;
578        }
579
580        Ok((eigenval, v))
581    }
582}
583
584impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
585    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
586        let weights = self
587            .weights_
588            .as_ref()
589            .ok_or_else(|| SklearsError::NotFitted {
590                operation: "predict".to_string(),
591            })?;
592
593        let feature_transformer =
594            self.feature_transformer_
595                .as_ref()
596                .ok_or_else(|| SklearsError::NotFitted {
597                    operation: "predict".to_string(),
598                })?;
599
600        // Transform features
601        let x_transformed = feature_transformer.transform(x)?;
602
603        // SIMD-accelerated prediction computation using optimized matrix-vector multiplication
604        let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
605        let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
606
607        // Matrix-vector multiplication optimized via BLAS backend
608        let predictions_f64 = x_f64.dot(&weights_f64);
609
610        // Convert back to Float
611        let predictions =
612            Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
613
614        Ok(predictions)
615    }
616}
617
618/// Online/Incremental Kernel Ridge Regression
619///
620/// This variant allows for online updates to the model as new data arrives.
621#[derive(Debug, Clone)]
622pub struct OnlineKernelRidgeRegression<State = Untrained> {
623    /// Base kernel ridge regression
624    pub base_model: KernelRidgeRegression<State>,
625    /// Forgetting factor for online updates
626    pub forgetting_factor: Float,
627    /// Update frequency
628    pub update_frequency: usize,
629
630    // Online state
631    update_count_: usize,
632    accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
633
634    _state: PhantomData<State>,
635}
636
637impl OnlineKernelRidgeRegression<Untrained> {
638    /// Create a new online kernel ridge regression model
639    pub fn new(approximation_method: ApproximationMethod) -> Self {
640        Self {
641            base_model: KernelRidgeRegression::new(approximation_method),
642            forgetting_factor: 0.99,
643            update_frequency: 100,
644            update_count_: 0,
645            accumulated_data_: None,
646            _state: PhantomData,
647        }
648    }
649
650    /// Set forgetting factor
651    pub fn forgetting_factor(mut self, factor: Float) -> Self {
652        self.forgetting_factor = factor;
653        self
654    }
655
656    /// Set update frequency
657    pub fn update_frequency(mut self, frequency: usize) -> Self {
658        self.update_frequency = frequency;
659        self
660    }
661
662    /// Set alpha parameter
663    pub fn alpha(mut self, alpha: Float) -> Self {
664        self.base_model = self.base_model.alpha(alpha);
665        self
666    }
667
668    /// Set random state
669    pub fn random_state(mut self, seed: u64) -> Self {
670        self.base_model = self.base_model.random_state(seed);
671        self
672    }
673}
674
675impl Estimator for OnlineKernelRidgeRegression<Untrained> {
676    type Config = ();
677    type Error = SklearsError;
678    type Float = Float;
679
680    fn config(&self) -> &Self::Config {
681        &()
682    }
683}
684
685impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
686    type Fitted = OnlineKernelRidgeRegression<Trained>;
687
688    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
689        let fitted_base = self.base_model.fit(x, y)?;
690
691        Ok(OnlineKernelRidgeRegression {
692            base_model: fitted_base,
693            forgetting_factor: self.forgetting_factor,
694            update_frequency: self.update_frequency,
695            update_count_: 0,
696            accumulated_data_: None,
697            _state: PhantomData,
698        })
699    }
700}
701
702impl OnlineKernelRidgeRegression<Trained> {
703    /// Update the model with new data
704    pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
705        // Accumulate new data
706        match &self.accumulated_data_ {
707            Some((x_acc, y_acc)) => {
708                let x_combined =
709                    scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
710                let y_combined =
711                    scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
712                self.accumulated_data_ = Some((x_combined, y_combined));
713            }
714            None => {
715                self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
716            }
717        }
718
719        self.update_count_ += 1;
720
721        // Check if it's time to update
722        if self.update_count_ % self.update_frequency == 0 {
723            if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
724                // Refit the model with accumulated data
725                // In practice, you might want to implement a more sophisticated
726                // online update algorithm here
727                let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
728                self.base_model = updated_base;
729                self.accumulated_data_ = None;
730            }
731        }
732
733        Ok(self)
734    }
735
736    /// Get the number of updates performed
737    pub fn update_count(&self) -> usize {
738        self.update_count_
739    }
740}
741
742impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
743    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
744        self.base_model.predict(x)
745    }
746}
747
748// Helper trait to convert trained model to untrained
749pub trait IntoUntrained<T> {
750    fn into_untrained(self) -> T;
751}
752
753impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
754    fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
755        KernelRidgeRegression {
756            approximation_method: self.approximation_method,
757            alpha: self.alpha,
758            solver: self.solver,
759            random_state: self.random_state,
760            weights_: None,
761            feature_transformer_: None,
762            _state: PhantomData,
763        }
764    }
765}
766
767#[allow(non_snake_case)]
768#[cfg(test)]
769mod tests {
770    use super::*;
771    use scirs2_core::ndarray::array;
772
773    #[test]
774    fn test_kernel_ridge_regression_rff() {
775        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
776        let y = array![1.0, 4.0, 9.0, 16.0];
777
778        let approximation = ApproximationMethod::RandomFourierFeatures {
779            n_components: 50,
780            gamma: 0.1,
781        };
782
783        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
784        let fitted = krr.fit(&x, &y).unwrap();
785        let predictions = fitted.predict(&x).unwrap();
786
787        assert_eq!(predictions.len(), 4);
788        // Check that predictions are reasonable
789        for pred in predictions.iter() {
790            assert!(pred.is_finite());
791        }
792    }
793
794    #[test]
795    fn test_kernel_ridge_regression_nystroem() {
796        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
797        let y = array![1.0, 2.0, 3.0];
798
799        let approximation = ApproximationMethod::Nystroem {
800            kernel: Kernel::Rbf { gamma: 1.0 },
801            n_components: 3,
802            sampling_strategy: SamplingStrategy::Random,
803        };
804
805        let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
806        let fitted = krr.fit(&x, &y).unwrap();
807        let predictions = fitted.predict(&x).unwrap();
808
809        assert_eq!(predictions.len(), 3);
810    }
811
812    #[test]
813    fn test_kernel_ridge_regression_fastfood() {
814        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
815        let y = array![1.0, 2.0];
816
817        let approximation = ApproximationMethod::Fastfood {
818            n_components: 8,
819            gamma: 0.5,
820        };
821
822        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
823        let fitted = krr.fit(&x, &y).unwrap();
824        let predictions = fitted.predict(&x).unwrap();
825
826        assert_eq!(predictions.len(), 2);
827    }
828
829    #[test]
830    fn test_different_solvers() {
831        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
832        let y = array![1.0, 2.0, 3.0];
833
834        let approximation = ApproximationMethod::RandomFourierFeatures {
835            n_components: 10,
836            gamma: 1.0,
837        };
838
839        // Test Direct solver
840        let krr_direct = KernelRidgeRegression::new(approximation.clone())
841            .solver(Solver::Direct)
842            .alpha(0.1);
843        let fitted_direct = krr_direct.fit(&x, &y).unwrap();
844        let pred_direct = fitted_direct.predict(&x).unwrap();
845
846        // Test SVD solver
847        let krr_svd = KernelRidgeRegression::new(approximation.clone())
848            .solver(Solver::SVD)
849            .alpha(0.1);
850        let fitted_svd = krr_svd.fit(&x, &y).unwrap();
851        let pred_svd = fitted_svd.predict(&x).unwrap();
852
853        // Test CG solver
854        let krr_cg = KernelRidgeRegression::new(approximation)
855            .solver(Solver::ConjugateGradient {
856                max_iter: 100,
857                tol: 1e-6,
858            })
859            .alpha(0.1);
860        let fitted_cg = krr_cg.fit(&x, &y).unwrap();
861        let pred_cg = fitted_cg.predict(&x).unwrap();
862
863        assert_eq!(pred_direct.len(), 3);
864        assert_eq!(pred_svd.len(), 3);
865        assert_eq!(pred_cg.len(), 3);
866    }
867
868    #[test]
869    fn test_online_kernel_ridge_regression() {
870        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
871        let y_initial = array![1.0, 2.0];
872        let x_new = array![[3.0, 4.0], [4.0, 5.0]];
873        let y_new = array![3.0, 4.0];
874
875        let approximation = ApproximationMethod::RandomFourierFeatures {
876            n_components: 20,
877            gamma: 0.5,
878        };
879
880        let online_krr = OnlineKernelRidgeRegression::new(approximation)
881            .alpha(0.1)
882            .update_frequency(2);
883
884        let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
885        let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
886
887        assert_eq!(updated.update_count(), 1);
888
889        let predictions = updated.predict(&x_initial).unwrap();
890        assert_eq!(predictions.len(), 2);
891    }
892
893    #[test]
894    fn test_reproducibility() {
895        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
896        let y = array![1.0, 2.0, 3.0];
897
898        let approximation = ApproximationMethod::RandomFourierFeatures {
899            n_components: 10,
900            gamma: 1.0,
901        };
902
903        let krr1 = KernelRidgeRegression::new(approximation.clone())
904            .alpha(0.1)
905            .random_state(42);
906        let fitted1 = krr1.fit(&x, &y).unwrap();
907        let pred1 = fitted1.predict(&x).unwrap();
908
909        let krr2 = KernelRidgeRegression::new(approximation)
910            .alpha(0.1)
911            .random_state(42);
912        let fitted2 = krr2.fit(&x, &y).unwrap();
913        let pred2 = fitted2.predict(&x).unwrap();
914
915        assert_eq!(pred1.len(), pred2.len());
916        for i in 0..pred1.len() {
917            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
918        }
919    }
920}