Skip to main content

sklears_kernel_approximation/
nystroem.rs

1//! Nyström method for kernel approximation
2use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::Rng;
6use scirs2_core::random::{thread_rng, SeedableRng};
7use sklears_core::{
8    error::{Result, SklearsError},
9    prelude::{Fit, Transform},
10    traits::{Estimator, Trained, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15/// Compute the inverse of a symmetric positive definite matrix via Cholesky decomposition.
16///
17/// For an SPD matrix A = L L^T, computes A^{-1} = L^{-T} L^{-1}.
18/// This is O(n^3) with small constants and no external library overhead.
19fn cholesky_inverse(a: &Array2<f64>, n: usize) -> std::result::Result<Array2<f64>, String> {
20    // Step 1: Cholesky factorization A = L L^T
21    let mut l = Array2::<f64>::zeros((n, n));
22    for j in 0..n {
23        let mut sum = 0.0;
24        for k in 0..j {
25            sum += l[[j, k]] * l[[j, k]];
26        }
27        let diag = a[[j, j]] - sum;
28        if diag <= 0.0 {
29            return Err(format!(
30                "Matrix is not positive definite (diagonal element {} = {})",
31                j, diag
32            ));
33        }
34        l[[j, j]] = diag.sqrt();
35
36        for i in (j + 1)..n {
37            let mut s = 0.0;
38            for k in 0..j {
39                s += l[[i, k]] * l[[j, k]];
40            }
41            l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
42        }
43    }
44
45    // Step 2: Compute L^{-1} by forward substitution on identity columns
46    let mut l_inv = Array2::<f64>::zeros((n, n));
47    for col in 0..n {
48        // Solve L * x = e_col
49        for i in 0..n {
50            let mut sum = if i == col { 1.0 } else { 0.0 };
51            for k in 0..i {
52                sum -= l[[i, k]] * l_inv[[k, col]];
53            }
54            l_inv[[i, col]] = sum / l[[i, i]];
55        }
56    }
57
58    // Step 3: A^{-1} = L^{-T} * L^{-1} = (L^{-1})^T * L^{-1}
59    let result = l_inv.t().dot(&l_inv);
60
61    Ok(result)
62}
63
64/// Sampling strategy for Nyström approximation
65#[derive(Debug, Clone)]
66/// SamplingStrategy
67pub enum SamplingStrategy {
68    /// Random uniform sampling
69    Random,
70    /// K-means clustering based sampling
71    KMeans,
72    /// Leverage score based sampling
73    LeverageScore,
74    /// Column norm based sampling
75    ColumnNorm,
76}
77
78/// Kernel type for Nystroem approximation
79#[derive(Debug, Clone)]
80/// Kernel
81pub enum Kernel {
82    /// Linear kernel: K(x,y) = x^T y
83    Linear,
84    /// RBF kernel: K(x,y) = exp(-gamma * ||x-y||²)
85    Rbf { gamma: Float },
86    /// Polynomial kernel: K(x,y) = (gamma * x^T y + coef0)^degree
87    Polynomial {
88        gamma: Float,
89
90        coef0: Float,
91
92        degree: u32,
93    },
94}
95
96impl Kernel {
97    /// Compute kernel matrix between X and Y
98    pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
99        let (n_x, _) = x.dim();
100        let (n_y, _) = y.dim();
101        let mut kernel_matrix = Array2::zeros((n_x, n_y));
102
103        match self {
104            Kernel::Linear => {
105                kernel_matrix = x.dot(&y.t());
106            }
107            Kernel::Rbf { gamma } => {
108                for i in 0..n_x {
109                    for j in 0..n_y {
110                        let diff = &x.row(i) - &y.row(j);
111                        let dist_sq = diff.dot(&diff);
112                        kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
113                    }
114                }
115            }
116            Kernel::Polynomial {
117                gamma,
118                coef0,
119                degree,
120            } => {
121                for i in 0..n_x {
122                    for j in 0..n_y {
123                        let dot_prod = x.row(i).dot(&y.row(j));
124                        kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
125                    }
126                }
127            }
128        }
129
130        kernel_matrix
131    }
132}
133
134/// Nyström method for kernel approximation
135///
136/// General method for kernel approximation using eigendecomposition on a subset
137/// of training data. Works with any kernel function and supports multiple
138/// sampling strategies for improved approximation quality.
139///
140/// # Parameters
141///
142/// * `kernel` - Kernel function to approximate
143/// * `n_components` - Number of samples to use for approximation (default: 100)
144/// * `sampling_strategy` - Strategy for selecting landmark points
145/// * `random_state` - Random seed for reproducibility
146///
147/// # Examples
148///
149/// ```rust,ignore
150/// use sklears_kernel_approximation::nystroem::{Nystroem, Kernel, SamplingStrategy};
151/// use sklears_core::traits::{Transform, Fit, Untrained}
152/// use scirs2_core::ndarray::array;
153///
154/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
155///
156/// let nystroem = Nystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
157///     .sampling_strategy(SamplingStrategy::LeverageScore);
158/// let fitted_nystroem = nystroem.fit(&X, &()).unwrap();
159/// let X_transformed = fitted_nystroem.transform(&X).unwrap();
160/// assert_eq!(X_transformed.shape(), &[3, 3]);
161/// ```
162#[derive(Debug, Clone)]
163/// Nystroem
164pub struct Nystroem<State = Untrained> {
165    /// Kernel function
166    pub kernel: Kernel,
167    /// Number of components for approximation
168    pub n_components: usize,
169    /// Sampling strategy for landmark selection
170    pub sampling_strategy: SamplingStrategy,
171    /// Random seed
172    pub random_state: Option<u64>,
173
174    // Fitted attributes
175    components_: Option<Array2<Float>>,
176    normalization_: Option<Array2<Float>>,
177    component_indices_: Option<Vec<usize>>,
178
179    _state: PhantomData<State>,
180}
181
182impl Nystroem<Untrained> {
183    /// Create a new Nystroem approximator
184    pub fn new(kernel: Kernel, n_components: usize) -> Self {
185        Self {
186            kernel,
187            n_components,
188            sampling_strategy: SamplingStrategy::Random,
189            random_state: None,
190            components_: None,
191            normalization_: None,
192            component_indices_: None,
193            _state: PhantomData,
194        }
195    }
196
197    /// Set the sampling strategy
198    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
199        self.sampling_strategy = strategy;
200        self
201    }
202
203    /// Set random state for reproducibility
204    pub fn random_state(mut self, seed: u64) -> Self {
205        self.random_state = Some(seed);
206        self
207    }
208}
209
210impl Estimator for Nystroem<Untrained> {
211    type Config = ();
212    type Error = SklearsError;
213    type Float = Float;
214
215    fn config(&self) -> &Self::Config {
216        &()
217    }
218}
219
220impl Nystroem<Untrained> {
221    /// Select component indices based on sampling strategy
222    fn select_components(
223        &self,
224        x: &Array2<Float>,
225        n_components: usize,
226        rng: &mut RealStdRng,
227    ) -> Result<Vec<usize>> {
228        let (n_samples, _) = x.dim();
229
230        match &self.sampling_strategy {
231            SamplingStrategy::Random => {
232                let mut indices: Vec<usize> = (0..n_samples).collect();
233                indices.shuffle(rng);
234                Ok(indices[..n_components].to_vec())
235            }
236            SamplingStrategy::KMeans => {
237                // Simple k-means based sampling
238                self.kmeans_sampling(x, n_components, rng)
239            }
240            SamplingStrategy::LeverageScore => {
241                // Leverage score based sampling
242                self.leverage_score_sampling(x, n_components, rng)
243            }
244            SamplingStrategy::ColumnNorm => {
245                // Column norm based sampling
246                self.column_norm_sampling(x, n_components, rng)
247            }
248        }
249    }
250
251    /// Simple k-means based sampling
252    fn kmeans_sampling(
253        &self,
254        x: &Array2<Float>,
255        n_components: usize,
256        rng: &mut RealStdRng,
257    ) -> Result<Vec<usize>> {
258        let (n_samples, n_features) = x.dim();
259        let mut centers = Array2::zeros((n_components, n_features));
260
261        // Initialize centers randomly
262        let mut indices: Vec<usize> = (0..n_samples).collect();
263        indices.shuffle(rng);
264        for (i, &idx) in indices[..n_components].iter().enumerate() {
265            centers.row_mut(i).assign(&x.row(idx));
266        }
267
268        // Run a few iterations of k-means
269        for _iter in 0..5 {
270            let mut assignments = vec![0; n_samples];
271
272            // Assign points to nearest centers
273            for i in 0..n_samples {
274                let mut min_dist = Float::INFINITY;
275                let mut best_center = 0;
276
277                for j in 0..n_components {
278                    let diff = &x.row(i) - &centers.row(j);
279                    let dist = diff.dot(&diff);
280                    if dist < min_dist {
281                        min_dist = dist;
282                        best_center = j;
283                    }
284                }
285                assignments[i] = best_center;
286            }
287
288            // Update centers
289            for j in 0..n_components {
290                let cluster_points: Vec<usize> = assignments
291                    .iter()
292                    .enumerate()
293                    .filter(|(_, &assignment)| assignment == j)
294                    .map(|(i, _)| i)
295                    .collect();
296
297                if !cluster_points.is_empty() {
298                    let mut new_center = Array1::zeros(n_features);
299                    for &point_idx in &cluster_points {
300                        new_center = new_center + x.row(point_idx);
301                    }
302                    new_center /= cluster_points.len() as Float;
303                    centers.row_mut(j).assign(&new_center);
304                }
305            }
306        }
307
308        // Find closest points to final centers
309        let mut selected_indices = Vec::new();
310        for j in 0..n_components {
311            let mut min_dist = Float::INFINITY;
312            let mut best_point = 0;
313
314            for i in 0..n_samples {
315                let diff = &x.row(i) - &centers.row(j);
316                let dist = diff.dot(&diff);
317                if dist < min_dist {
318                    min_dist = dist;
319                    best_point = i;
320                }
321            }
322            selected_indices.push(best_point);
323        }
324
325        selected_indices.sort_unstable();
326        selected_indices.dedup();
327
328        // Fill remaining slots randomly if needed
329        while selected_indices.len() < n_components {
330            let random_idx = rng.gen_range(0..n_samples);
331            if !selected_indices.contains(&random_idx) {
332                selected_indices.push(random_idx);
333            }
334        }
335
336        Ok(selected_indices[..n_components].to_vec())
337    }
338
339    /// Leverage score based sampling
340    fn leverage_score_sampling(
341        &self,
342        x: &Array2<Float>,
343        n_components: usize,
344        _rng: &mut RealStdRng,
345    ) -> Result<Vec<usize>> {
346        let (n_samples, _) = x.dim();
347
348        // Compute leverage scores (diagonal of hat matrix)
349        // For simplicity, we approximate using row norms as proxy
350        let mut scores = Vec::new();
351        for i in 0..n_samples {
352            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
353            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
354        }
355
356        // Sample based on scores using cumulative distribution
357        let total_score: Float = scores.iter().sum();
358        if total_score <= 0.0 {
359            return Err(SklearsError::InvalidInput(
360                "All scores are zero or negative".to_string(),
361            ));
362        }
363
364        // Create cumulative distribution
365        let mut cumulative = Vec::with_capacity(scores.len());
366        let mut sum = 0.0;
367        for &score in &scores {
368            sum += score / total_score;
369            cumulative.push(sum);
370        }
371
372        let mut selected_indices = Vec::new();
373        for _ in 0..n_components {
374            let r = thread_rng().gen::<Float>();
375            // Find index where cumulative probability >= r
376            let mut idx = cumulative
377                .iter()
378                .position(|&cum| cum >= r)
379                .unwrap_or(scores.len() - 1);
380
381            // Ensure no duplicates
382            while selected_indices.contains(&idx) {
383                let r = thread_rng().gen::<Float>();
384                idx = cumulative
385                    .iter()
386                    .position(|&cum| cum >= r)
387                    .unwrap_or(scores.len() - 1);
388            }
389            selected_indices.push(idx);
390        }
391
392        Ok(selected_indices)
393    }
394
395    /// Column norm based sampling
396    fn column_norm_sampling(
397        &self,
398        x: &Array2<Float>,
399        n_components: usize,
400        rng: &mut RealStdRng,
401    ) -> Result<Vec<usize>> {
402        let (n_samples, _) = x.dim();
403
404        // Compute row norms
405        let mut norms = Vec::new();
406        for i in 0..n_samples {
407            let norm = x.row(i).dot(&x.row(i)).sqrt();
408            norms.push(norm + 1e-10);
409        }
410
411        // Sort by norm and take diverse selection
412        let mut indices_with_norms: Vec<(usize, Float)> = norms
413            .iter()
414            .enumerate()
415            .map(|(i, &norm)| (i, norm))
416            .collect();
417        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
418
419        let mut selected_indices = Vec::new();
420        let step = n_samples.max(1) / n_components.max(1);
421
422        for i in 0..n_components {
423            let idx = (i * step).min(n_samples - 1);
424            selected_indices.push(indices_with_norms[idx].0);
425        }
426
427        // Fill remaining with random if needed
428        while selected_indices.len() < n_components {
429            let random_idx = rng.gen_range(0..n_samples);
430            if !selected_indices.contains(&random_idx) {
431                selected_indices.push(random_idx);
432            }
433        }
434
435        Ok(selected_indices)
436    }
437
438    /// Compute eigendecomposition using scirs2-linalg's eigh for symmetric matrices
439    /// Returns (eigenvalues, eigenvectors) sorted in descending order by eigenvalue magnitude
440    fn compute_eigendecomposition(
441        &self,
442        matrix: &Array2<Float>,
443        _rng: &mut RealStdRng,
444    ) -> Result<(Array1<Float>, Array2<Float>)> {
445        use scirs2_linalg::compat::{ArrayLinalgExt, UPLO};
446
447        let n = matrix.nrows();
448
449        if n != matrix.ncols() {
450            return Err(SklearsError::InvalidInput(
451                "Matrix must be square for eigendecomposition".to_string(),
452            ));
453        }
454
455        // Use scirs2-linalg's eigh for symmetric eigendecomposition (much faster than power iteration)
456        let (eigenvals, eigenvecs) = matrix
457            .eigh(UPLO::Lower)
458            .map_err(|e| SklearsError::InvalidInput(format!("Eigendecomposition failed: {}", e)))?;
459
460        // eigh returns eigenvalues in ascending order, we need descending
461        let mut sorted_eigenvals = Array1::zeros(n);
462        let mut sorted_eigenvecs = Array2::zeros((n, n));
463
464        for i in 0..n {
465            sorted_eigenvals[i] = eigenvals[n - 1 - i];
466            sorted_eigenvecs
467                .column_mut(i)
468                .assign(&eigenvecs.column(n - 1 - i));
469        }
470
471        Ok((sorted_eigenvals, sorted_eigenvecs))
472    }
473
474    /// Power iteration method to find dominant eigenvalue and eigenvector
475    fn power_iteration(
476        &self,
477        matrix: &Array2<Float>,
478        max_iter: usize,
479        tol: Float,
480        rng: &mut RealStdRng,
481    ) -> Result<(Float, Array1<Float>)> {
482        let n = matrix.nrows();
483
484        // Initialize random vector
485        let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
486
487        // Normalize
488        let norm = v.dot(&v).sqrt();
489        if norm < 1e-10 {
490            return Err(SklearsError::InvalidInput(
491                "Initial vector has zero norm".to_string(),
492            ));
493        }
494        v /= norm;
495
496        let mut eigenval = 0.0;
497
498        for _iter in 0..max_iter {
499            // Apply matrix
500            let w = matrix.dot(&v);
501
502            // Compute Rayleigh quotient
503            let new_eigenval = v.dot(&w);
504
505            // Normalize
506            let w_norm = w.dot(&w).sqrt();
507            if w_norm < 1e-10 {
508                break;
509            }
510            let new_v = w / w_norm;
511
512            // Check convergence
513            let eigenval_change = (new_eigenval - eigenval).abs();
514            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
515
516            if eigenval_change < tol && vector_change < tol {
517                return Ok((new_eigenval, new_v));
518            }
519
520            eigenval = new_eigenval;
521            v = new_v;
522        }
523
524        Ok((eigenval, v))
525    }
526}
527
528impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
529    type Fitted = Nystroem<Trained>;
530
531    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
532        let (n_samples, _) = x.dim();
533
534        if self.n_components > n_samples {
535            eprintln!(
536                "Warning: n_components ({}) > n_samples ({})",
537                self.n_components, n_samples
538            );
539        }
540
541        let n_components_actual = self.n_components.min(n_samples);
542
543        let mut rng = if let Some(seed) = self.random_state {
544            RealStdRng::seed_from_u64(seed)
545        } else {
546            RealStdRng::from_seed(thread_rng().gen())
547        };
548
549        // Select component indices using specified strategy
550        let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
551
552        // Extract component samples
553        let mut components = Array2::zeros((n_components_actual, x.ncols()));
554        for (i, &idx) in component_indices.iter().enumerate() {
555            components.row_mut(i).assign(&x.row(idx));
556        }
557
558        // Compute kernel matrix K_11 on sampled points
559        let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
560
561        // Proper Nyström approximation
562        // K ≈ K₁₂ K₁₁⁻¹ K₁₂ᵀ where K₁₁⁻¹ is the inverse of landmark kernel matrix
563        let eps = 1e-8;
564
565        // Add regularization to diagonal for numerical stability (ensures positive definiteness)
566        let mut k11_reg = k11;
567        for i in 0..n_components_actual {
568            k11_reg[[i, i]] += eps;
569        }
570
571        // Compute inverse via Cholesky decomposition (fast for SPD matrices)
572        // K11_reg = L L^T, so K11_reg^{-1} = L^{-T} L^{-1}
573        let normalization = cholesky_inverse(&k11_reg, n_components_actual).map_err(|e| {
574            SklearsError::InvalidInput(format!("Failed to invert regularized kernel matrix: {}", e))
575        })?;
576
577        Ok(Nystroem {
578            kernel: self.kernel,
579            n_components: self.n_components,
580            sampling_strategy: self.sampling_strategy,
581            random_state: self.random_state,
582            components_: Some(components),
583            normalization_: Some(normalization),
584            component_indices_: Some(component_indices),
585            _state: PhantomData,
586        })
587    }
588}
589
590impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
591    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
592        let components = self.components_.as_ref().unwrap();
593        let normalization = self.normalization_.as_ref().unwrap();
594
595        if x.ncols() != components.ncols() {
596            return Err(SklearsError::InvalidInput(format!(
597                "X has {} features, but Nystroem was fitted with {} features",
598                x.ncols(),
599                components.ncols()
600            )));
601        }
602
603        // Compute kernel matrix K(X, components)
604        let k_x_components = self.kernel.compute_kernel(x, components);
605
606        // Apply normalization: K(X, components) @ normalization
607        let result = k_x_components.dot(normalization);
608
609        Ok(result)
610    }
611}
612
613impl Nystroem<Trained> {
614    /// Get the selected component samples
615    pub fn components(&self) -> &Array2<Float> {
616        self.components_.as_ref().unwrap()
617    }
618
619    /// Get the component indices
620    pub fn component_indices(&self) -> &[usize] {
621        self.component_indices_.as_ref().unwrap()
622    }
623
624    /// Get the normalization matrix
625    pub fn normalization(&self) -> &Array2<Float> {
626        self.normalization_.as_ref().unwrap()
627    }
628}
629
630#[allow(non_snake_case)]
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use scirs2_core::ndarray::array;
635
636    #[test]
637    fn test_nystroem_linear_kernel() {
638        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
639
640        let nystroem = Nystroem::new(Kernel::Linear, 3);
641        let fitted = nystroem.fit(&x, &()).unwrap();
642        let x_transformed = fitted.transform(&x).unwrap();
643
644        assert_eq!(x_transformed.nrows(), 4);
645        assert!(x_transformed.ncols() <= 3); // May be less due to eigenvalue filtering
646    }
647
648    #[test]
649    fn test_nystroem_rbf_kernel() {
650        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
651
652        let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
653        let fitted = nystroem.fit(&x, &()).unwrap();
654        let x_transformed = fitted.transform(&x).unwrap();
655
656        assert_eq!(x_transformed.nrows(), 3);
657        assert!(x_transformed.ncols() <= 2);
658    }
659
660    #[test]
661    fn test_nystroem_polynomial_kernel() {
662        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
663
664        let kernel = Kernel::Polynomial {
665            gamma: 1.0,
666            coef0: 1.0,
667            degree: 2,
668        };
669        let nystroem = Nystroem::new(kernel, 2);
670        let fitted = nystroem.fit(&x, &()).unwrap();
671        let x_transformed = fitted.transform(&x).unwrap();
672
673        assert_eq!(x_transformed.nrows(), 3);
674        assert!(x_transformed.ncols() <= 2);
675    }
676
677    #[test]
678    fn test_nystroem_reproducibility() {
679        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
680
681        let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
682        let fitted1 = nystroem1.fit(&x, &()).unwrap();
683        let result1 = fitted1.transform(&x).unwrap();
684
685        let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
686        let fitted2 = nystroem2.fit(&x, &()).unwrap();
687        let result2 = fitted2.transform(&x).unwrap();
688
689        // Results should be very similar with same random state (allowing for numerical precision)
690        assert_eq!(result1.shape(), result2.shape());
691        for (a, b) in result1.iter().zip(result2.iter()) {
692            assert!(
693                (a - b).abs() < 1e-6,
694                "Values differ too much: {} vs {}",
695                a,
696                b
697            );
698        }
699    }
700
701    #[test]
702    fn test_nystroem_feature_mismatch() {
703        let x_train = array![[1.0, 2.0], [3.0, 4.0],];
704
705        let x_test = array![
706            [1.0, 2.0, 3.0], // Wrong number of features
707        ];
708
709        let nystroem = Nystroem::new(Kernel::Linear, 2);
710        let fitted = nystroem.fit(&x_train, &()).unwrap();
711        let result = fitted.transform(&x_test);
712
713        assert!(result.is_err());
714    }
715
716    #[test]
717    fn test_nystroem_sampling_strategies() {
718        let x = array![
719            [1.0, 2.0],
720            [3.0, 4.0],
721            [5.0, 6.0],
722            [7.0, 8.0],
723            [2.0, 1.0],
724            [4.0, 3.0],
725            [6.0, 5.0],
726            [8.0, 7.0]
727        ];
728
729        // Test Random sampling
730        let nystroem_random = Nystroem::new(Kernel::Linear, 4)
731            .sampling_strategy(SamplingStrategy::Random)
732            .random_state(42);
733        let fitted_random = nystroem_random.fit(&x, &()).unwrap();
734        let result_random = fitted_random.transform(&x).unwrap();
735        assert_eq!(result_random.nrows(), 8);
736
737        // Test K-means sampling
738        let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
739            .sampling_strategy(SamplingStrategy::KMeans)
740            .random_state(42);
741        let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
742        let result_kmeans = fitted_kmeans.transform(&x).unwrap();
743        assert_eq!(result_kmeans.nrows(), 8);
744
745        // Test Leverage score sampling
746        let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
747            .sampling_strategy(SamplingStrategy::LeverageScore)
748            .random_state(42);
749        let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
750        let result_leverage = fitted_leverage.transform(&x).unwrap();
751        assert_eq!(result_leverage.nrows(), 8);
752
753        // Test Column norm sampling
754        let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
755            .sampling_strategy(SamplingStrategy::ColumnNorm)
756            .random_state(42);
757        let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
758        let result_norm = fitted_norm.transform(&x).unwrap();
759        assert_eq!(result_norm.nrows(), 8);
760    }
761
762    #[test]
763    fn test_nystroem_rbf_with_different_sampling() {
764        let x = array![
765            [1.0, 2.0],
766            [3.0, 4.0],
767            [5.0, 6.0],
768            [7.0, 8.0],
769            [2.0, 1.0],
770            [4.0, 3.0],
771            [6.0, 5.0],
772            [8.0, 7.0]
773        ];
774
775        let kernel = Kernel::Rbf { gamma: 0.1 };
776
777        // Test with leverage score sampling
778        let nystroem = Nystroem::new(kernel, 4)
779            .sampling_strategy(SamplingStrategy::LeverageScore)
780            .random_state(42);
781        let fitted = nystroem.fit(&x, &()).unwrap();
782        let result = fitted.transform(&x).unwrap();
783
784        assert_eq!(result.shape(), &[8, 4]);
785
786        // Check that all values are finite
787        for val in result.iter() {
788            assert!(val.is_finite());
789        }
790    }
791
792    #[test]
793    fn test_nystroem_improved_eigendecomposition() {
794        let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
795
796        let nystroem = Nystroem::new(Kernel::Linear, 3)
797            .sampling_strategy(SamplingStrategy::Random)
798            .random_state(42);
799        let fitted = nystroem.fit(&x, &()).unwrap();
800        let result = fitted.transform(&x).unwrap();
801
802        assert_eq!(result.nrows(), 4);
803        assert!(result.ncols() <= 3);
804
805        // Check numerical stability
806        for val in result.iter() {
807            assert!(val.is_finite());
808        }
809    }
810}