sklears_kernel_approximation/
nystroem.rs

1//! Nyström method for kernel approximation
2use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::Rng;
6use scirs2_core::random::{thread_rng, SeedableRng};
7use sklears_core::{
8    error::{Result, SklearsError},
9    prelude::{Fit, Transform},
10    traits::{Estimator, Trained, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15/// Sampling strategy for Nyström approximation
16#[derive(Debug, Clone)]
17/// SamplingStrategy
18pub enum SamplingStrategy {
19    /// Random uniform sampling
20    Random,
21    /// K-means clustering based sampling
22    KMeans,
23    /// Leverage score based sampling
24    LeverageScore,
25    /// Column norm based sampling
26    ColumnNorm,
27}
28
29/// Kernel type for Nystroem approximation
30#[derive(Debug, Clone)]
31/// Kernel
32pub enum Kernel {
33    /// Linear kernel: K(x,y) = x^T y
34    Linear,
35    /// RBF kernel: K(x,y) = exp(-gamma * ||x-y||²)
36    Rbf { gamma: Float },
37    /// Polynomial kernel: K(x,y) = (gamma * x^T y + coef0)^degree
38    Polynomial {
39        gamma: Float,
40
41        coef0: Float,
42
43        degree: u32,
44    },
45}
46
47impl Kernel {
48    /// Compute kernel matrix between X and Y
49    pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
50        let (n_x, _) = x.dim();
51        let (n_y, _) = y.dim();
52        let mut kernel_matrix = Array2::zeros((n_x, n_y));
53
54        match self {
55            Kernel::Linear => {
56                kernel_matrix = x.dot(&y.t());
57            }
58            Kernel::Rbf { gamma } => {
59                for i in 0..n_x {
60                    for j in 0..n_y {
61                        let diff = &x.row(i) - &y.row(j);
62                        let dist_sq = diff.dot(&diff);
63                        kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
64                    }
65                }
66            }
67            Kernel::Polynomial {
68                gamma,
69                coef0,
70                degree,
71            } => {
72                for i in 0..n_x {
73                    for j in 0..n_y {
74                        let dot_prod = x.row(i).dot(&y.row(j));
75                        kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
76                    }
77                }
78            }
79        }
80
81        kernel_matrix
82    }
83}
84
85/// Nyström method for kernel approximation
86///
87/// General method for kernel approximation using eigendecomposition on a subset
88/// of training data. Works with any kernel function and supports multiple
89/// sampling strategies for improved approximation quality.
90///
91/// # Parameters
92///
93/// * `kernel` - Kernel function to approximate
94/// * `n_components` - Number of samples to use for approximation (default: 100)
95/// * `sampling_strategy` - Strategy for selecting landmark points
96/// * `random_state` - Random seed for reproducibility
97///
98/// # Examples
99///
100/// ```rust,ignore
101/// use sklears_kernel_approximation::nystroem::{Nystroem, Kernel, SamplingStrategy};
102/// use sklears_core::traits::{Transform, Fit, Untrained}
103/// use scirs2_core::ndarray::array;
104///
105/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
106///
107/// let nystroem = Nystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
108///     .sampling_strategy(SamplingStrategy::LeverageScore);
109/// let fitted_nystroem = nystroem.fit(&X, &()).unwrap();
110/// let X_transformed = fitted_nystroem.transform(&X).unwrap();
111/// assert_eq!(X_transformed.shape(), &[3, 3]);
112/// ```
113#[derive(Debug, Clone)]
114/// Nystroem
115pub struct Nystroem<State = Untrained> {
116    /// Kernel function
117    pub kernel: Kernel,
118    /// Number of components for approximation
119    pub n_components: usize,
120    /// Sampling strategy for landmark selection
121    pub sampling_strategy: SamplingStrategy,
122    /// Random seed
123    pub random_state: Option<u64>,
124
125    // Fitted attributes
126    components_: Option<Array2<Float>>,
127    normalization_: Option<Array2<Float>>,
128    component_indices_: Option<Vec<usize>>,
129
130    _state: PhantomData<State>,
131}
132
133impl Nystroem<Untrained> {
134    /// Create a new Nystroem approximator
135    pub fn new(kernel: Kernel, n_components: usize) -> Self {
136        Self {
137            kernel,
138            n_components,
139            sampling_strategy: SamplingStrategy::Random,
140            random_state: None,
141            components_: None,
142            normalization_: None,
143            component_indices_: None,
144            _state: PhantomData,
145        }
146    }
147
148    /// Set the sampling strategy
149    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
150        self.sampling_strategy = strategy;
151        self
152    }
153
154    /// Set random state for reproducibility
155    pub fn random_state(mut self, seed: u64) -> Self {
156        self.random_state = Some(seed);
157        self
158    }
159}
160
161impl Estimator for Nystroem<Untrained> {
162    type Config = ();
163    type Error = SklearsError;
164    type Float = Float;
165
166    fn config(&self) -> &Self::Config {
167        &()
168    }
169}
170
171impl Nystroem<Untrained> {
172    /// Select component indices based on sampling strategy
173    fn select_components(
174        &self,
175        x: &Array2<Float>,
176        n_components: usize,
177        rng: &mut RealStdRng,
178    ) -> Result<Vec<usize>> {
179        let (n_samples, _) = x.dim();
180
181        match &self.sampling_strategy {
182            SamplingStrategy::Random => {
183                let mut indices: Vec<usize> = (0..n_samples).collect();
184                indices.shuffle(rng);
185                Ok(indices[..n_components].to_vec())
186            }
187            SamplingStrategy::KMeans => {
188                // Simple k-means based sampling
189                self.kmeans_sampling(x, n_components, rng)
190            }
191            SamplingStrategy::LeverageScore => {
192                // Leverage score based sampling
193                self.leverage_score_sampling(x, n_components, rng)
194            }
195            SamplingStrategy::ColumnNorm => {
196                // Column norm based sampling
197                self.column_norm_sampling(x, n_components, rng)
198            }
199        }
200    }
201
202    /// Simple k-means based sampling
203    fn kmeans_sampling(
204        &self,
205        x: &Array2<Float>,
206        n_components: usize,
207        rng: &mut RealStdRng,
208    ) -> Result<Vec<usize>> {
209        let (n_samples, n_features) = x.dim();
210        let mut centers = Array2::zeros((n_components, n_features));
211
212        // Initialize centers randomly
213        let mut indices: Vec<usize> = (0..n_samples).collect();
214        indices.shuffle(rng);
215        for (i, &idx) in indices[..n_components].iter().enumerate() {
216            centers.row_mut(i).assign(&x.row(idx));
217        }
218
219        // Run a few iterations of k-means
220        for _iter in 0..5 {
221            let mut assignments = vec![0; n_samples];
222
223            // Assign points to nearest centers
224            for i in 0..n_samples {
225                let mut min_dist = Float::INFINITY;
226                let mut best_center = 0;
227
228                for j in 0..n_components {
229                    let diff = &x.row(i) - &centers.row(j);
230                    let dist = diff.dot(&diff);
231                    if dist < min_dist {
232                        min_dist = dist;
233                        best_center = j;
234                    }
235                }
236                assignments[i] = best_center;
237            }
238
239            // Update centers
240            for j in 0..n_components {
241                let cluster_points: Vec<usize> = assignments
242                    .iter()
243                    .enumerate()
244                    .filter(|(_, &assignment)| assignment == j)
245                    .map(|(i, _)| i)
246                    .collect();
247
248                if !cluster_points.is_empty() {
249                    let mut new_center = Array1::zeros(n_features);
250                    for &point_idx in &cluster_points {
251                        new_center = new_center + x.row(point_idx);
252                    }
253                    new_center /= cluster_points.len() as Float;
254                    centers.row_mut(j).assign(&new_center);
255                }
256            }
257        }
258
259        // Find closest points to final centers
260        let mut selected_indices = Vec::new();
261        for j in 0..n_components {
262            let mut min_dist = Float::INFINITY;
263            let mut best_point = 0;
264
265            for i in 0..n_samples {
266                let diff = &x.row(i) - &centers.row(j);
267                let dist = diff.dot(&diff);
268                if dist < min_dist {
269                    min_dist = dist;
270                    best_point = i;
271                }
272            }
273            selected_indices.push(best_point);
274        }
275
276        selected_indices.sort_unstable();
277        selected_indices.dedup();
278
279        // Fill remaining slots randomly if needed
280        while selected_indices.len() < n_components {
281            let random_idx = rng.gen_range(0..n_samples);
282            if !selected_indices.contains(&random_idx) {
283                selected_indices.push(random_idx);
284            }
285        }
286
287        Ok(selected_indices[..n_components].to_vec())
288    }
289
290    /// Leverage score based sampling
291    fn leverage_score_sampling(
292        &self,
293        x: &Array2<Float>,
294        n_components: usize,
295        _rng: &mut RealStdRng,
296    ) -> Result<Vec<usize>> {
297        let (n_samples, _) = x.dim();
298
299        // Compute leverage scores (diagonal of hat matrix)
300        // For simplicity, we approximate using row norms as proxy
301        let mut scores = Vec::new();
302        for i in 0..n_samples {
303            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
304            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
305        }
306
307        // Sample based on scores using cumulative distribution
308        let total_score: Float = scores.iter().sum();
309        if total_score <= 0.0 {
310            return Err(SklearsError::InvalidInput(
311                "All scores are zero or negative".to_string(),
312            ));
313        }
314
315        // Create cumulative distribution
316        let mut cumulative = Vec::with_capacity(scores.len());
317        let mut sum = 0.0;
318        for &score in &scores {
319            sum += score / total_score;
320            cumulative.push(sum);
321        }
322
323        let mut selected_indices = Vec::new();
324        for _ in 0..n_components {
325            let r = thread_rng().gen::<Float>();
326            // Find index where cumulative probability >= r
327            let mut idx = cumulative
328                .iter()
329                .position(|&cum| cum >= r)
330                .unwrap_or(scores.len() - 1);
331
332            // Ensure no duplicates
333            while selected_indices.contains(&idx) {
334                let r = thread_rng().gen::<Float>();
335                idx = cumulative
336                    .iter()
337                    .position(|&cum| cum >= r)
338                    .unwrap_or(scores.len() - 1);
339            }
340            selected_indices.push(idx);
341        }
342
343        Ok(selected_indices)
344    }
345
346    /// Column norm based sampling
347    fn column_norm_sampling(
348        &self,
349        x: &Array2<Float>,
350        n_components: usize,
351        rng: &mut RealStdRng,
352    ) -> Result<Vec<usize>> {
353        let (n_samples, _) = x.dim();
354
355        // Compute row norms
356        let mut norms = Vec::new();
357        for i in 0..n_samples {
358            let norm = x.row(i).dot(&x.row(i)).sqrt();
359            norms.push(norm + 1e-10);
360        }
361
362        // Sort by norm and take diverse selection
363        let mut indices_with_norms: Vec<(usize, Float)> = norms
364            .iter()
365            .enumerate()
366            .map(|(i, &norm)| (i, norm))
367            .collect();
368        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
369
370        let mut selected_indices = Vec::new();
371        let step = n_samples.max(1) / n_components.max(1);
372
373        for i in 0..n_components {
374            let idx = (i * step).min(n_samples - 1);
375            selected_indices.push(indices_with_norms[idx].0);
376        }
377
378        // Fill remaining with random if needed
379        while selected_indices.len() < n_components {
380            let random_idx = rng.gen_range(0..n_samples);
381            if !selected_indices.contains(&random_idx) {
382                selected_indices.push(random_idx);
383            }
384        }
385
386        Ok(selected_indices)
387    }
388
389    /// Compute eigendecomposition using power iteration method
390    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
391    fn compute_eigendecomposition(
392        &self,
393        matrix: &Array2<Float>,
394        rng: &mut RealStdRng,
395    ) -> Result<(Array1<Float>, Array2<Float>)> {
396        let n = matrix.nrows();
397
398        if n != matrix.ncols() {
399            return Err(SklearsError::InvalidInput(
400                "Matrix must be square for eigendecomposition".to_string(),
401            ));
402        }
403
404        let mut eigenvals = Array1::zeros(n);
405        let mut eigenvecs = Array2::zeros((n, n));
406
407        // Use deflation method to find multiple eigenvalues
408        let mut deflated_matrix = matrix.clone();
409
410        for k in 0..n {
411            // Power iteration for k-th eigenvalue/eigenvector
412            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8, rng)?;
413
414            eigenvals[k] = eigenval;
415            eigenvecs.column_mut(k).assign(&eigenvec);
416
417            // Deflate matrix: A_new = A - λ * v * v^T
418            for i in 0..n {
419                for j in 0..n {
420                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
421                }
422            }
423        }
424
425        // Sort eigenvalues and eigenvectors in descending order
426        let mut indices: Vec<usize> = (0..n).collect();
427        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
428
429        let mut sorted_eigenvals = Array1::zeros(n);
430        let mut sorted_eigenvecs = Array2::zeros((n, n));
431
432        for (new_idx, &old_idx) in indices.iter().enumerate() {
433            sorted_eigenvals[new_idx] = eigenvals[old_idx];
434            sorted_eigenvecs
435                .column_mut(new_idx)
436                .assign(&eigenvecs.column(old_idx));
437        }
438
439        Ok((sorted_eigenvals, sorted_eigenvecs))
440    }
441
442    /// Power iteration method to find dominant eigenvalue and eigenvector
443    fn power_iteration(
444        &self,
445        matrix: &Array2<Float>,
446        max_iter: usize,
447        tol: Float,
448        rng: &mut RealStdRng,
449    ) -> Result<(Float, Array1<Float>)> {
450        let n = matrix.nrows();
451
452        // Initialize random vector
453        let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
454
455        // Normalize
456        let norm = v.dot(&v).sqrt();
457        if norm < 1e-10 {
458            return Err(SklearsError::InvalidInput(
459                "Initial vector has zero norm".to_string(),
460            ));
461        }
462        v /= norm;
463
464        let mut eigenval = 0.0;
465
466        for _iter in 0..max_iter {
467            // Apply matrix
468            let w = matrix.dot(&v);
469
470            // Compute Rayleigh quotient
471            let new_eigenval = v.dot(&w);
472
473            // Normalize
474            let w_norm = w.dot(&w).sqrt();
475            if w_norm < 1e-10 {
476                break;
477            }
478            let new_v = w / w_norm;
479
480            // Check convergence
481            let eigenval_change = (new_eigenval - eigenval).abs();
482            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
483
484            if eigenval_change < tol && vector_change < tol {
485                return Ok((new_eigenval, new_v));
486            }
487
488            eigenval = new_eigenval;
489            v = new_v;
490        }
491
492        Ok((eigenval, v))
493    }
494}
495
496impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
497    type Fitted = Nystroem<Trained>;
498
499    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500        let (n_samples, _) = x.dim();
501
502        if self.n_components > n_samples {
503            eprintln!(
504                "Warning: n_components ({}) > n_samples ({})",
505                self.n_components, n_samples
506            );
507        }
508
509        let n_components_actual = self.n_components.min(n_samples);
510
511        let mut rng = if let Some(seed) = self.random_state {
512            RealStdRng::seed_from_u64(seed)
513        } else {
514            RealStdRng::from_seed(thread_rng().gen())
515        };
516
517        // Select component indices using specified strategy
518        let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
519
520        // Extract component samples
521        let mut components = Array2::zeros((n_components_actual, x.ncols()));
522        for (i, &idx) in component_indices.iter().enumerate() {
523            components.row_mut(i).assign(&x.row(idx));
524        }
525
526        // Compute kernel matrix K_11 on sampled points
527        let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
528
529        // Proper Nyström approximation using eigendecomposition
530        // K ≈ K₁₂ K₁₁⁻¹ K₁₂ᵀ where K₁₁⁻¹ is the pseudo-inverse of landmark kernel matrix
531        let eps = 1e-12;
532
533        // Add small regularization to diagonal for numerical stability
534        let mut k11_reg = k11.clone();
535        for i in 0..n_components_actual {
536            k11_reg[[i, i]] += eps;
537        }
538
539        // Compute pseudo-inverse using eigendecomposition
540        // For symmetric positive definite matrices, we can use power iteration for eigendecomposition
541        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(&k11_reg, &mut rng)?;
542
543        // Filter out small eigenvalues for numerical stability
544        let threshold = 1e-8;
545        let valid_indices: Vec<usize> = eigenvals
546            .iter()
547            .enumerate()
548            .filter(|(_, &val)| val > threshold)
549            .map(|(i, _)| i)
550            .collect();
551
552        if valid_indices.is_empty() {
553            return Err(SklearsError::InvalidInput(
554                "No valid eigenvalues found in kernel matrix".to_string(),
555            ));
556        }
557
558        // Construct pseudo-inverse: V * Λ⁻¹ * V^T
559        let _n_valid = valid_indices.len();
560        let mut pseudo_inverse = Array2::zeros((n_components_actual, n_components_actual));
561
562        for i in 0..n_components_actual {
563            for j in 0..n_components_actual {
564                let mut sum = 0.0;
565                for &k in &valid_indices {
566                    sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
567                }
568                pseudo_inverse[[i, j]] = sum;
569            }
570        }
571
572        let normalization = pseudo_inverse;
573
574        Ok(Nystroem {
575            kernel: self.kernel,
576            n_components: self.n_components,
577            sampling_strategy: self.sampling_strategy,
578            random_state: self.random_state,
579            components_: Some(components),
580            normalization_: Some(normalization),
581            component_indices_: Some(component_indices),
582            _state: PhantomData,
583        })
584    }
585}
586
587impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
588    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
589        let components = self.components_.as_ref().unwrap();
590        let normalization = self.normalization_.as_ref().unwrap();
591
592        if x.ncols() != components.ncols() {
593            return Err(SklearsError::InvalidInput(format!(
594                "X has {} features, but Nystroem was fitted with {} features",
595                x.ncols(),
596                components.ncols()
597            )));
598        }
599
600        // Compute kernel matrix K(X, components)
601        let k_x_components = self.kernel.compute_kernel(x, components);
602
603        // Apply normalization: K(X, components) @ normalization
604        let result = k_x_components.dot(normalization);
605
606        Ok(result)
607    }
608}
609
610impl Nystroem<Trained> {
611    /// Get the selected component samples
612    pub fn components(&self) -> &Array2<Float> {
613        self.components_.as_ref().unwrap()
614    }
615
616    /// Get the component indices
617    pub fn component_indices(&self) -> &[usize] {
618        self.component_indices_.as_ref().unwrap()
619    }
620
621    /// Get the normalization matrix
622    pub fn normalization(&self) -> &Array2<Float> {
623        self.normalization_.as_ref().unwrap()
624    }
625}
626
627#[allow(non_snake_case)]
628#[cfg(test)]
629mod tests {
630    use super::*;
631    use scirs2_core::ndarray::array;
632
633    #[test]
634    fn test_nystroem_linear_kernel() {
635        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
636
637        let nystroem = Nystroem::new(Kernel::Linear, 3);
638        let fitted = nystroem.fit(&x, &()).unwrap();
639        let x_transformed = fitted.transform(&x).unwrap();
640
641        assert_eq!(x_transformed.nrows(), 4);
642        assert!(x_transformed.ncols() <= 3); // May be less due to eigenvalue filtering
643    }
644
645    #[test]
646    fn test_nystroem_rbf_kernel() {
647        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
648
649        let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
650        let fitted = nystroem.fit(&x, &()).unwrap();
651        let x_transformed = fitted.transform(&x).unwrap();
652
653        assert_eq!(x_transformed.nrows(), 3);
654        assert!(x_transformed.ncols() <= 2);
655    }
656
657    #[test]
658    fn test_nystroem_polynomial_kernel() {
659        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
660
661        let kernel = Kernel::Polynomial {
662            gamma: 1.0,
663            coef0: 1.0,
664            degree: 2,
665        };
666        let nystroem = Nystroem::new(kernel, 2);
667        let fitted = nystroem.fit(&x, &()).unwrap();
668        let x_transformed = fitted.transform(&x).unwrap();
669
670        assert_eq!(x_transformed.nrows(), 3);
671        assert!(x_transformed.ncols() <= 2);
672    }
673
674    #[test]
675    fn test_nystroem_reproducibility() {
676        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
677
678        let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
679        let fitted1 = nystroem1.fit(&x, &()).unwrap();
680        let result1 = fitted1.transform(&x).unwrap();
681
682        let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
683        let fitted2 = nystroem2.fit(&x, &()).unwrap();
684        let result2 = fitted2.transform(&x).unwrap();
685
686        // Results should be very similar with same random state (allowing for numerical precision)
687        assert_eq!(result1.shape(), result2.shape());
688        for (a, b) in result1.iter().zip(result2.iter()) {
689            assert!(
690                (a - b).abs() < 1e-6,
691                "Values differ too much: {} vs {}",
692                a,
693                b
694            );
695        }
696    }
697
698    #[test]
699    fn test_nystroem_feature_mismatch() {
700        let x_train = array![[1.0, 2.0], [3.0, 4.0],];
701
702        let x_test = array![
703            [1.0, 2.0, 3.0], // Wrong number of features
704        ];
705
706        let nystroem = Nystroem::new(Kernel::Linear, 2);
707        let fitted = nystroem.fit(&x_train, &()).unwrap();
708        let result = fitted.transform(&x_test);
709
710        assert!(result.is_err());
711    }
712
713    #[test]
714    fn test_nystroem_sampling_strategies() {
715        let x = array![
716            [1.0, 2.0],
717            [3.0, 4.0],
718            [5.0, 6.0],
719            [7.0, 8.0],
720            [2.0, 1.0],
721            [4.0, 3.0],
722            [6.0, 5.0],
723            [8.0, 7.0]
724        ];
725
726        // Test Random sampling
727        let nystroem_random = Nystroem::new(Kernel::Linear, 4)
728            .sampling_strategy(SamplingStrategy::Random)
729            .random_state(42);
730        let fitted_random = nystroem_random.fit(&x, &()).unwrap();
731        let result_random = fitted_random.transform(&x).unwrap();
732        assert_eq!(result_random.nrows(), 8);
733
734        // Test K-means sampling
735        let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
736            .sampling_strategy(SamplingStrategy::KMeans)
737            .random_state(42);
738        let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
739        let result_kmeans = fitted_kmeans.transform(&x).unwrap();
740        assert_eq!(result_kmeans.nrows(), 8);
741
742        // Test Leverage score sampling
743        let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
744            .sampling_strategy(SamplingStrategy::LeverageScore)
745            .random_state(42);
746        let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
747        let result_leverage = fitted_leverage.transform(&x).unwrap();
748        assert_eq!(result_leverage.nrows(), 8);
749
750        // Test Column norm sampling
751        let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
752            .sampling_strategy(SamplingStrategy::ColumnNorm)
753            .random_state(42);
754        let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
755        let result_norm = fitted_norm.transform(&x).unwrap();
756        assert_eq!(result_norm.nrows(), 8);
757    }
758
759    #[test]
760    fn test_nystroem_rbf_with_different_sampling() {
761        let x = array![
762            [1.0, 2.0],
763            [3.0, 4.0],
764            [5.0, 6.0],
765            [7.0, 8.0],
766            [2.0, 1.0],
767            [4.0, 3.0],
768            [6.0, 5.0],
769            [8.0, 7.0]
770        ];
771
772        let kernel = Kernel::Rbf { gamma: 0.1 };
773
774        // Test with leverage score sampling
775        let nystroem = Nystroem::new(kernel, 4)
776            .sampling_strategy(SamplingStrategy::LeverageScore)
777            .random_state(42);
778        let fitted = nystroem.fit(&x, &()).unwrap();
779        let result = fitted.transform(&x).unwrap();
780
781        assert_eq!(result.shape(), &[8, 4]);
782
783        // Check that all values are finite
784        for val in result.iter() {
785            assert!(val.is_finite());
786        }
787    }
788
789    #[test]
790    fn test_nystroem_improved_eigendecomposition() {
791        let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
792
793        let nystroem = Nystroem::new(Kernel::Linear, 3)
794            .sampling_strategy(SamplingStrategy::Random)
795            .random_state(42);
796        let fitted = nystroem.fit(&x, &()).unwrap();
797        let result = fitted.transform(&x).unwrap();
798
799        assert_eq!(result.nrows(), 4);
800        assert!(result.ncols() <= 3);
801
802        // Check numerical stability
803        for val in result.iter() {
804            assert!(val.is_finite());
805        }
806    }
807}