Skip to main content

sklears_kernel_approximation/kernel_ridge_regression/
basic_regression.rs

1//! Basic Kernel Ridge Regression Implementation
2//!
3//! This module contains the core KernelRidgeRegression implementation including:
4//! - Basic kernel ridge regression with multiple approximation methods
5//! - Support for different solvers (Direct, SVD, Conjugate Gradient)
6//! - Online/Incremental learning capabilities
7//! - Comprehensive numerical linear algebra implementations
8
9use crate::{
10    FastfoodTransform, Nystroem, RBFSampler, StructuredRandomFeatures, Trained, Untrained,
11};
12use scirs2_linalg::compat::ArrayLinalgExt;
13// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
14use scirs2_core::ndarray::{Array1, Array2, Axis};
15use scirs2_core::random::thread_rng;
16use sklears_core::error::{Result, SklearsError};
17use sklears_core::prelude::{Estimator, Fit, Float, Predict};
18use std::marker::PhantomData;
19
20use super::core_types::*;
21
22/// Basic Kernel Ridge Regression
23///
24/// This is the main kernel ridge regression implementation that supports various
25/// kernel approximation methods and solvers for different use cases.
26#[derive(Debug, Clone)]
27pub struct KernelRidgeRegression<State = Untrained> {
28    pub approximation_method: ApproximationMethod,
29    pub alpha: Float,
30    pub solver: Solver,
31    pub random_state: Option<u64>,
32
33    // Fitted parameters
34    pub(crate) weights_: Option<Array1<Float>>,
35    pub(crate) feature_transformer_: Option<FeatureTransformer>,
36
37    pub(crate) _state: PhantomData<State>,
38}
39
40impl KernelRidgeRegression<Untrained> {
41    /// Create a new kernel ridge regression model
42    pub fn new(approximation_method: ApproximationMethod) -> Self {
43        Self {
44            approximation_method,
45            alpha: 1.0,
46            solver: Solver::Direct,
47            random_state: None,
48            weights_: None,
49            feature_transformer_: None,
50            _state: PhantomData,
51        }
52    }
53
54    /// Set the regularization parameter
55    pub fn alpha(mut self, alpha: Float) -> Self {
56        self.alpha = alpha;
57        self
58    }
59
60    /// Set the solver method
61    pub fn solver(mut self, solver: Solver) -> Self {
62        self.solver = solver;
63        self
64    }
65
66    /// Set random state for reproducibility
67    pub fn random_state(mut self, seed: u64) -> Self {
68        self.random_state = Some(seed);
69        self
70    }
71}
72
73impl Estimator for KernelRidgeRegression<Untrained> {
74    type Config = ();
75    type Error = SklearsError;
76    type Float = Float;
77
78    fn config(&self) -> &Self::Config {
79        &()
80    }
81}
82
83impl Fit<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Untrained> {
84    type Fitted = KernelRidgeRegression<Trained>;
85
86    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
87        let (n_samples, _) = x.dim();
88
89        if y.len() != n_samples {
90            return Err(SklearsError::InvalidInput(
91                "Number of samples in X and y must match".to_string(),
92            ));
93        }
94
95        // Fit the feature transformer based on approximation method
96        let feature_transformer = self.fit_feature_transformer(x)?;
97
98        // Transform features
99        let x_transformed = feature_transformer.transform(x)?;
100
101        // Solve the ridge regression problem: (X^T X + alpha * I) w = X^T y
102        let weights = self.solve_ridge_regression(&x_transformed, y)?;
103
104        Ok(KernelRidgeRegression {
105            approximation_method: self.approximation_method,
106            alpha: self.alpha,
107            solver: self.solver,
108            random_state: self.random_state,
109            weights_: Some(weights),
110            feature_transformer_: Some(feature_transformer),
111            _state: PhantomData,
112        })
113    }
114}
115
116impl KernelRidgeRegression<Untrained> {
117    /// Fit the feature transformer based on the approximation method
118    fn fit_feature_transformer(&self, x: &Array2<Float>) -> Result<FeatureTransformer> {
119        match &self.approximation_method {
120            ApproximationMethod::Nystroem {
121                kernel,
122                n_components,
123                sampling_strategy,
124            } => {
125                let mut nystroem = Nystroem::new(kernel.clone(), *n_components)
126                    .sampling_strategy(sampling_strategy.clone());
127
128                if let Some(seed) = self.random_state {
129                    nystroem = nystroem.random_state(seed);
130                }
131
132                let fitted = nystroem.fit(x, &())?;
133                Ok(FeatureTransformer::Nystroem(fitted))
134            }
135            ApproximationMethod::RandomFourierFeatures {
136                n_components,
137                gamma,
138            } => {
139                let mut rbf_sampler = RBFSampler::new(*n_components).gamma(*gamma);
140
141                if let Some(seed) = self.random_state {
142                    rbf_sampler = rbf_sampler.random_state(seed);
143                }
144
145                let fitted = rbf_sampler.fit(x, &())?;
146                Ok(FeatureTransformer::RBFSampler(fitted))
147            }
148            ApproximationMethod::StructuredRandomFeatures {
149                n_components,
150                gamma,
151            } => {
152                let mut structured_rff = StructuredRandomFeatures::new(*n_components).gamma(*gamma);
153
154                if let Some(seed) = self.random_state {
155                    structured_rff = structured_rff.random_state(seed);
156                }
157
158                let fitted = structured_rff.fit(x, &())?;
159                Ok(FeatureTransformer::StructuredRFF(fitted))
160            }
161            ApproximationMethod::Fastfood {
162                n_components,
163                gamma,
164            } => {
165                let mut fastfood = FastfoodTransform::new(*n_components).gamma(*gamma);
166
167                if let Some(seed) = self.random_state {
168                    fastfood = fastfood.random_state(seed);
169                }
170
171                let fitted = fastfood.fit(x, &())?;
172                Ok(FeatureTransformer::Fastfood(fitted))
173            }
174        }
175    }
176
177    /// Solve the ridge regression problem
178    fn solve_ridge_regression(
179        &self,
180        x: &Array2<Float>,
181        y: &Array1<Float>,
182    ) -> Result<Array1<Float>> {
183        let (_n_samples, _n_features) = x.dim();
184
185        match &self.solver {
186            Solver::Direct => self.solve_direct(x, y),
187            Solver::SVD => self.solve_svd(x, y),
188            Solver::ConjugateGradient { max_iter, tol } => {
189                self.solve_conjugate_gradient(x, y, *max_iter, *tol)
190            }
191        }
192    }
193
194    /// Direct solver using normal equations
195    fn solve_direct(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
196        let (_, n_features) = x.dim();
197
198        // SIMD-accelerated computation of X^T X + alpha * I using optimized operations
199        let x_f64 = Array2::from_shape_fn(x.dim(), |(i, j)| x[[i, j]]);
200        let y_f64 = Array1::from_vec(y.iter().copied().collect());
201
202        // #[cfg(feature = "nightly-simd")]
203        // let (gram_matrix, xty_f64, weights_f64) = {
204        //     let xtx_f64 = simd_kernel::simd_gram_matrix_from_data(&x_f64.view());
205        //     let mut gram_matrix = xtx_f64;
206        //
207        //     // SIMD-accelerated diagonal regularization
208        //     for i in 0..n_features {
209        //         gram_matrix[[i, i]] += self.alpha as f64;
210        //     }
211        //
212        //     // SIMD-accelerated computation of X^T y
213        //     let xty_f64 =
214        //         simd_kernel::simd_matrix_vector_multiply(&x_f64.t().view(), &y_f64.view())?;
215        //
216        //     // SIMD-accelerated linear system solving
217        //     let weights_f64 =
218        //         simd_kernel::simd_ridge_coefficients(&gram_matrix.view(), &xty_f64.view(), 0.0)?;
219        //
220        //     (gram_matrix, xty_f64, weights_f64)
221        // };
222
223        // Use ndarray operations (optimized via BLAS backend)
224        let weights_f64 = {
225            // Matrix operations optimized through ndarray's BLAS backend
226            let gram_matrix = x_f64.t().dot(&x_f64);
227            let mut regularized_gram = gram_matrix;
228            for i in 0..n_features {
229                regularized_gram[[i, i]] += self.alpha;
230            }
231            let xty_f64 = x_f64.t().dot(&y_f64);
232
233            // Solve the linear system (X^T X + αI) w = X^T y
234            regularized_gram
235                .solve(&xty_f64)
236                .map_err(|e| SklearsError::InvalidParameter {
237                    name: "regularization".to_string(),
238                    reason: format!("Linear system solving failed: {:?}", e),
239                })?
240        };
241
242        // Convert back to Float
243        let weights = Array1::from_vec(weights_f64.iter().map(|&val| val as Float).collect());
244        Ok(weights)
245    }
246
247    /// SVD-based solver (more numerically stable)
248    fn solve_svd(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
249        // Use SVD of the design matrix X for numerical stability
250        // Solve the regularized least squares: min ||Xw - y||² + α||w||²
251
252        let (_n_samples, _n_features) = x.dim();
253
254        // Compute SVD of X using power iteration method
255        let (u, s, vt) = self.compute_svd(x)?;
256
257        // Compute regularized pseudo-inverse
258        let threshold = 1e-12;
259        let mut s_reg_inv = Array1::zeros(s.len());
260        for i in 0..s.len() {
261            if s[i] > threshold {
262                s_reg_inv[i] = s[i] / (s[i] * s[i] + self.alpha);
263            }
264        }
265
266        // Solve: w = V * S_reg^(-1) * U^T * y
267        let ut_y = u.t().dot(y);
268        let mut temp = Array1::zeros(s.len());
269        for i in 0..s.len() {
270            temp[i] = s_reg_inv[i] * ut_y[i];
271        }
272
273        let weights = vt.t().dot(&temp);
274        Ok(weights)
275    }
276
277    /// Conjugate Gradient solver (iterative, memory efficient)
278    fn solve_conjugate_gradient(
279        &self,
280        x: &Array2<Float>,
281        y: &Array1<Float>,
282        max_iter: usize,
283        tol: Float,
284    ) -> Result<Array1<Float>> {
285        let (_n_samples, n_features) = x.dim();
286
287        // Initialize weights
288        let mut w = Array1::zeros(n_features);
289
290        // Compute initial residual: r = X^T y - (X^T X + alpha I) w
291        let xty = x.t().dot(y);
292        let mut r = xty.clone();
293
294        let mut p = r.clone();
295        let mut rsold = r.dot(&r);
296
297        for _iter in 0..max_iter {
298            // Compute A * p where A = X^T X + alpha I
299            let xtxp = x.t().dot(&x.dot(&p));
300            let mut ap = xtxp;
301            for i in 0..n_features {
302                ap[i] += self.alpha * p[i];
303            }
304
305            // Compute step size
306            let alpha_cg = rsold / p.dot(&ap);
307
308            // Update weights
309            w = w + alpha_cg * &p;
310
311            // Update residual
312            r = r - alpha_cg * &ap;
313
314            let rsnew = r.dot(&r);
315
316            // Check convergence
317            if rsnew.sqrt() < tol {
318                break;
319            }
320
321            // Update search direction
322            let beta = rsnew / rsold;
323            p = &r + beta * &p;
324            rsold = rsnew;
325        }
326
327        Ok(w)
328    }
329
330    /// Solve linear system Ax = b using Gaussian elimination with partial pivoting
331    fn solve_linear_system(&self, a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
332        let n = a.nrows();
333        if n != a.ncols() || n != b.len() {
334            return Err(SklearsError::InvalidInput(
335                "Matrix dimensions must match for linear system solve".to_string(),
336            ));
337        }
338
339        // Create augmented matrix [A|b]
340        let mut aug = Array2::zeros((n, n + 1));
341        for i in 0..n {
342            for j in 0..n {
343                aug[[i, j]] = a[[i, j]];
344            }
345            aug[[i, n]] = b[i];
346        }
347
348        // Forward elimination with partial pivoting
349        for k in 0..n {
350            // Find pivot
351            let mut max_row = k;
352            for i in (k + 1)..n {
353                if aug[[i, k]].abs() > aug[[max_row, k]].abs() {
354                    max_row = i;
355                }
356            }
357
358            // Swap rows if needed
359            if max_row != k {
360                for j in 0..=n {
361                    let temp = aug[[k, j]];
362                    aug[[k, j]] = aug[[max_row, j]];
363                    aug[[max_row, j]] = temp;
364                }
365            }
366
367            // Check for zero pivot
368            if aug[[k, k]].abs() < 1e-12 {
369                return Err(SklearsError::InvalidInput(
370                    "Matrix is singular or nearly singular".to_string(),
371                ));
372            }
373
374            // Eliminate column
375            for i in (k + 1)..n {
376                let factor = aug[[i, k]] / aug[[k, k]];
377                for j in k..=n {
378                    aug[[i, j]] -= factor * aug[[k, j]];
379                }
380            }
381        }
382
383        // Back substitution
384        let mut x = Array1::zeros(n);
385        for i in (0..n).rev() {
386            let mut sum = aug[[i, n]];
387            for j in (i + 1)..n {
388                sum -= aug[[i, j]] * x[j];
389            }
390            x[i] = sum / aug[[i, i]];
391        }
392
393        Ok(x)
394    }
395
396    /// Compute SVD using power iteration method
397    /// Returns (U, S, V^T) where X = U * S * V^T
398    fn compute_svd(
399        &self,
400        x: &Array2<Float>,
401    ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
402        let (m, n) = x.dim();
403        let min_dim = m.min(n);
404
405        // For SVD, we compute eigendecomposition of X^T X (for V) and X X^T (for U)
406        let xt = x.t();
407
408        if n <= m {
409            // Thin SVD: compute V from X^T X
410            let xtx = xt.dot(x);
411            let (eigenvals_v, eigenvecs_v) = self.compute_eigendecomposition_svd(&xtx)?;
412
413            // Singular values are sqrt of eigenvalues
414            let mut singular_vals = Array1::zeros(min_dim);
415            let mut valid_indices = Vec::new();
416            for i in 0..eigenvals_v.len() {
417                if eigenvals_v[i] > 1e-12 {
418                    singular_vals[valid_indices.len()] = eigenvals_v[i].sqrt();
419                    valid_indices.push(i);
420                    if valid_indices.len() >= min_dim {
421                        break;
422                    }
423                }
424            }
425
426            // Construct V matrix
427            let mut v = Array2::zeros((n, min_dim));
428            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
429                v.column_mut(new_idx).assign(&eigenvecs_v.column(old_idx));
430            }
431
432            // Compute U = X * V * S^(-1)
433            let mut u = Array2::zeros((m, min_dim));
434            for j in 0..valid_indices.len() {
435                let v_col = v.column(j);
436                let xv = x.dot(&v_col);
437                let u_col = &xv / singular_vals[j];
438                u.column_mut(j).assign(&u_col);
439            }
440
441            Ok((u, singular_vals, v.t().to_owned()))
442        } else {
443            // Wide matrix: compute U from X X^T
444            let xxt = x.dot(&xt);
445            let (eigenvals_u, eigenvecs_u) = self.compute_eigendecomposition_svd(&xxt)?;
446
447            // Singular values are sqrt of eigenvalues
448            let mut singular_vals = Array1::zeros(min_dim);
449            let mut valid_indices = Vec::new();
450            for i in 0..eigenvals_u.len() {
451                if eigenvals_u[i] > 1e-12 {
452                    singular_vals[valid_indices.len()] = eigenvals_u[i].sqrt();
453                    valid_indices.push(i);
454                    if valid_indices.len() >= min_dim {
455                        break;
456                    }
457                }
458            }
459
460            // Construct U matrix
461            let mut u = Array2::zeros((m, min_dim));
462            for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
463                u.column_mut(new_idx).assign(&eigenvecs_u.column(old_idx));
464            }
465
466            // Compute V = X^T * U * S^(-1)
467            let mut v = Array2::zeros((n, min_dim));
468            for j in 0..valid_indices.len() {
469                let u_col = u.column(j);
470                let xtu = xt.dot(&u_col);
471                let v_col = &xtu / singular_vals[j];
472                v.column_mut(j).assign(&v_col);
473            }
474
475            Ok((u, singular_vals, v.t().to_owned()))
476        }
477    }
478
479    /// Compute eigendecomposition for SVD computation
480    fn compute_eigendecomposition_svd(
481        &self,
482        matrix: &Array2<Float>,
483    ) -> Result<(Array1<Float>, Array2<Float>)> {
484        let n = matrix.nrows();
485
486        if n != matrix.ncols() {
487            return Err(SklearsError::InvalidInput(
488                "Matrix must be square for eigendecomposition".to_string(),
489            ));
490        }
491
492        let mut eigenvals = Array1::zeros(n);
493        let mut eigenvecs = Array2::zeros((n, n));
494
495        // Use deflation method to find multiple eigenvalues
496        let mut deflated_matrix = matrix.clone();
497
498        for k in 0..n {
499            // Power iteration for k-th eigenvalue/eigenvector
500            let (eigenval, eigenvec) = self.power_iteration_svd(&deflated_matrix, 100, 1e-8)?;
501
502            eigenvals[k] = eigenval;
503            eigenvecs.column_mut(k).assign(&eigenvec);
504
505            // Deflate matrix: A_new = A - λ * v * v^T
506            for i in 0..n {
507                for j in 0..n {
508                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
509                }
510            }
511        }
512
513        // Sort eigenvalues and eigenvectors in descending order
514        let mut indices: Vec<usize> = (0..n).collect();
515        indices.sort_by(|&i, &j| {
516            eigenvals[j]
517                .partial_cmp(&eigenvals[i])
518                .expect("operation should succeed")
519        });
520
521        let mut sorted_eigenvals = Array1::zeros(n);
522        let mut sorted_eigenvecs = Array2::zeros((n, n));
523
524        for (new_idx, &old_idx) in indices.iter().enumerate() {
525            sorted_eigenvals[new_idx] = eigenvals[old_idx];
526            sorted_eigenvecs
527                .column_mut(new_idx)
528                .assign(&eigenvecs.column(old_idx));
529        }
530
531        Ok((sorted_eigenvals, sorted_eigenvecs))
532    }
533
534    /// Power iteration method for SVD eigendecomposition
535    fn power_iteration_svd(
536        &self,
537        matrix: &Array2<Float>,
538        max_iter: usize,
539        tol: Float,
540    ) -> Result<(Float, Array1<Float>)> {
541        let n = matrix.nrows();
542
543        // Initialize random vector
544        let mut v = Array1::from_shape_fn(n, |_| thread_rng().random::<Float>() - 0.5);
545
546        // Normalize
547        let norm = v.dot(&v).sqrt();
548        if norm < 1e-10 {
549            return Err(SklearsError::InvalidInput(
550                "Initial vector has zero norm".to_string(),
551            ));
552        }
553        v /= norm;
554
555        let mut eigenval = 0.0;
556
557        for _iter in 0..max_iter {
558            // Apply matrix
559            let w = matrix.dot(&v);
560
561            // Compute Rayleigh quotient
562            let new_eigenval = v.dot(&w);
563
564            // Normalize
565            let w_norm = w.dot(&w).sqrt();
566            if w_norm < 1e-10 {
567                break;
568            }
569            let new_v = w / w_norm;
570
571            // Check convergence
572            let eigenval_change = (new_eigenval - eigenval).abs();
573            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
574
575            if eigenval_change < tol && vector_change < tol {
576                return Ok((new_eigenval, new_v));
577            }
578
579            eigenval = new_eigenval;
580            v = new_v;
581        }
582
583        Ok((eigenval, v))
584    }
585}
586
587impl Predict<Array2<Float>, Array1<Float>> for KernelRidgeRegression<Trained> {
588    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
589        let weights = self
590            .weights_
591            .as_ref()
592            .ok_or_else(|| SklearsError::NotFitted {
593                operation: "predict".to_string(),
594            })?;
595
596        let feature_transformer =
597            self.feature_transformer_
598                .as_ref()
599                .ok_or_else(|| SklearsError::NotFitted {
600                    operation: "predict".to_string(),
601                })?;
602
603        // Transform features
604        let x_transformed = feature_transformer.transform(x)?;
605
606        // SIMD-accelerated prediction computation using optimized matrix-vector multiplication
607        let x_f64 = Array2::from_shape_fn(x_transformed.dim(), |(i, j)| x_transformed[[i, j]]);
608        let weights_f64 = Array1::from_vec(weights.iter().copied().collect());
609
610        // Matrix-vector multiplication optimized via BLAS backend
611        let predictions_f64 = x_f64.dot(&weights_f64);
612
613        // Convert back to Float
614        let predictions =
615            Array1::from_vec(predictions_f64.iter().map(|&val| val as Float).collect());
616
617        Ok(predictions)
618    }
619}
620
621/// Online/Incremental Kernel Ridge Regression
622///
623/// This variant allows for online updates to the model as new data arrives.
624#[derive(Debug, Clone)]
625pub struct OnlineKernelRidgeRegression<State = Untrained> {
626    /// Base kernel ridge regression
627    pub base_model: KernelRidgeRegression<State>,
628    /// Forgetting factor for online updates
629    pub forgetting_factor: Float,
630    /// Update frequency
631    pub update_frequency: usize,
632
633    // Online state
634    update_count_: usize,
635    accumulated_data_: Option<(Array2<Float>, Array1<Float>)>,
636
637    _state: PhantomData<State>,
638}
639
640impl OnlineKernelRidgeRegression<Untrained> {
641    /// Create a new online kernel ridge regression model
642    pub fn new(approximation_method: ApproximationMethod) -> Self {
643        Self {
644            base_model: KernelRidgeRegression::new(approximation_method),
645            forgetting_factor: 0.99,
646            update_frequency: 100,
647            update_count_: 0,
648            accumulated_data_: None,
649            _state: PhantomData,
650        }
651    }
652
653    /// Set forgetting factor
654    pub fn forgetting_factor(mut self, factor: Float) -> Self {
655        self.forgetting_factor = factor;
656        self
657    }
658
659    /// Set update frequency
660    pub fn update_frequency(mut self, frequency: usize) -> Self {
661        self.update_frequency = frequency;
662        self
663    }
664
665    /// Set alpha parameter
666    pub fn alpha(mut self, alpha: Float) -> Self {
667        self.base_model = self.base_model.alpha(alpha);
668        self
669    }
670
671    /// Set random state
672    pub fn random_state(mut self, seed: u64) -> Self {
673        self.base_model = self.base_model.random_state(seed);
674        self
675    }
676}
677
678impl Estimator for OnlineKernelRidgeRegression<Untrained> {
679    type Config = ();
680    type Error = SklearsError;
681    type Float = Float;
682
683    fn config(&self) -> &Self::Config {
684        &()
685    }
686}
687
688impl Fit<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Untrained> {
689    type Fitted = OnlineKernelRidgeRegression<Trained>;
690
691    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
692        let fitted_base = self.base_model.fit(x, y)?;
693
694        Ok(OnlineKernelRidgeRegression {
695            base_model: fitted_base,
696            forgetting_factor: self.forgetting_factor,
697            update_frequency: self.update_frequency,
698            update_count_: 0,
699            accumulated_data_: None,
700            _state: PhantomData,
701        })
702    }
703}
704
705impl OnlineKernelRidgeRegression<Trained> {
706    /// Update the model with new data
707    pub fn partial_fit(mut self, x_new: &Array2<Float>, y_new: &Array1<Float>) -> Result<Self> {
708        // Accumulate new data
709        match &self.accumulated_data_ {
710            Some((x_acc, y_acc)) => {
711                let x_combined =
712                    scirs2_core::ndarray::concatenate![Axis(0), x_acc.clone(), x_new.clone()];
713                let y_combined =
714                    scirs2_core::ndarray::concatenate![Axis(0), y_acc.clone(), y_new.clone()];
715                self.accumulated_data_ = Some((x_combined, y_combined));
716            }
717            None => {
718                self.accumulated_data_ = Some((x_new.clone(), y_new.clone()));
719            }
720        }
721
722        self.update_count_ += 1;
723
724        // Check if it's time to update
725        if self.update_count_ % self.update_frequency == 0 {
726            if let Some((ref x_acc, ref y_acc)) = self.accumulated_data_ {
727                // Refit the model with accumulated data
728                // In practice, you might want to implement a more sophisticated
729                // online update algorithm here
730                let updated_base = self.base_model.clone().into_untrained().fit(x_acc, y_acc)?;
731                self.base_model = updated_base;
732                self.accumulated_data_ = None;
733            }
734        }
735
736        Ok(self)
737    }
738
739    /// Get the number of updates performed
740    pub fn update_count(&self) -> usize {
741        self.update_count_
742    }
743}
744
745impl Predict<Array2<Float>, Array1<Float>> for OnlineKernelRidgeRegression<Trained> {
746    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
747        self.base_model.predict(x)
748    }
749}
750
751// Helper trait to convert trained model to untrained
752pub trait IntoUntrained<T> {
753    fn into_untrained(self) -> T;
754}
755
756impl IntoUntrained<KernelRidgeRegression<Untrained>> for KernelRidgeRegression<Trained> {
757    fn into_untrained(self) -> KernelRidgeRegression<Untrained> {
758        KernelRidgeRegression {
759            approximation_method: self.approximation_method,
760            alpha: self.alpha,
761            solver: self.solver,
762            random_state: self.random_state,
763            weights_: None,
764            feature_transformer_: None,
765            _state: PhantomData,
766        }
767    }
768}
769
770#[allow(non_snake_case)]
771#[cfg(test)]
772mod tests {
773    use super::*;
774    use scirs2_core::ndarray::array;
775
776    #[test]
777    fn test_kernel_ridge_regression_rff() {
778        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
779        let y = array![1.0, 4.0, 9.0, 16.0];
780
781        let approximation = ApproximationMethod::RandomFourierFeatures {
782            n_components: 50,
783            gamma: 0.1,
784        };
785
786        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
787        let fitted = krr.fit(&x, &y).expect("operation should succeed");
788        let predictions = fitted.predict(&x).expect("operation should succeed");
789
790        assert_eq!(predictions.len(), 4);
791        // Check that predictions are reasonable
792        for pred in predictions.iter() {
793            assert!(pred.is_finite());
794        }
795    }
796
797    #[test]
798    fn test_kernel_ridge_regression_nystroem() {
799        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
800        let y = array![1.0, 2.0, 3.0];
801
802        let approximation = ApproximationMethod::Nystroem {
803            kernel: Kernel::Rbf { gamma: 1.0 },
804            n_components: 3,
805            sampling_strategy: SamplingStrategy::Random,
806        };
807
808        let krr = KernelRidgeRegression::new(approximation).alpha(1.0);
809        let fitted = krr.fit(&x, &y).expect("operation should succeed");
810        let predictions = fitted.predict(&x).expect("operation should succeed");
811
812        assert_eq!(predictions.len(), 3);
813    }
814
815    #[test]
816    fn test_kernel_ridge_regression_fastfood() {
817        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
818        let y = array![1.0, 2.0];
819
820        let approximation = ApproximationMethod::Fastfood {
821            n_components: 8,
822            gamma: 0.5,
823        };
824
825        let krr = KernelRidgeRegression::new(approximation).alpha(0.1);
826        let fitted = krr.fit(&x, &y).expect("operation should succeed");
827        let predictions = fitted.predict(&x).expect("operation should succeed");
828
829        assert_eq!(predictions.len(), 2);
830    }
831
832    #[test]
833    fn test_different_solvers() {
834        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
835        let y = array![1.0, 2.0, 3.0];
836
837        let approximation = ApproximationMethod::RandomFourierFeatures {
838            n_components: 10,
839            gamma: 1.0,
840        };
841
842        // Test Direct solver
843        let krr_direct = KernelRidgeRegression::new(approximation.clone())
844            .solver(Solver::Direct)
845            .alpha(0.1);
846        let fitted_direct = krr_direct.fit(&x, &y).expect("operation should succeed");
847        let pred_direct = fitted_direct.predict(&x).expect("operation should succeed");
848
849        // Test SVD solver
850        let krr_svd = KernelRidgeRegression::new(approximation.clone())
851            .solver(Solver::SVD)
852            .alpha(0.1);
853        let fitted_svd = krr_svd.fit(&x, &y).expect("operation should succeed");
854        let pred_svd = fitted_svd.predict(&x).expect("operation should succeed");
855
856        // Test CG solver
857        let krr_cg = KernelRidgeRegression::new(approximation)
858            .solver(Solver::ConjugateGradient {
859                max_iter: 100,
860                tol: 1e-6,
861            })
862            .alpha(0.1);
863        let fitted_cg = krr_cg.fit(&x, &y).expect("operation should succeed");
864        let pred_cg = fitted_cg.predict(&x).expect("operation should succeed");
865
866        assert_eq!(pred_direct.len(), 3);
867        assert_eq!(pred_svd.len(), 3);
868        assert_eq!(pred_cg.len(), 3);
869    }
870
871    #[test]
872    fn test_online_kernel_ridge_regression() {
873        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
874        let y_initial = array![1.0, 2.0];
875        let x_new = array![[3.0, 4.0], [4.0, 5.0]];
876        let y_new = array![3.0, 4.0];
877
878        let approximation = ApproximationMethod::RandomFourierFeatures {
879            n_components: 20,
880            gamma: 0.5,
881        };
882
883        let online_krr = OnlineKernelRidgeRegression::new(approximation)
884            .alpha(0.1)
885            .update_frequency(2);
886
887        let fitted = online_krr
888            .fit(&x_initial, &y_initial)
889            .expect("operation should succeed");
890        let updated = fitted
891            .partial_fit(&x_new, &y_new)
892            .expect("operation should succeed");
893
894        assert_eq!(updated.update_count(), 1);
895
896        let predictions = updated
897            .predict(&x_initial)
898            .expect("operation should succeed");
899        assert_eq!(predictions.len(), 2);
900    }
901
902    #[test]
903    fn test_reproducibility() {
904        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
905        let y = array![1.0, 2.0, 3.0];
906
907        let approximation = ApproximationMethod::RandomFourierFeatures {
908            n_components: 10,
909            gamma: 1.0,
910        };
911
912        let krr1 = KernelRidgeRegression::new(approximation.clone())
913            .alpha(0.1)
914            .random_state(42);
915        let fitted1 = krr1.fit(&x, &y).expect("operation should succeed");
916        let pred1 = fitted1.predict(&x).expect("operation should succeed");
917
918        let krr2 = KernelRidgeRegression::new(approximation)
919            .alpha(0.1)
920            .random_state(42);
921        let fitted2 = krr2.fit(&x, &y).expect("operation should succeed");
922        let pred2 = fitted2.predict(&x).expect("operation should succeed");
923
924        assert_eq!(pred1.len(), pred2.len());
925        for i in 0..pred1.len() {
926            assert!((pred1[i] - pred2[i]).abs() < 1e-10);
927        }
928    }
929}