sklears_kernel_approximation/kernel_ridge_regression/
basic_regression.rs

1//! Basic Kernel Ridge Regression Implementation
2//!
3//! This module contains the core KernelRidgeRegression implementation including:
4//! - Basic kernel ridge regression with multiple approximation methods
5//! - Support for different solvers (Direct, SVD, Conjugate Gradient)
6//! - Online/Incremental learning capabilities
7//! - Comprehensive numerical linear algebra implementations
8
9use crate::{
10    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_core::ndarray::ndarray_linalg::solve::Solve;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::thread_rng;
15use scirs2_core::random::Rng;
16use sklears_core::error::{Result, SklearsError};
17use sklears_core::prelude::{Estimator, Fit, Float, Predict};
18use std::marker::PhantomData;
19
20use super::core_types::*;
21
22/// Basic Kernel Ridge Regression
23///
24/// This is the main kernel ridge regression implementation that supports various
25/// kernel approximation methods and solvers for different use cases.
26#[derive(Debug, Clone)]
27pub struct KernelRidgeRegression<State = Untrained> {
28    pub approximation_method: ApproximationMethod,
29    pub alpha: Float,
30    pub solver: Solver,
31    pub random_state: Option<u64>,
32
33    // Fitted parameters
34    pub(crate) weights_: Option<Array1<Float>>,
35    pub(crate) feature_transformer_: Option<FeatureTransformer>,
36
37    pub(crate) _state: PhantomData<State>,
38}
39
40impl KernelRidgeRegression<Untrained> {
41    /// Create a new kernel ridge regression model
42    pub fn new(approximation_method: ApproximationMethod) -> Self {
43        Self {
44            approximation_method,
45            alpha: 1.0,
46            solver: Solver::Direct,
47            random_state: None,
48            weights_: None,
49            feature_transformer_: None,
50            _state: PhantomData,
51        }
52    }
53
54    /// Set the regularization parameter
55    pub fn alpha(mut self, alpha: Float) -> Self {
56        self.alpha = alpha;
57        self
58    }
59
60    /// Set the solver method
61    pub fn solver(mut self, solver: Solver) -> Self {
62        self.solver = solver;
63        self
64    }
65
66    /// Set random state for reproducibility
67    pub fn random_state(mut self, seed: u64) -> Self {
68        self.random_state = Some(seed);
69        self
70    }
71}
72
73impl Estimator for KernelRidgeRegression<Untrained> {
74    type Config = ();
75    type Error = SklearsError;
76    type Float = Float;
77
78    fn config(&self) -> &Self::Config {
79        &()
80    }
81}
82
83impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
84    type Fitted = KernelRidgeRegression<Trained>;
85
86    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
87        let (n_samples, _) = x.dim();
88
89        if y.len() != n_samples {
90            return Err(SklearsError::InvalidInput(
91                "Number of samples in X and y must match".to_string(),
92            ));
93        }
94
95        // Fit the feature transformer based on approximation method
96        let feature_transformer = self.fit_feature_transformer(x)?;
97
98        // Transform features
99        let x_transformed = feature_transformer.transform(x)?;
100
101        // Solve the ridge regression problem: (X^T X + alpha * I) w = X^T y
102        let weights = self.solve_ridge_regression(&x_transformed, y)?;
103
104        Ok(KernelRidgeRegression {
105            approximation_method: self.approximation_method,
106            alpha: self.alpha,
107            solver: self.solver,
108            random_state: self.random_state,
109            weights_: Some(weights),
110            feature_transformer_: Some(feature_transformer),
111            _state: PhantomData,
112        })
113    }
114}
115
116impl KernelRidgeRegression<Untrained> {
117    /// Fit the feature transformer based on the approximation method
118    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
119        match &self.approximation_method {
120            ApproximationMethod::Nystroem {
121                kernel,
122                n_components,
123                sampling_strategy,
124            } => {
125                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
126                    .sampling_strategy(sampling_strategy.clone());
127
128                if let Some(seed) = self.random_state {
129                    nystroem = nystroem.random_state(seed);
130                }
131
132                let fitted = nystroem.fit(x, &())?;
133                Ok(FeatureTransformer::Nystroem(fitted))
134            }
135            ApproximationMethod::RandomFourierFeatures {
136                n_components,
137                gamma,
138            } => {
139                let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
140
141                if let Some(seed) = self.random_state {
142                    rbf_sampler = rbf_sampler.random_state(seed);
143                }
144
145                let fitted = rbf_sampler.fit(x, &())?;
146                Ok(FeatureTransformer::RBFSampler(fitted))
147            }
148            ApproximationMethod::StructuredRandomFeatures {
149                n_components,
150                gamma,
151            } => {
152                let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
153
154                if let Some(seed) = self.random_state {
155                    structured_rff = structured_rff.random_state(seed);
156                }
157
158                let fitted = structured_rff.fit(x, &())?;
159                Ok(FeatureTransformer::StructuredRFF(fitted))
160            }
161            ApproximationMethod::Fastfood {
162                n_components,
163                gamma,
164            } => {
165                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
166
167                if let Some(seed) = self.random_state {
168                    fastfood = fastfood.random_state(seed);
169                }
170
171                let fitted = fastfood.fit(x, &())?;
172                Ok(FeatureTransformer::Fastfood(fitted))
173            }
174        }
175    }
176
177    /// Solve the ridge regression problem
178    fn solve_ridge_regression(
179        &self,
180        x: &Array2<Float>,
181        y: &Array1<Float>,
182    ) -> Result<Array1<Float>> {
183        let (_n_samples, _n_features) = x.dim();
184
185        match &self.solver {
186            Solver::Direct => self.solve_direct(x, y),
187            Solver::SVD => self.solve_svd(x, y),
188            Solver::ConjugateGradient { max_iter, tol } => {
189                self.solve_conjugate_gradient(x, y, *max_iter, *tol)
190            }
191        }
192    }
193
194    /// Direct solver using normal equations
195    fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
196        let (_, n_features) = x.dim();
197
198        // SIMD-accelerated computation of X^T X + alpha * I using optimized operations
199        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
200        let y_f64 = Array1::from_vec(y.iter().copied().collect());
201
202        // #[cfg(feature = "nightly-simd")]
203        // let (gram_matrix, xty_f64, weights_f64) = {
204        //     let xtx_f64 = simd_kernel::simd_gram_matrix_from_data(&x_f64.view());
205        //     let mut gram_matrix = xtx_f64;
206        //
207        //     // SIMD-accelerated diagonal regularization
208        //     for i in 0..n_features {
209        //         gram_matrix[[i, i]] += self.alpha as f64;
210        //     }
211        //
212        //     // SIMD-accelerated computation of X^T y
213        //     let xty_f64 =
214        //         simd_kernel::simd_matrix_vector_multiply(&x_f64.t().view(), &y_f64.view())?;
215        //
216        //     // SIMD-accelerated linear system solving
217        //     let weights_f64 =
218        //         simd_kernel::simd_ridge_coefficients(&gram_matrix.view(), &xty_f64.view(), 0.0)?;
219        //
220        //     (gram_matrix, xty_f64, weights_f64)
221        // };
222
223        // Use ndarray operations (optimized via BLAS backend)
224        let weights_f64 = {
225            // Matrix operations optimized through ndarray's BLAS backend
226            let gram_matrix = x_f64.t().dot(&x_f64);
227            let mut regularized_gram = gram_matrix;
228            for i in 0..n_features {
229                regularized_gram[[i, i]] += self.alpha;
230            }
231            let xty_f64 = x_f64.t().dot(&y_f64);
232
233            // Solve the linear system (X^T X + αI) w = X^T y
234            regularized_gram
235                .solve(&xty_f64)
236                .map_err(|e| SklearsError::InvalidParameter {
237                    name: "regularization".to_string(),
238                    reason: format!("Linear system solving failed: {:?}", e),
239                })?
240        };
241
242        // Convert back to Float
243        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
244        Ok(weights)
245    }
246
247    /// SVD-based solver (more numerically stable)
248    fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
249        // Use SVD of the design matrix X for numerical stability
250        // Solve the regularized least squares: min ||Xw - y||² + α||w||²
251
252        let (_n_samples, _n_features) = x.dim();
253
254        // Compute SVD of X using power iteration method
255        let (u, s, vt) = self.compute_svd(x)?;
256
257        // Compute regularized pseudo-inverse
258        let threshold = 1e-12;
259        let mut s_reg_inv = Array1::zeros(s.len());
260        for i in 0..s.len() {
261            if s[i] > threshold {
262                s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
263            }
264        }
265
266        // Solve: w = V * S_reg^(-1) * U^T * y
267        let ut_y = u.t().dot(y);
268        let mut temp = Array1::zeros(s.len());
269        for i in 0..s.len() {
270            temp[i] = s_reg_inv[i] * ut_y[i];
271        }
272
273        let weights = vt.t().dot(&temp);
274        Ok(weights)
275    }
276
277    /// Conjugate Gradient solver (iterative, memory efficient)
278    fn solve_conjugate_gradient(
279        &self,
280        x: &Array2<Float>,
281        y: &Array1<Float>,
282        max_iter: usize,
283        tol: Float,
284    ) -> Result<Array1<Float>> {
285        let (_n_samples, n_features) = x.dim();
286
287        // Initialize weights
288        let mut w = Array1::zeros(n_features);
289
290        // Compute initial residual: r = X^T y - (X^T X + alpha I) w
291        let xty = x.t().dot(y);
292        let mut r = xty.clone();
293
294        let mut p = r.clone();
295        let mut rsold = r.dot(&r);
296
297        for _iter in 0..max_iter {
298            // Compute A * p where A = X^T X + alpha I
299            let xtxp = x.t().dot(&x.dot(&p));
300            let mut ap = xtxp;
301            for i in 0..n_features {
302                ap[i] += self.alpha * p[i];
303            }
304
305            // Compute step size
306            let alpha_cg = rsold / p.dot(&ap);
307
308            // Update weights
309            w = w + alpha_cg * &p;
310
311            // Update residual
312            r = r - alpha_cg * &ap;
313
314            let rsnew = r.dot(&r);
315
316            // Check convergence
317            if rsnew.sqrt() < tol {
318                break;
319            }
320
321            // Update search direction
322            let beta = rsnew / rsold;
323            p = &r + beta * &p;
324            rsold = rsnew;
325        }
326
327        Ok(w)
328    }
329
330    /// Solve linear system Ax = b using Gaussian elimination with partial pivoting
331    fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
332        let n = a.nrows();
333        if n != a.ncols() || n != b.len() {
334            return Err(SklearsError::InvalidInput(
335                "Matrix dimensions must match for linear system solve".to_string(),
336            ));
337        }
338
339        // Create augmented matrix [A|b]
340        let mut aug = Array2::zeros((n, n + 1));
341        for i in 0..n {
342            for j in 0..n {
343                aug[[i, j]] = a[[i, j]];
344            }
345            aug[[i, n]] = b[i];
346        }
347
348        // Forward elimination with partial pivoting
349        for k in 0..n {
350            // Find pivot
351            let mut max_row = k;
352            for i in (k + 1)..n {
353                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
354                    max_row = i;
355                }
356            }
357
358            // Swap rows if needed
359            if max_row != k {
360                for j in 0..=n {
361                    let temp = aug[[k, j]];
362                    aug[[k, j]] = aug[[max_row, j]];
363                    aug[[max_row, j]] = temp;
364                }
365            }
366
367            // Check for zero pivot
368            if aug[[k, k]].abs() < 1e-12 {
369                return Err(SklearsError::InvalidInput(
370                    "Matrix is singular or nearly singular".to_string(),
371                ));
372            }
373
374            // Eliminate column
375            for i in (k + 1)..n {
376                let factor = aug[[i, k]] / aug[[k, k]];
377                for j in k..=n {
378                    aug[[i, j]] -= factor * aug[[k, j]];
379                }
380            }
381        }
382
383        // Back substitution
384        let mut x = Array1::zeros(n);
385        for i in (0..n).rev() {
386            let mut sum = aug[[i, n]];
387            for j in (i + 1)..n {
388                sum -= aug[[i, j]] * x[j];
389            }
390            x[i] = sum / aug[[i, i]];
391        }
392
393        Ok(x)
394    }
395
396    /// Compute SVD using power iteration method
397    /// Returns (U, S, V^T) where X = U * S * V^T
398    fn compute_svd(
399        &self,
400        x: &Array2<Float>,
401    ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
402        let (m, n) = x.dim();
403        let min_dim = m.min(n);
404
405        // For SVD, we compute eigendecomposition of X^T X (for V) and X X^T (for U)
406        let xt = x.t();
407
408        if n <= m {
409            // Thin SVD: compute V from X^T X
410            let xtx = xt.dot(x);
411            let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
412
413            // Singular values are sqrt of eigenvalues
414            let mut singular_vals = Array1::zeros(min_dim);
415            let mut valid_indices = Vec::new();
416            for i in 0..eigenvals_v.len() {
417                if eigenvals_v[i] > 1e-12 {
418                    singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
419                    valid_indices.push(i);
420                    if valid_indices.len() >= min_dim {
421                        break;
422                    }
423                }
424            }
425
426            // Construct V matrix
427            let mut v = Array2::zeros((n, min_dim));
428            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
429                v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
430            }
431
432            // Compute U = X * V * S^(-1)
433            let mut u = Array2::zeros((m, min_dim));
434            for j in 0..valid_indices.len() {
435                let v_col = v.column(j);
436                let xv = x.dot(&v_col);
437                let u_col = &xv / singular_vals[j];
438                u.column_mut(j).assign(&u_col);
439            }
440
441            Ok((u, singular_vals, v.t().to_owned()))
442        } else {
443            // Wide matrix: compute U from X X^T
444            let xxt = x.dot(&xt);
445            let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
446
447            // Singular values are sqrt of eigenvalues
448            let mut singular_vals = Array1::zeros(min_dim);
449            let mut valid_indices = Vec::new();
450            for i in 0..eigenvals_u.len() {
451                if eigenvals_u[i] > 1e-12 {
452                    singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
453                    valid_indices.push(i);
454                    if valid_indices.len() >= min_dim {
455                        break;
456                    }
457                }
458            }
459
460            // Construct U matrix
461            let mut u = Array2::zeros((m, min_dim));
462            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
463                u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
464            }
465
466            // Compute V = X^T * U * S^(-1)
467            let mut v = Array2::zeros((n, min_dim));
468            for j in 0..valid_indices.len() {
469                let u_col = u.column(j);
470                let xtu = xt.dot(&u_col);
471                let v_col = &xtu / singular_vals[j];
472                v.column_mut(j).assign(&v_col);
473            }
474
475            Ok((u, singular_vals, v.t().to_owned()))
476        }
477    }
478
479    /// Compute eigendecomposition for SVD computation
480    fn compute_eigendecomposition_svd(
481        &self,
482        matrix: &Array2<Float>,
483    ) -> Result<(Array1<Float>, Array2<Float>)> {
484        let n = matrix.nrows();
485
486        if n != matrix.ncols() {
487            return Err(SklearsError::InvalidInput(
488                "Matrix must be square for eigendecomposition".to_string(),
489            ));
490        }
491
492        let mut eigenvals = Array1::zeros(n);
493        let mut eigenvecs = Array2::zeros((n, n));
494
495        // Use deflation method to find multiple eigenvalues
496        let mut deflated_matrix = matrix.clone();
497
498        for k in 0..n {
499            // Power iteration for k-th eigenvalue/eigenvector
500            let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
501
502            eigenvals[k] = eigenval;
503            eigenvecs.column_mut(k).assign(&eigenvec);
504
505            // Deflate matrix: A_new = A - λ * v * v^T
506            for i in 0..n {
507                for j in 0..n {
508                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
509                }
510            }
511        }
512
513        // Sort eigenvalues and eigenvectors in descending order
514        let mut indices: Vec<usize> = (0..n).collect();
515        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
516
517        let mut sorted_eigenvals = Array1::zeros(n);
518        let mut sorted_eigenvecs = Array2::zeros((n, n));
519
520        for (new_idx, &old_idx) in indices.iter().enumerate() {
521            sorted_eigenvals[new_idx] = eigenvals[old_idx];
522            sorted_eigenvecs
523                .column_mut(new_idx)
524                .assign(&eigenvecs.column(old_idx));
525        }
526
527        Ok((sorted_eigenvals, sorted_eigenvecs))
528    }
529
530    /// Power iteration method for SVD eigendecomposition
531    fn power_iteration_svd(
532        &self,
533        matrix: &Array2<Float>,
534        max_iter: usize,
535        tol: Float,
536    ) -> Result<(Float, Array1<Float>)> {
537        let n = matrix.nrows();
538
539        // Initialize random vector
540        let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen::<Float>() - 0.5);
541
542        // Normalize
543        let norm = v.dot(&v).sqrt();
544        if norm < 1e-10 {
545            return Err(SklearsError::InvalidInput(
546                "Initial vector has zero norm".to_string(),
547            ));
548        }
549        v /= norm;
550
551        let mut eigenval = 0.0;
552
553        for _iter in 0..max_iter {
554            // Apply matrix
555            let w = matrix.dot(&v);
556
557            // Compute Rayleigh quotient
558            let new_eigenval = v.dot(&w);
559
560            // Normalize
561            let w_norm = w.dot(&w).sqrt();
562            if w_norm < 1e-10 {
563                break;
564            }
565            let new_v = w / w_norm;
566
567            // Check convergence
568            let eigenval_change = (new_eigenval - eigenval).abs();
569            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
570
571            if eigenval_change < tol && vector_change < tol {
572                return Ok((new_eigenval, new_v));
573            }
574
575            eigenval = new_eigenval;
576            v = new_v;
577        }
578
579        Ok((eigenval, v))
580    }
581}
582
583impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
584    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
585        let weights = self
586            .weights_
587            .as_ref()
588            .ok_or_else(|| SklearsError::NotFitted {
589                operation: "predict".to_string(),
590            })?;
591
592        let feature_transformer =
593            self.feature_transformer_
594                .as_ref()
595                .ok_or_else(|| SklearsError::NotFitted {
596                    operation: "predict".to_string(),
597                })?;
598
599        // Transform features
600        let x_transformed = feature_transformer.transform(x)?;
601
602        // SIMD-accelerated prediction computation using optimized matrix-vector multiplication
603        let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
604        let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
605
606        // Matrix-vector multiplication optimized via BLAS backend
607        let predictions_f64 = x_f64.dot(&weights_f64);
608
609        // Convert back to Float
610        let predictions =
611            Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
612
613        Ok(predictions)
614    }
615}
616
617/// Online/Incremental Kernel Ridge Regression
618///
619/// This variant allows for online updates to the model as new data arrives.
620#[derive(Debug, Clone)]
621pub struct OnlineKernelRidgeRegression<State = Untrained> {
622    /// Base kernel ridge regression
623    pub base_model: KernelRidgeRegression<State>,
624    /// Forgetting factor for online updates
625    pub forgetting_factor: Float,
626    /// Update frequency
627    pub update_frequency: usize,
628
629    // Online state
630    update_count_: usize,
631    accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
632
633    _state: PhantomData<State>,
634}
635
636impl OnlineKernelRidgeRegression<Untrained> {
637    /// Create a new online kernel ridge regression model
638    pub fn new(approximation_method: ApproximationMethod) -> Self {
639        Self {
640            base_model: KernelRidgeRegression::new(approximation_method),
641            forgetting_factor: 0.99,
642            update_frequency: 100,
643            update_count_: 0,
644            accumulated_data_: None,
645            _state: PhantomData,
646        }
647    }
648
649    /// Set forgetting factor
650    pub fn forgetting_factor(mut self, factor: Float) -> Self {
651        self.forgetting_factor = factor;
652        self
653    }
654
655    /// Set update frequency
656    pub fn update_frequency(mut self, frequency: usize) -> Self {
657        self.update_frequency = frequency;
658        self
659    }
660
661    /// Set alpha parameter
662    pub fn alpha(mut self, alpha: Float) -> Self {
663        self.base_model = self.base_model.alpha(alpha);
664        self
665    }
666
667    /// Set random state
668    pub fn random_state(mut self, seed: u64) -> Self {
669        self.base_model = self.base_model.random_state(seed);
670        self
671    }
672}
673
674impl Estimator for OnlineKernelRidgeRegression<Untrained> {
675    type Config = ();
676    type Error = SklearsError;
677    type Float = Float;
678
679    fn config(&self) -> &Self::Config {
680        &()
681    }
682}
683
684impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
685    type Fitted = OnlineKernelRidgeRegression<Trained>;
686
687    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
688        let fitted_base = self.base_model.fit(x, y)?;
689
690        Ok(OnlineKernelRidgeRegression {
691            base_model: fitted_base,
692            forgetting_factor: self.forgetting_factor,
693            update_frequency: self.update_frequency,
694            update_count_: 0,
695            accumulated_data_: None,
696            _state: PhantomData,
697        })
698    }
699}
700
701impl OnlineKernelRidgeRegression<Trained> {
702    /// Update the model with new data
703    pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
704        // Accumulate new data
705        match &self.accumulated_data_ {
706            Some((x_acc, y_acc)) => {
707                let x_combined =
708                    scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
709                let y_combined =
710                    scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
711                self.accumulated_data_ = Some((x_combined, y_combined));
712            }
713            None => {
714                self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
715            }
716        }
717
718        self.update_count_ += 1;
719
720        // Check if it's time to update
721        if self.update_count_ % self.update_frequency == 0 {
722            if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
723                // Refit the model with accumulated data
724                // In practice, you might want to implement a more sophisticated
725                // online update algorithm here
726                let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
727                self.base_model = updated_base;
728                self.accumulated_data_ = None;
729            }
730        }
731
732        Ok(self)
733    }
734
735    /// Get the number of updates performed
736    pub fn update_count(&self) -> usize {
737        self.update_count_
738    }
739}
740
741impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
742    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
743        self.base_model.predict(x)
744    }
745}
746
747// Helper trait to convert trained model to untrained
748pub trait IntoUntrained<T> {
749    fn into_untrained(self) -> T;
750}
751
752impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
753    fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
754        KernelRidgeRegression {
755            approximation_method: self.approximation_method,
756            alpha: self.alpha,
757            solver: self.solver,
758            random_state: self.random_state,
759            weights_: None,
760            feature_transformer_: None,
761            _state: PhantomData,
762        }
763    }
764}
765
766#[allow(non_snake_case)]
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use scirs2_core::ndarray::array;
771
772    #[test]
773    fn test_kernel_ridge_regression_rff() {
774        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
775        let y = array![1.0, 4.0, 9.0, 16.0];
776
777        let approximation = ApproximationMethod::RandomFourierFeatures {
778            n_components: 50,
779            gamma: 0.1,
780        };
781
782        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
783        let fitted = krr.fit(&x, &y).unwrap();
784        let predictions = fitted.predict(&x).unwrap();
785
786        assert_eq!(predictions.len(), 4);
787        // Check that predictions are reasonable
788        for pred in predictions.iter() {
789            assert!(pred.is_finite());
790        }
791    }
792
793    #[test]
794    fn test_kernel_ridge_regression_nystroem() {
795        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
796        let y = array![1.0, 2.0, 3.0];
797
798        let approximation = ApproximationMethod::Nystroem {
799            kernel: Kernel::Rbf { gamma: 1.0 },
800            n_components: 3,
801            sampling_strategy: SamplingStrategy::Random,
802        };
803
804        let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
805        let fitted = krr.fit(&x, &y).unwrap();
806        let predictions = fitted.predict(&x).unwrap();
807
808        assert_eq!(predictions.len(), 3);
809    }
810
811    #[test]
812    fn test_kernel_ridge_regression_fastfood() {
813        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
814        let y = array![1.0, 2.0];
815
816        let approximation = ApproximationMethod::Fastfood {
817            n_components: 8,
818            gamma: 0.5,
819        };
820
821        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
822        let fitted = krr.fit(&x, &y).unwrap();
823        let predictions = fitted.predict(&x).unwrap();
824
825        assert_eq!(predictions.len(), 2);
826    }
827
828    #[test]
829    fn test_different_solvers() {
830        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
831        let y = array![1.0, 2.0, 3.0];
832
833        let approximation = ApproximationMethod::RandomFourierFeatures {
834            n_components: 10,
835            gamma: 1.0,
836        };
837
838        // Test Direct solver
839        let krr_direct = KernelRidgeRegression::new(approximation.clone())
840            .solver(Solver::Direct)
841            .alpha(0.1);
842        let fitted_direct = krr_direct.fit(&x, &y).unwrap();
843        let pred_direct = fitted_direct.predict(&x).unwrap();
844
845        // Test SVD solver
846        let krr_svd = KernelRidgeRegression::new(approximation.clone())
847            .solver(Solver::SVD)
848            .alpha(0.1);
849        let fitted_svd = krr_svd.fit(&x, &y).unwrap();
850        let pred_svd = fitted_svd.predict(&x).unwrap();
851
852        // Test CG solver
853        let krr_cg = KernelRidgeRegression::new(approximation)
854            .solver(Solver::ConjugateGradient {
855                max_iter: 100,
856                tol: 1e-6,
857            })
858            .alpha(0.1);
859        let fitted_cg = krr_cg.fit(&x, &y).unwrap();
860        let pred_cg = fitted_cg.predict(&x).unwrap();
861
862        assert_eq!(pred_direct.len(), 3);
863        assert_eq!(pred_svd.len(), 3);
864        assert_eq!(pred_cg.len(), 3);
865    }
866
867    #[test]
868    fn test_online_kernel_ridge_regression() {
869        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
870        let y_initial = array![1.0, 2.0];
871        let x_new = array![[3.0, 4.0], [4.0, 5.0]];
872        let y_new = array![3.0, 4.0];
873
874        let approximation = ApproximationMethod::RandomFourierFeatures {
875            n_components: 20,
876            gamma: 0.5,
877        };
878
879        let online_krr = OnlineKernelRidgeRegression::new(approximation)
880            .alpha(0.1)
881            .update_frequency(2);
882
883        let fitted = online_krr.fit(&x_initial, &y_initial).unwrap();
884        let updated = fitted.partial_fit(&x_new, &y_new).unwrap();
885
886        assert_eq!(updated.update_count(), 1);
887
888        let predictions = updated.predict(&x_initial).unwrap();
889        assert_eq!(predictions.len(), 2);
890    }
891
892    #[test]
893    fn test_reproducibility() {
894        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
895        let y = array![1.0, 2.0, 3.0];
896
897        let approximation = ApproximationMethod::RandomFourierFeatures {
898            n_components: 10,
899            gamma: 1.0,
900        };
901
902        let krr1 = KernelRidgeRegression::new(approximation.clone())
903            .alpha(0.1)
904            .random_state(42);
905        let fitted1 = krr1.fit(&x, &y).unwrap();
906        let pred1 = fitted1.predict(&x).unwrap();
907
908        let krr2 = KernelRidgeRegression::new(approximation)
909            .alpha(0.1)
910            .random_state(42);
911        let fitted2 = krr2.fit(&x, &y).unwrap();
912        let pred2 = fitted2.predict(&x).unwrap();
913
914        assert_eq!(pred1.len(), pred2.len());
915        for i in 0..pred1.len() {
916            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
917        }
918    }
919}