sklears_kernel_approximation/
nystroem.rs

1//! Nyström method for kernel approximation
2use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::{thread_rng, Rng, SeedableRng};
6use sklears_core::{
7    error::{Result, SklearsError},
8    prelude::{Fit, Transform},
9    traits::{Estimator, Trained, Untrained},
10    types::Float,
11};
12use std::marker::PhantomData;
13
14/// Sampling strategy for Nyström approximation
15#[derive(Debug, Clone)]
16/// SamplingStrategy
17pub enum SamplingStrategy {
18    /// Random uniform sampling
19    Random,
20    /// K-means clustering based sampling
21    KMeans,
22    /// Leverage score based sampling
23    LeverageScore,
24    /// Column norm based sampling
25    ColumnNorm,
26}
27
28/// Kernel type for Nystroem approximation
29#[derive(Debug, Clone)]
30/// Kernel
31pub enum Kernel {
32    /// Linear kernel: K(x,y) = x^T y
33    Linear,
34    /// RBF kernel: K(x,y) = exp(-gamma * ||x-y||²)
35    Rbf { gamma: Float },
36    /// Polynomial kernel: K(x,y) = (gamma * x^T y + coef0)^degree
37    Polynomial {
38        gamma: Float,
39
40        coef0: Float,
41
42        degree: u32,
43    },
44}
45
46impl Kernel {
47    /// Compute kernel matrix between X and Y
48    pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
49        let (n_x, _) = x.dim();
50        let (n_y, _) = y.dim();
51        let mut kernel_matrix = Array2::zeros((n_x, n_y));
52
53        match self {
54            Kernel::Linear => {
55                kernel_matrix = x.dot(&y.t());
56            }
57            Kernel::Rbf { gamma } => {
58                for i in 0..n_x {
59                    for j in 0..n_y {
60                        let diff = &x.row(i) - &y.row(j);
61                        let dist_sq = diff.dot(&diff);
62                        kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
63                    }
64                }
65            }
66            Kernel::Polynomial {
67                gamma,
68                coef0,
69                degree,
70            } => {
71                for i in 0..n_x {
72                    for j in 0..n_y {
73                        let dot_prod = x.row(i).dot(&y.row(j));
74                        kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
75                    }
76                }
77            }
78        }
79
80        kernel_matrix
81    }
82}
83
84/// Nyström method for kernel approximation
85///
86/// General method for kernel approximation using eigendecomposition on a subset
87/// of training data. Works with any kernel function and supports multiple
88/// sampling strategies for improved approximation quality.
89///
90/// # Parameters
91///
92/// * `kernel` - Kernel function to approximate
93/// * `n_components` - Number of samples to use for approximation (default: 100)
94/// * `sampling_strategy` - Strategy for selecting landmark points
95/// * `random_state` - Random seed for reproducibility
96///
97/// # Examples
98///
99/// ```rust,ignore
100/// use sklears_kernel_approximation::nystroem::{Nystroem, Kernel, SamplingStrategy};
101/// use sklears_core::traits::{Transform, Fit, Untrained}
102/// use scirs2_core::ndarray::array;
103///
104/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
105///
106/// let nystroem = Nystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
107///     .sampling_strategy(SamplingStrategy::LeverageScore);
108/// let fitted_nystroem = nystroem.fit(&X, &()).unwrap();
109/// let X_transformed = fitted_nystroem.transform(&X).unwrap();
110/// assert_eq!(X_transformed.shape(), &[3, 3]);
111/// ```
112#[derive(Debug, Clone)]
113/// Nystroem
114pub struct Nystroem<State = Untrained> {
115    /// Kernel function
116    pub kernel: Kernel,
117    /// Number of components for approximation
118    pub n_components: usize,
119    /// Sampling strategy for landmark selection
120    pub sampling_strategy: SamplingStrategy,
121    /// Random seed
122    pub random_state: Option<u64>,
123
124    // Fitted attributes
125    components_: Option<Array2<Float>>,
126    normalization_: Option<Array2<Float>>,
127    component_indices_: Option<Vec<usize>>,
128
129    _state: PhantomData<State>,
130}
131
132impl Nystroem<Untrained> {
133    /// Create a new Nystroem approximator
134    pub fn new(kernel: Kernel, n_components: usize) -> Self {
135        Self {
136            kernel,
137            n_components,
138            sampling_strategy: SamplingStrategy::Random,
139            random_state: None,
140            components_: None,
141            normalization_: None,
142            component_indices_: None,
143            _state: PhantomData,
144        }
145    }
146
147    /// Set the sampling strategy
148    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
149        self.sampling_strategy = strategy;
150        self
151    }
152
153    /// Set random state for reproducibility
154    pub fn random_state(mut self, seed: u64) -> Self {
155        self.random_state = Some(seed);
156        self
157    }
158}
159
160impl Estimator for Nystroem<Untrained> {
161    type Config = ();
162    type Error = SklearsError;
163    type Float = Float;
164
165    fn config(&self) -> &Self::Config {
166        &()
167    }
168}
169
170impl Nystroem<Untrained> {
171    /// Select component indices based on sampling strategy
172    fn select_components(
173        &self,
174        x: &Array2<Float>,
175        n_components: usize,
176        rng: &mut RealStdRng,
177    ) -> Result<Vec<usize>> {
178        let (n_samples, _) = x.dim();
179
180        match &self.sampling_strategy {
181            SamplingStrategy::Random => {
182                let mut indices: Vec<usize> = (0..n_samples).collect();
183                indices.shuffle(rng);
184                Ok(indices[..n_components].to_vec())
185            }
186            SamplingStrategy::KMeans => {
187                // Simple k-means based sampling
188                self.kmeans_sampling(x, n_components, rng)
189            }
190            SamplingStrategy::LeverageScore => {
191                // Leverage score based sampling
192                self.leverage_score_sampling(x, n_components, rng)
193            }
194            SamplingStrategy::ColumnNorm => {
195                // Column norm based sampling
196                self.column_norm_sampling(x, n_components, rng)
197            }
198        }
199    }
200
201    /// Simple k-means based sampling
202    fn kmeans_sampling(
203        &self,
204        x: &Array2<Float>,
205        n_components: usize,
206        rng: &mut RealStdRng,
207    ) -> Result<Vec<usize>> {
208        let (n_samples, n_features) = x.dim();
209        let mut centers = Array2::zeros((n_components, n_features));
210
211        // Initialize centers randomly
212        let mut indices: Vec<usize> = (0..n_samples).collect();
213        indices.shuffle(rng);
214        for (i, &idx) in indices[..n_components].iter().enumerate() {
215            centers.row_mut(i).assign(&x.row(idx));
216        }
217
218        // Run a few iterations of k-means
219        for _iter in 0..5 {
220            let mut assignments = vec![0; n_samples];
221
222            // Assign points to nearest centers
223            for i in 0..n_samples {
224                let mut min_dist = Float::INFINITY;
225                let mut best_center = 0;
226
227                for j in 0..n_components {
228                    let diff = &x.row(i) - &centers.row(j);
229                    let dist = diff.dot(&diff);
230                    if dist < min_dist {
231                        min_dist = dist;
232                        best_center = j;
233                    }
234                }
235                assignments[i] = best_center;
236            }
237
238            // Update centers
239            for j in 0..n_components {
240                let cluster_points: Vec<usize> = assignments
241                    .iter()
242                    .enumerate()
243                    .filter(|(_, &assignment)| assignment == j)
244                    .map(|(i, _)| i)
245                    .collect();
246
247                if !cluster_points.is_empty() {
248                    let mut new_center = Array1::zeros(n_features);
249                    for &point_idx in &cluster_points {
250                        new_center = new_center + &x.row(point_idx);
251                    }
252                    new_center /= cluster_points.len() as Float;
253                    centers.row_mut(j).assign(&new_center);
254                }
255            }
256        }
257
258        // Find closest points to final centers
259        let mut selected_indices = Vec::new();
260        for j in 0..n_components {
261            let mut min_dist = Float::INFINITY;
262            let mut best_point = 0;
263
264            for i in 0..n_samples {
265                let diff = &x.row(i) - &centers.row(j);
266                let dist = diff.dot(&diff);
267                if dist < min_dist {
268                    min_dist = dist;
269                    best_point = i;
270                }
271            }
272            selected_indices.push(best_point);
273        }
274
275        selected_indices.sort_unstable();
276        selected_indices.dedup();
277
278        // Fill remaining slots randomly if needed
279        while selected_indices.len() < n_components {
280            let random_idx = rng.gen_range(0..n_samples);
281            if !selected_indices.contains(&random_idx) {
282                selected_indices.push(random_idx);
283            }
284        }
285
286        Ok(selected_indices[..n_components].to_vec())
287    }
288
289    /// Leverage score based sampling
290    fn leverage_score_sampling(
291        &self,
292        x: &Array2<Float>,
293        n_components: usize,
294        rng: &mut RealStdRng,
295    ) -> Result<Vec<usize>> {
296        let (n_samples, _) = x.dim();
297
298        // Compute leverage scores (diagonal of hat matrix)
299        // For simplicity, we approximate using row norms as proxy
300        let mut scores = Vec::new();
301        for i in 0..n_samples {
302            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
303            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
304        }
305
306        // Sample based on scores using cumulative distribution
307        let total_score: Float = scores.iter().sum();
308        if total_score <= 0.0 {
309            return Err(SklearsError::InvalidInput(
310                "All scores are zero or negative".to_string(),
311            ));
312        }
313
314        // Create cumulative distribution
315        let mut cumulative = Vec::with_capacity(scores.len());
316        let mut sum = 0.0;
317        for &score in &scores {
318            sum += score / total_score;
319            cumulative.push(sum);
320        }
321
322        let mut selected_indices = Vec::new();
323        for _ in 0..n_components {
324            let r = thread_rng().gen::<Float>();
325            // Find index where cumulative probability >= r
326            let mut idx = cumulative
327                .iter()
328                .position(|&cum| cum >= r)
329                .unwrap_or(scores.len() - 1);
330
331            // Ensure no duplicates
332            while selected_indices.contains(&idx) {
333                let r = thread_rng().gen::<Float>();
334                idx = cumulative
335                    .iter()
336                    .position(|&cum| cum >= r)
337                    .unwrap_or(scores.len() - 1);
338            }
339            selected_indices.push(idx);
340        }
341
342        Ok(selected_indices)
343    }
344
345    /// Column norm based sampling
346    fn column_norm_sampling(
347        &self,
348        x: &Array2<Float>,
349        n_components: usize,
350        rng: &mut RealStdRng,
351    ) -> Result<Vec<usize>> {
352        let (n_samples, _) = x.dim();
353
354        // Compute row norms
355        let mut norms = Vec::new();
356        for i in 0..n_samples {
357            let norm = x.row(i).dot(&x.row(i)).sqrt();
358            norms.push(norm + 1e-10);
359        }
360
361        // Sort by norm and take diverse selection
362        let mut indices_with_norms: Vec<(usize, Float)> = norms
363            .iter()
364            .enumerate()
365            .map(|(i, &norm)| (i, norm))
366            .collect();
367        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
368
369        let mut selected_indices = Vec::new();
370        let step = n_samples.max(1) / n_components.max(1);
371
372        for i in 0..n_components {
373            let idx = (i * step).min(n_samples - 1);
374            selected_indices.push(indices_with_norms[idx].0);
375        }
376
377        // Fill remaining with random if needed
378        while selected_indices.len() < n_components {
379            let random_idx = rng.gen_range(0..n_samples);
380            if !selected_indices.contains(&random_idx) {
381                selected_indices.push(random_idx);
382            }
383        }
384
385        Ok(selected_indices)
386    }
387
388    /// Compute eigendecomposition using power iteration method
389    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
390    fn compute_eigendecomposition(
391        &self,
392        matrix: &Array2<Float>,
393        rng: &mut RealStdRng,
394    ) -> Result<(Array1<Float>, Array2<Float>)> {
395        let n = matrix.nrows();
396
397        if n != matrix.ncols() {
398            return Err(SklearsError::InvalidInput(
399                "Matrix must be square for eigendecomposition".to_string(),
400            ));
401        }
402
403        let mut eigenvals = Array1::zeros(n);
404        let mut eigenvecs = Array2::zeros((n, n));
405
406        // Use deflation method to find multiple eigenvalues
407        let mut deflated_matrix = matrix.clone();
408
409        for k in 0..n {
410            // Power iteration for k-th eigenvalue/eigenvector
411            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8, rng)?;
412
413            eigenvals[k] = eigenval;
414            eigenvecs.column_mut(k).assign(&eigenvec);
415
416            // Deflate matrix: A_new = A - λ * v * v^T
417            for i in 0..n {
418                for j in 0..n {
419                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
420                }
421            }
422        }
423
424        // Sort eigenvalues and eigenvectors in descending order
425        let mut indices: Vec<usize> = (0..n).collect();
426        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
427
428        let mut sorted_eigenvals = Array1::zeros(n);
429        let mut sorted_eigenvecs = Array2::zeros((n, n));
430
431        for (new_idx, &old_idx) in indices.iter().enumerate() {
432            sorted_eigenvals[new_idx] = eigenvals[old_idx];
433            sorted_eigenvecs
434                .column_mut(new_idx)
435                .assign(&eigenvecs.column(old_idx));
436        }
437
438        Ok((sorted_eigenvals, sorted_eigenvecs))
439    }
440
441    /// Power iteration method to find dominant eigenvalue and eigenvector
442    fn power_iteration(
443        &self,
444        matrix: &Array2<Float>,
445        max_iter: usize,
446        tol: Float,
447        rng: &mut RealStdRng,
448    ) -> Result<(Float, Array1<Float>)> {
449        let n = matrix.nrows();
450
451        // Initialize random vector
452        let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
453
454        // Normalize
455        let norm = v.dot(&v).sqrt();
456        if norm < 1e-10 {
457            return Err(SklearsError::InvalidInput(
458                "Initial vector has zero norm".to_string(),
459            ));
460        }
461        v /= norm;
462
463        let mut eigenval = 0.0;
464
465        for _iter in 0..max_iter {
466            // Apply matrix
467            let w = matrix.dot(&v);
468
469            // Compute Rayleigh quotient
470            let new_eigenval = v.dot(&w);
471
472            // Normalize
473            let w_norm = w.dot(&w).sqrt();
474            if w_norm < 1e-10 {
475                break;
476            }
477            let new_v = w / w_norm;
478
479            // Check convergence
480            let eigenval_change = (new_eigenval - eigenval).abs();
481            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
482
483            if eigenval_change < tol && vector_change < tol {
484                return Ok((new_eigenval, new_v));
485            }
486
487            eigenval = new_eigenval;
488            v = new_v;
489        }
490
491        Ok((eigenval, v))
492    }
493}
494
495impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
496    type Fitted = Nystroem<Trained>;
497
498    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
499        let (n_samples, _) = x.dim();
500
501        if self.n_components > n_samples {
502            eprintln!(
503                "Warning: n_components ({}) > n_samples ({})",
504                self.n_components, n_samples
505            );
506        }
507
508        let n_components_actual = self.n_components.min(n_samples);
509
510        let mut rng = if let Some(seed) = self.random_state {
511            RealStdRng::seed_from_u64(seed)
512        } else {
513            RealStdRng::from_seed(thread_rng().gen())
514        };
515
516        // Select component indices using specified strategy
517        let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
518
519        // Extract component samples
520        let mut components = Array2::zeros((n_components_actual, x.ncols()));
521        for (i, &idx) in component_indices.iter().enumerate() {
522            components.row_mut(i).assign(&x.row(idx));
523        }
524
525        // Compute kernel matrix K_11 on sampled points
526        let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
527
528        // Proper Nyström approximation using eigendecomposition
529        // K ≈ K₁₂ K₁₁⁻¹ K₁₂ᵀ where K₁₁⁻¹ is the pseudo-inverse of landmark kernel matrix
530        let eps = 1e-12;
531
532        // Add small regularization to diagonal for numerical stability
533        let mut k11_reg = k11.clone();
534        for i in 0..n_components_actual {
535            k11_reg[[i, i]] += eps;
536        }
537
538        // Compute pseudo-inverse using eigendecomposition
539        // For symmetric positive definite matrices, we can use power iteration for eigendecomposition
540        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(&k11_reg, &mut rng)?;
541
542        // Filter out small eigenvalues for numerical stability
543        let threshold = 1e-8;
544        let valid_indices: Vec<usize> = eigenvals
545            .iter()
546            .enumerate()
547            .filter(|(_, &val)| val > threshold)
548            .map(|(i, _)| i)
549            .collect();
550
551        if valid_indices.is_empty() {
552            return Err(SklearsError::InvalidInput(
553                "No valid eigenvalues found in kernel matrix".to_string(),
554            ));
555        }
556
557        // Construct pseudo-inverse: V * Λ⁻¹ * V^T
558        let n_valid = valid_indices.len();
559        let mut pseudo_inverse = Array2::zeros((n_components_actual, n_components_actual));
560
561        for i in 0..n_components_actual {
562            for j in 0..n_components_actual {
563                let mut sum = 0.0;
564                for &k in &valid_indices {
565                    sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
566                }
567                pseudo_inverse[[i, j]] = sum;
568            }
569        }
570
571        let normalization = pseudo_inverse;
572
573        Ok(Nystroem {
574            kernel: self.kernel,
575            n_components: self.n_components,
576            sampling_strategy: self.sampling_strategy,
577            random_state: self.random_state,
578            components_: Some(components),
579            normalization_: Some(normalization),
580            component_indices_: Some(component_indices),
581            _state: PhantomData,
582        })
583    }
584}
585
586impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
587    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
588        let components = self.components_.as_ref().unwrap();
589        let normalization = self.normalization_.as_ref().unwrap();
590
591        if x.ncols() != components.ncols() {
592            return Err(SklearsError::InvalidInput(format!(
593                "X has {} features, but Nystroem was fitted with {} features",
594                x.ncols(),
595                components.ncols()
596            )));
597        }
598
599        // Compute kernel matrix K(X, components)
600        let k_x_components = self.kernel.compute_kernel(x, components);
601
602        // Apply normalization: K(X, components) @ normalization
603        let result = k_x_components.dot(normalization);
604
605        Ok(result)
606    }
607}
608
609impl Nystroem<Trained> {
610    /// Get the selected component samples
611    pub fn components(&self) -> &Array2<Float> {
612        self.components_.as_ref().unwrap()
613    }
614
615    /// Get the component indices
616    pub fn component_indices(&self) -> &[usize] {
617        self.component_indices_.as_ref().unwrap()
618    }
619
620    /// Get the normalization matrix
621    pub fn normalization(&self) -> &Array2<Float> {
622        self.normalization_.as_ref().unwrap()
623    }
624}
625
626#[allow(non_snake_case)]
627#[cfg(test)]
628mod tests {
629    use super::*;
630    use scirs2_core::ndarray::array;
631
632    #[test]
633    fn test_nystroem_linear_kernel() {
634        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
635
636        let nystroem = Nystroem::new(Kernel::Linear, 3);
637        let fitted = nystroem.fit(&x, &()).unwrap();
638        let x_transformed = fitted.transform(&x).unwrap();
639
640        assert_eq!(x_transformed.nrows(), 4);
641        assert!(x_transformed.ncols() <= 3); // May be less due to eigenvalue filtering
642    }
643
644    #[test]
645    fn test_nystroem_rbf_kernel() {
646        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
647
648        let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
649        let fitted = nystroem.fit(&x, &()).unwrap();
650        let x_transformed = fitted.transform(&x).unwrap();
651
652        assert_eq!(x_transformed.nrows(), 3);
653        assert!(x_transformed.ncols() <= 2);
654    }
655
656    #[test]
657    fn test_nystroem_polynomial_kernel() {
658        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
659
660        let kernel = Kernel::Polynomial {
661            gamma: 1.0,
662            coef0: 1.0,
663            degree: 2,
664        };
665        let nystroem = Nystroem::new(kernel, 2);
666        let fitted = nystroem.fit(&x, &()).unwrap();
667        let x_transformed = fitted.transform(&x).unwrap();
668
669        assert_eq!(x_transformed.nrows(), 3);
670        assert!(x_transformed.ncols() <= 2);
671    }
672
673    #[test]
674    fn test_nystroem_reproducibility() {
675        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
676
677        let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
678        let fitted1 = nystroem1.fit(&x, &()).unwrap();
679        let result1 = fitted1.transform(&x).unwrap();
680
681        let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
682        let fitted2 = nystroem2.fit(&x, &()).unwrap();
683        let result2 = fitted2.transform(&x).unwrap();
684
685        // Results should be very similar with same random state (allowing for numerical precision)
686        assert_eq!(result1.shape(), result2.shape());
687        for (a, b) in result1.iter().zip(result2.iter()) {
688            assert!(
689                (a - b).abs() < 1e-6,
690                "Values differ too much: {} vs {}",
691                a,
692                b
693            );
694        }
695    }
696
697    #[test]
698    fn test_nystroem_feature_mismatch() {
699        let x_train = array![[1.0, 2.0], [3.0, 4.0],];
700
701        let x_test = array![
702            [1.0, 2.0, 3.0], // Wrong number of features
703        ];
704
705        let nystroem = Nystroem::new(Kernel::Linear, 2);
706        let fitted = nystroem.fit(&x_train, &()).unwrap();
707        let result = fitted.transform(&x_test);
708
709        assert!(result.is_err());
710    }
711
712    #[test]
713    fn test_nystroem_sampling_strategies() {
714        let x = array![
715            [1.0, 2.0],
716            [3.0, 4.0],
717            [5.0, 6.0],
718            [7.0, 8.0],
719            [2.0, 1.0],
720            [4.0, 3.0],
721            [6.0, 5.0],
722            [8.0, 7.0]
723        ];
724
725        // Test Random sampling
726        let nystroem_random = Nystroem::new(Kernel::Linear, 4)
727            .sampling_strategy(SamplingStrategy::Random)
728            .random_state(42);
729        let fitted_random = nystroem_random.fit(&x, &()).unwrap();
730        let result_random = fitted_random.transform(&x).unwrap();
731        assert_eq!(result_random.nrows(), 8);
732
733        // Test K-means sampling
734        let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
735            .sampling_strategy(SamplingStrategy::KMeans)
736            .random_state(42);
737        let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
738        let result_kmeans = fitted_kmeans.transform(&x).unwrap();
739        assert_eq!(result_kmeans.nrows(), 8);
740
741        // Test Leverage score sampling
742        let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
743            .sampling_strategy(SamplingStrategy::LeverageScore)
744            .random_state(42);
745        let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
746        let result_leverage = fitted_leverage.transform(&x).unwrap();
747        assert_eq!(result_leverage.nrows(), 8);
748
749        // Test Column norm sampling
750        let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
751            .sampling_strategy(SamplingStrategy::ColumnNorm)
752            .random_state(42);
753        let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
754        let result_norm = fitted_norm.transform(&x).unwrap();
755        assert_eq!(result_norm.nrows(), 8);
756    }
757
758    #[test]
759    fn test_nystroem_rbf_with_different_sampling() {
760        let x = array![
761            [1.0, 2.0],
762            [3.0, 4.0],
763            [5.0, 6.0],
764            [7.0, 8.0],
765            [2.0, 1.0],
766            [4.0, 3.0],
767            [6.0, 5.0],
768            [8.0, 7.0]
769        ];
770
771        let kernel = Kernel::Rbf { gamma: 0.1 };
772
773        // Test with leverage score sampling
774        let nystroem = Nystroem::new(kernel, 4)
775            .sampling_strategy(SamplingStrategy::LeverageScore)
776            .random_state(42);
777        let fitted = nystroem.fit(&x, &()).unwrap();
778        let result = fitted.transform(&x).unwrap();
779
780        assert_eq!(result.shape(), &[8, 4]);
781
782        // Check that all values are finite
783        for val in result.iter() {
784            assert!(val.is_finite());
785        }
786    }
787
788    #[test]
789    fn test_nystroem_improved_eigendecomposition() {
790        let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
791
792        let nystroem = Nystroem::new(Kernel::Linear, 3)
793            .sampling_strategy(SamplingStrategy::Random)
794            .random_state(42);
795        let fitted = nystroem.fit(&x, &()).unwrap();
796        let result = fitted.transform(&x).unwrap();
797
798        assert_eq!(result.nrows(), 4);
799        assert!(result.ncols() <= 3);
800
801        // Check numerical stability
802        for val in result.iter() {
803            assert!(val.is_finite());
804        }
805    }
806}