Skip to main content

sklears_kernel_approximation/
incremental_nystroem.rs

1//! Incremental Nyström method for online kernel approximation
2//!
3//! This module implements online/incremental versions of the Nyström method
4//! that can efficiently update kernel approximations as new data arrives.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Estimator, Fit, Trained, Transform, Untrained},
14    types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20/// Update strategy for incremental Nyström
21#[derive(Debug, Clone)]
22/// UpdateStrategy
23pub enum UpdateStrategy {
24    /// Simple addition of new landmarks
25    Append,
26    /// Replace oldest landmarks with new ones (sliding window)
27    SlidingWindow,
28    /// Merge new data with existing approximation
29    Merge,
30    /// Selective update based on approximation quality
31    Selective { threshold: Float },
32}
33
34/// Incremental Nyström method for online kernel approximation
35///
36/// Enables efficient updating of kernel approximations as new data arrives,
37/// without requiring complete recomputation from scratch.
38///
39/// # Parameters
40///
41/// * `kernel` - Kernel function to approximate
42/// * `n_components` - Maximum number of landmark points
43/// * `update_strategy` - Strategy for incorporating new data
44/// * `min_update_size` - Minimum number of new samples before updating
45/// * `sampling_strategy` - Strategy for selecting new landmarks
46#[derive(Debug, Clone)]
47pub struct IncrementalNystroem<State = Untrained> {
48    pub kernel: Kernel,
49    pub n_components: usize,
50    pub update_strategy: UpdateStrategy,
51    pub min_update_size: usize,
52    pub sampling_strategy: SamplingStrategy,
53    pub random_state: Option<u64>,
54
55    // Fitted attributes
56    components_: Option<Array2<Float>>,
57    normalization_: Option<Array2<Float>>,
58    component_indices_: Option<Vec<usize>>,
59    landmark_data_: Option<Array2<Float>>,
60    update_count_: usize,
61    accumulated_data_: Option<Array2<Float>>,
62
63    _state: PhantomData<State>,
64}
65
66impl IncrementalNystroem<Untrained> {
67    /// Create a new incremental Nyström approximator
68    pub fn new(kernel: Kernel, n_components: usize) -> Self {
69        Self {
70            kernel,
71            n_components,
72            update_strategy: UpdateStrategy::Append,
73            min_update_size: 10,
74            sampling_strategy: SamplingStrategy::Random,
75            random_state: None,
76            components_: None,
77            normalization_: None,
78            component_indices_: None,
79            landmark_data_: None,
80            update_count_: 0,
81            accumulated_data_: None,
82            _state: PhantomData,
83        }
84    }
85
86    /// Set the update strategy
87    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
88        self.update_strategy = strategy;
89        self
90    }
91
92    /// Set minimum update size
93    pub fn min_update_size(mut self, size: usize) -> Self {
94        self.min_update_size = size;
95        self
96    }
97
98    /// Set sampling strategy
99    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
100        self.sampling_strategy = strategy;
101        self
102    }
103
104    /// Set random state
105    pub fn random_state(mut self, seed: u64) -> Self {
106        self.random_state = Some(seed);
107        self
108    }
109}
110
111impl Estimator for IncrementalNystroem<Untrained> {
112    type Config = ();
113    type Error = SklearsError;
114    type Float = Float;
115
116    fn config(&self) -> &Self::Config {
117        &()
118    }
119}
120
121impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
122    type Fitted = IncrementalNystroem<Trained>;
123
124    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
125        let (n_samples, _n_features) = x.dim();
126        let n_components = self.n_components.min(n_samples);
127
128        let mut rng = match self.random_state {
129            Some(seed) => RealStdRng::seed_from_u64(seed),
130            None => RealStdRng::from_seed(thread_rng().gen()),
131        };
132
133        // Select initial landmark points
134        let component_indices = self.select_components(x, n_components, &mut rng)?;
135        let landmark_data = self.extract_landmarks(x, &component_indices);
136
137        // Compute kernel matrix for landmarks
138        let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
139
140        // Compute eigendecomposition
141        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
142
143        Ok(IncrementalNystroem {
144            kernel: self.kernel,
145            n_components: self.n_components,
146            update_strategy: self.update_strategy,
147            min_update_size: self.min_update_size,
148            sampling_strategy: self.sampling_strategy,
149            random_state: self.random_state,
150            components_: Some(components),
151            normalization_: Some(normalization),
152            component_indices_: Some(component_indices),
153            landmark_data_: Some(landmark_data),
154            update_count_: 0,
155            accumulated_data_: None,
156            _state: PhantomData,
157        })
158    }
159}
160
161impl IncrementalNystroem<Untrained> {
162    /// Select component indices based on sampling strategy
163    fn select_components(
164        &self,
165        x: &Array2<Float>,
166        n_components: usize,
167        rng: &mut RealStdRng,
168    ) -> Result<Vec<usize>> {
169        let (n_samples, _) = x.dim();
170
171        match &self.sampling_strategy {
172            SamplingStrategy::Random => {
173                let mut indices: Vec<usize> = (0..n_samples).collect();
174                indices.shuffle(rng);
175                Ok(indices[..n_components].to_vec())
176            }
177            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
178            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
179            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
180        }
181    }
182
183    /// Simple k-means based sampling
184    fn kmeans_sampling(
185        &self,
186        x: &Array2<Float>,
187        n_components: usize,
188        rng: &mut RealStdRng,
189    ) -> Result<Vec<usize>> {
190        let (n_samples, n_features) = x.dim();
191        let mut centers = Array2::zeros((n_components, n_features));
192
193        // Initialize centers randomly
194        let mut indices: Vec<usize> = (0..n_samples).collect();
195        indices.shuffle(rng);
196        for (i, &idx) in indices[..n_components].iter().enumerate() {
197            centers.row_mut(i).assign(&x.row(idx));
198        }
199
200        // Run a few iterations of k-means
201        for _iter in 0..5 {
202            let mut assignments = vec![0; n_samples];
203
204            // Assign points to nearest centers
205            for i in 0..n_samples {
206                let mut min_dist = Float::INFINITY;
207                let mut best_center = 0;
208
209                for j in 0..n_components {
210                    let diff = &x.row(i) - &centers.row(j);
211                    let dist = diff.dot(&diff);
212                    if dist < min_dist {
213                        min_dist = dist;
214                        best_center = j;
215                    }
216                }
217                assignments[i] = best_center;
218            }
219
220            // Update centers
221            for j in 0..n_components {
222                let cluster_points: Vec<usize> = assignments
223                    .iter()
224                    .enumerate()
225                    .filter(|(_, &assignment)| assignment == j)
226                    .map(|(i, _)| i)
227                    .collect();
228
229                if !cluster_points.is_empty() {
230                    let mut new_center = Array1::zeros(n_features);
231                    for &point_idx in &cluster_points {
232                        new_center = new_center + x.row(point_idx);
233                    }
234                    new_center /= cluster_points.len() as Float;
235                    centers.row_mut(j).assign(&new_center);
236                }
237            }
238        }
239
240        // Find closest points to final centers
241        let mut selected_indices = Vec::new();
242        for j in 0..n_components {
243            let mut min_dist = Float::INFINITY;
244            let mut best_point = 0;
245
246            for i in 0..n_samples {
247                let diff = &x.row(i) - &centers.row(j);
248                let dist = diff.dot(&diff);
249                if dist < min_dist {
250                    min_dist = dist;
251                    best_point = i;
252                }
253            }
254            selected_indices.push(best_point);
255        }
256
257        selected_indices.sort_unstable();
258        selected_indices.dedup();
259
260        // Fill remaining slots randomly if needed
261        while selected_indices.len() < n_components {
262            let random_idx = rng.gen_range(0..n_samples);
263            if !selected_indices.contains(&random_idx) {
264                selected_indices.push(random_idx);
265            }
266        }
267
268        Ok(selected_indices[..n_components].to_vec())
269    }
270
271    /// Leverage score based sampling
272    fn leverage_score_sampling(
273        &self,
274        x: &Array2<Float>,
275        n_components: usize,
276        _rng: &mut RealStdRng,
277    ) -> Result<Vec<usize>> {
278        let (n_samples, _) = x.dim();
279
280        // Compute leverage scores (diagonal of hat matrix)
281        // For simplicity, we approximate using row norms as proxy
282        let mut scores = Vec::new();
283        for i in 0..n_samples {
284            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
285            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
286        }
287
288        // Sample based on scores using cumulative distribution
289        let total_score: Float = scores.iter().sum();
290        if total_score <= 0.0 {
291            return Err(SklearsError::InvalidInput(
292                "All scores are zero or negative".to_string(),
293            ));
294        }
295
296        // Create cumulative distribution
297        let mut cumulative = Vec::with_capacity(scores.len());
298        let mut sum = 0.0;
299        for &score in &scores {
300            sum += score / total_score;
301            cumulative.push(sum);
302        }
303
304        let mut selected_indices = Vec::new();
305        for _ in 0..n_components {
306            let r = thread_rng().gen::<Float>();
307            // Find index where cumulative probability >= r
308            let mut idx = cumulative
309                .iter()
310                .position(|&cum| cum >= r)
311                .unwrap_or(scores.len() - 1);
312
313            // Ensure no duplicates
314            while selected_indices.contains(&idx) {
315                let r = thread_rng().gen::<Float>();
316                idx = cumulative
317                    .iter()
318                    .position(|&cum| cum >= r)
319                    .unwrap_or(scores.len() - 1);
320            }
321            selected_indices.push(idx);
322        }
323
324        Ok(selected_indices)
325    }
326
327    /// Column norm based sampling
328    fn column_norm_sampling(
329        &self,
330        x: &Array2<Float>,
331        n_components: usize,
332        rng: &mut RealStdRng,
333    ) -> Result<Vec<usize>> {
334        let (n_samples, _) = x.dim();
335
336        // Compute row norms
337        let mut norms = Vec::new();
338        for i in 0..n_samples {
339            let norm = x.row(i).dot(&x.row(i)).sqrt();
340            norms.push(norm + 1e-10);
341        }
342
343        // Sort by norm and take diverse selection
344        let mut indices_with_norms: Vec<(usize, Float)> = norms
345            .iter()
346            .enumerate()
347            .map(|(i, &norm)| (i, norm))
348            .collect();
349        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
350
351        let mut selected_indices = Vec::new();
352        let step = n_samples.max(1) / n_components.max(1);
353
354        for i in 0..n_components {
355            let idx = (i * step).min(n_samples - 1);
356            selected_indices.push(indices_with_norms[idx].0);
357        }
358
359        // Fill remaining with random if needed
360        while selected_indices.len() < n_components {
361            let random_idx = rng.gen_range(0..n_samples);
362            if !selected_indices.contains(&random_idx) {
363                selected_indices.push(random_idx);
364            }
365        }
366
367        Ok(selected_indices)
368    }
369
370    /// Extract landmark data points
371    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
372        let (_, n_features) = x.dim();
373        let mut landmarks = Array2::zeros((indices.len(), n_features));
374
375        for (i, &idx) in indices.iter().enumerate() {
376            landmarks.row_mut(i).assign(&x.row(idx));
377        }
378
379        landmarks
380    }
381
382    /// Compute eigendecomposition using power iteration method
383    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
384    fn compute_eigendecomposition(
385        &self,
386        matrix: Array2<Float>,
387    ) -> Result<(Array1<Float>, Array2<Float>)> {
388        let n = matrix.nrows();
389
390        if n != matrix.ncols() {
391            return Err(SklearsError::InvalidInput(
392                "Matrix must be square for eigendecomposition".to_string(),
393            ));
394        }
395
396        let mut eigenvals = Array1::zeros(n);
397        let mut eigenvecs = Array2::zeros((n, n));
398
399        // Use deflation method to find multiple eigenvalues
400        let mut deflated_matrix = matrix.clone();
401
402        for k in 0..n {
403            // Power iteration for k-th eigenvalue/eigenvector
404            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
405
406            eigenvals[k] = eigenval;
407            eigenvecs.column_mut(k).assign(&eigenvec);
408
409            // Deflate matrix: A_new = A - λ * v * v^T
410            for i in 0..n {
411                for j in 0..n {
412                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
413                }
414            }
415        }
416
417        // Sort eigenvalues and eigenvectors in descending order
418        let mut indices: Vec<usize> = (0..n).collect();
419        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
420
421        let mut sorted_eigenvals = Array1::zeros(n);
422        let mut sorted_eigenvecs = Array2::zeros((n, n));
423
424        for (new_idx, &old_idx) in indices.iter().enumerate() {
425            sorted_eigenvals[new_idx] = eigenvals[old_idx];
426            sorted_eigenvecs
427                .column_mut(new_idx)
428                .assign(&eigenvecs.column(old_idx));
429        }
430
431        Ok((sorted_eigenvals, sorted_eigenvecs))
432    }
433
434    /// Power iteration method to find dominant eigenvalue and eigenvector
435    fn power_iteration(
436        &self,
437        matrix: &Array2<Float>,
438        max_iter: usize,
439        tol: Float,
440    ) -> Result<(Float, Array1<Float>)> {
441        let n = matrix.nrows();
442
443        // Initialize with deterministic vector to ensure reproducibility
444        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
445
446        // Normalize
447        let norm = v.dot(&v).sqrt();
448        if norm < 1e-10 {
449            return Err(SklearsError::InvalidInput(
450                "Initial vector has zero norm".to_string(),
451            ));
452        }
453        v /= norm;
454
455        let mut eigenval = 0.0;
456
457        for _iter in 0..max_iter {
458            // Apply matrix
459            let w = matrix.dot(&v);
460
461            // Compute Rayleigh quotient
462            let new_eigenval = v.dot(&w);
463
464            // Normalize
465            let w_norm = w.dot(&w).sqrt();
466            if w_norm < 1e-10 {
467                break;
468            }
469            let new_v = w / w_norm;
470
471            // Check convergence
472            let eigenval_change = (new_eigenval - eigenval).abs();
473            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
474
475            if eigenval_change < tol && vector_change < tol {
476                return Ok((new_eigenval, new_v));
477            }
478
479            eigenval = new_eigenval;
480            v = new_v;
481        }
482
483        Ok((eigenval, v))
484    }
485
486    /// Compute eigendecomposition of kernel matrix
487    fn compute_decomposition(
488        &self,
489        mut kernel_matrix: Array2<Float>,
490    ) -> Result<(Array2<Float>, Array2<Float>)> {
491        // Add small regularization to diagonal for numerical stability
492        let reg = 1e-8;
493        for i in 0..kernel_matrix.nrows() {
494            kernel_matrix[[i, i]] += reg;
495        }
496
497        // Proper eigendecomposition for Nyström method
498        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
499
500        // Filter out small eigenvalues for numerical stability
501        let threshold = 1e-8;
502        let valid_indices: Vec<usize> = eigenvals
503            .iter()
504            .enumerate()
505            .filter(|(_, &val)| val > threshold)
506            .map(|(i, _)| i)
507            .collect();
508
509        if valid_indices.is_empty() {
510            return Err(SklearsError::InvalidInput(
511                "No valid eigenvalues found in kernel matrix".to_string(),
512            ));
513        }
514
515        // Construct components and normalization matrices
516        let n_valid = valid_indices.len();
517        let mut components = Array2::zeros((eigenvals.len(), n_valid));
518        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
519
520        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
521            let sqrt_eigenval = eigenvals[old_idx].sqrt();
522            components
523                .column_mut(new_idx)
524                .assign(&eigenvecs.column(old_idx));
525
526            // For Nyström method: normalization = V * Λ^(-1/2)
527            for i in 0..eigenvals.len() {
528                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
529            }
530        }
531
532        Ok((components, normalization))
533    }
534}
535
536impl IncrementalNystroem<Trained> {
537    /// Select component indices based on sampling strategy
538    fn select_components(
539        &self,
540        x: &Array2<Float>,
541        n_components: usize,
542        rng: &mut RealStdRng,
543    ) -> Result<Vec<usize>> {
544        let (n_samples, _) = x.dim();
545
546        match &self.sampling_strategy {
547            SamplingStrategy::Random => {
548                let mut indices: Vec<usize> = (0..n_samples).collect();
549                indices.shuffle(rng);
550                Ok(indices[..n_components].to_vec())
551            }
552            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
553            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
554            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
555        }
556    }
557
558    /// Simple k-means based sampling
559    fn kmeans_sampling(
560        &self,
561        x: &Array2<Float>,
562        n_components: usize,
563        rng: &mut RealStdRng,
564    ) -> Result<Vec<usize>> {
565        let (n_samples, n_features) = x.dim();
566        let mut centers = Array2::zeros((n_components, n_features));
567
568        // Initialize centers randomly
569        let mut indices: Vec<usize> = (0..n_samples).collect();
570        indices.shuffle(rng);
571        for (i, &idx) in indices[..n_components].iter().enumerate() {
572            centers.row_mut(i).assign(&x.row(idx));
573        }
574
575        // Run a few iterations of k-means
576        for _iter in 0..5 {
577            let mut assignments = vec![0; n_samples];
578
579            // Assign points to nearest centers
580            for i in 0..n_samples {
581                let mut min_dist = Float::INFINITY;
582                let mut best_center = 0;
583
584                for j in 0..n_components {
585                    let diff = &x.row(i) - &centers.row(j);
586                    let dist = diff.dot(&diff);
587                    if dist < min_dist {
588                        min_dist = dist;
589                        best_center = j;
590                    }
591                }
592                assignments[i] = best_center;
593            }
594
595            // Update centers
596            for j in 0..n_components {
597                let cluster_points: Vec<usize> = assignments
598                    .iter()
599                    .enumerate()
600                    .filter(|(_, &assignment)| assignment == j)
601                    .map(|(i, _)| i)
602                    .collect();
603
604                if !cluster_points.is_empty() {
605                    let mut new_center = Array1::zeros(n_features);
606                    for &point_idx in &cluster_points {
607                        new_center = new_center + x.row(point_idx);
608                    }
609                    new_center /= cluster_points.len() as Float;
610                    centers.row_mut(j).assign(&new_center);
611                }
612            }
613        }
614
615        // Find closest points to final centers
616        let mut selected_indices = Vec::new();
617        for j in 0..n_components {
618            let mut min_dist = Float::INFINITY;
619            let mut best_point = 0;
620
621            for i in 0..n_samples {
622                let diff = &x.row(i) - &centers.row(j);
623                let dist = diff.dot(&diff);
624                if dist < min_dist {
625                    min_dist = dist;
626                    best_point = i;
627                }
628            }
629            selected_indices.push(best_point);
630        }
631
632        selected_indices.sort_unstable();
633        selected_indices.dedup();
634
635        // Fill remaining slots randomly if needed
636        while selected_indices.len() < n_components {
637            let random_idx = rng.gen_range(0..n_samples);
638            if !selected_indices.contains(&random_idx) {
639                selected_indices.push(random_idx);
640            }
641        }
642
643        Ok(selected_indices[..n_components].to_vec())
644    }
645
646    /// Leverage score based sampling
647    fn leverage_score_sampling(
648        &self,
649        x: &Array2<Float>,
650        n_components: usize,
651        _rng: &mut RealStdRng,
652    ) -> Result<Vec<usize>> {
653        let (n_samples, _) = x.dim();
654
655        // Compute leverage scores (diagonal of hat matrix)
656        // For simplicity, we approximate using row norms as proxy
657        let mut scores = Vec::new();
658        for i in 0..n_samples {
659            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
660            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
661        }
662
663        // Sample based on scores using cumulative distribution
664        let total_score: Float = scores.iter().sum();
665        if total_score <= 0.0 {
666            return Err(SklearsError::InvalidInput(
667                "All scores are zero or negative".to_string(),
668            ));
669        }
670
671        // Create cumulative distribution
672        let mut cumulative = Vec::with_capacity(scores.len());
673        let mut sum = 0.0;
674        for &score in &scores {
675            sum += score / total_score;
676            cumulative.push(sum);
677        }
678
679        let mut selected_indices = Vec::new();
680        for _ in 0..n_components {
681            let r = thread_rng().gen::<Float>();
682            // Find index where cumulative probability >= r
683            let mut idx = cumulative
684                .iter()
685                .position(|&cum| cum >= r)
686                .unwrap_or(scores.len() - 1);
687
688            // Ensure no duplicates
689            while selected_indices.contains(&idx) {
690                let r = thread_rng().gen::<Float>();
691                idx = cumulative
692                    .iter()
693                    .position(|&cum| cum >= r)
694                    .unwrap_or(scores.len() - 1);
695            }
696            selected_indices.push(idx);
697        }
698
699        Ok(selected_indices)
700    }
701
702    /// Column norm based sampling
703    fn column_norm_sampling(
704        &self,
705        x: &Array2<Float>,
706        n_components: usize,
707        rng: &mut RealStdRng,
708    ) -> Result<Vec<usize>> {
709        let (n_samples, _) = x.dim();
710
711        // Compute row norms
712        let mut norms = Vec::new();
713        for i in 0..n_samples {
714            let norm = x.row(i).dot(&x.row(i)).sqrt();
715            norms.push(norm + 1e-10);
716        }
717
718        // Sort by norm and take diverse selection
719        let mut indices_with_norms: Vec<(usize, Float)> = norms
720            .iter()
721            .enumerate()
722            .map(|(i, &norm)| (i, norm))
723            .collect();
724        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
725
726        let mut selected_indices = Vec::new();
727        let step = n_samples.max(1) / n_components.max(1);
728
729        for i in 0..n_components {
730            let idx = (i * step).min(n_samples - 1);
731            selected_indices.push(indices_with_norms[idx].0);
732        }
733
734        // Fill remaining with random if needed
735        while selected_indices.len() < n_components {
736            let random_idx = rng.gen_range(0..n_samples);
737            if !selected_indices.contains(&random_idx) {
738                selected_indices.push(random_idx);
739            }
740        }
741
742        Ok(selected_indices)
743    }
744
745    /// Update the approximation with new data
746    pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
747        // Accumulate new data
748        match &self.accumulated_data_ {
749            Some(existing) => {
750                let combined =
751                    scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
752                self.accumulated_data_ = Some(combined);
753            }
754            None => {
755                self.accumulated_data_ = Some(x_new.clone());
756            }
757        }
758
759        // Check if we have enough accumulated data to update
760        let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
761            accumulated.nrows() >= self.min_update_size
762        } else {
763            false
764        };
765
766        if should_update {
767            if let Some(accumulated) = self.accumulated_data_.take() {
768                self = self.perform_update(&accumulated)?;
769                self.update_count_ += 1;
770            }
771        }
772
773        Ok(self)
774    }
775
776    /// Perform the actual update based on the strategy
777    fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
778        match self.update_strategy.clone() {
779            UpdateStrategy::Append => self.append_update(new_data),
780            UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
781            UpdateStrategy::Merge => self.merge_update(new_data),
782            UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
783        }
784    }
785
786    /// Append new landmarks (if space available)
787    fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
788        let current_landmarks = self.landmark_data_.as_ref().unwrap();
789        let current_components = current_landmarks.nrows();
790
791        if current_components >= self.n_components {
792            // No space to append, just return current state
793            return Ok(self);
794        }
795
796        let available_space = self.n_components - current_components;
797        let n_new = available_space.min(new_data.nrows());
798
799        if n_new == 0 {
800            return Ok(self);
801        }
802
803        // Select new landmarks from new data
804        let mut rng = match self.random_state {
805            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
806            None => RealStdRng::from_seed(thread_rng().gen()),
807        };
808
809        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
810        indices.shuffle(&mut rng);
811        let selected_indices = &indices[..n_new];
812
813        // Extract new landmarks
814        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
815
816        // Combine with existing landmarks
817        let combined_landmarks =
818            scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
819
820        // Recompute decomposition
821        let kernel_matrix = self
822            .kernel
823            .compute_kernel(&combined_landmarks, &combined_landmarks);
824        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
825
826        // Update indices
827        let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
828        let base_index = current_landmarks.nrows();
829        for &idx in selected_indices {
830            new_component_indices.push(base_index + idx);
831        }
832
833        self.components_ = Some(components);
834        self.normalization_ = Some(normalization);
835        self.component_indices_ = Some(new_component_indices);
836        self.landmark_data_ = Some(combined_landmarks);
837
838        Ok(self)
839    }
840
841    /// Sliding window update (replace oldest landmarks)
842    fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
843        let current_landmarks = self.landmark_data_.as_ref().unwrap();
844        let n_new = new_data.nrows().min(self.n_components);
845
846        if n_new == 0 {
847            return Ok(self);
848        }
849
850        // Select new landmarks
851        let mut rng = match self.random_state {
852            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
853            None => RealStdRng::from_seed(thread_rng().gen()),
854        };
855
856        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
857        indices.shuffle(&mut rng);
858        let selected_indices = &indices[..n_new];
859
860        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
861
862        // Replace oldest landmarks with new ones
863        let n_keep = self.n_components - n_new;
864        let combined_landmarks = if n_keep > 0 {
865            let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
866            scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
867        } else {
868            new_landmarks
869        };
870
871        // Recompute decomposition
872        let kernel_matrix = self
873            .kernel
874            .compute_kernel(&combined_landmarks, &combined_landmarks);
875        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
876
877        // Update component indices (simplified)
878        let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
879
880        self.components_ = Some(components);
881        self.normalization_ = Some(normalization);
882        self.component_indices_ = Some(new_component_indices);
883        self.landmark_data_ = Some(combined_landmarks);
884
885        Ok(self)
886    }
887
888    /// Merge update (combine approximations)
889    fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
890        // Sophisticated merging strategy that combines existing and new Nyström approximations
891        // This is based on the idea of merging two kernel approximations optimally
892
893        let current_landmarks = self.landmark_data_.as_ref().unwrap();
894        let _current_components = self.components_.as_ref().unwrap();
895        let _current_normalization = self.normalization_.as_ref().unwrap();
896
897        // Step 1: Create a new Nyström approximation from the new data
898        let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
899
900        let mut rng = match self.random_state {
901            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
902            None => RealStdRng::from_seed(thread_rng().gen()),
903        };
904
905        // Select new landmarks using the same strategy
906        let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
907        let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
908
909        // Compute new kernel matrix and decomposition
910        let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
911        let (_new_components, _new_normalization) =
912            self.compute_decomposition(new_kernel_matrix)?;
913
914        // Step 2: Combine the landmarks intelligently
915        // Merge by selecting the most diverse/informative landmarks from both sets
916        let merged_landmarks =
917            self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
918
919        // Step 3: Recompute the full approximation on merged landmarks
920        let merged_kernel_matrix = self
921            .kernel
922            .compute_kernel(&merged_landmarks, &merged_landmarks);
923        let (final_components, final_normalization) =
924            self.compute_decomposition(merged_kernel_matrix)?;
925
926        // Update component indices (simplified for merged case)
927        let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
928
929        let mut updated_self = self;
930        updated_self.components_ = Some(final_components);
931        updated_self.normalization_ = Some(final_normalization);
932        updated_self.component_indices_ = Some(final_component_indices);
933        updated_self.landmark_data_ = Some(merged_landmarks);
934
935        Ok(updated_self)
936    }
937
938    /// Intelligently merge landmarks from existing and new data
939    fn merge_landmarks_intelligently(
940        &self,
941        current_landmarks: &Array2<Float>,
942        new_landmarks: &Array2<Float>,
943        rng: &mut RealStdRng,
944    ) -> Result<Array2<Float>> {
945        let n_current = current_landmarks.nrows();
946        let n_new = new_landmarks.nrows();
947        let n_features = current_landmarks.ncols();
948
949        // Combine all landmarks temporarily
950        let all_landmarks = scirs2_core::ndarray::concatenate![
951            Axis(0),
952            current_landmarks.clone(),
953            new_landmarks.clone()
954        ];
955
956        // Use diversity-based selection to choose the best subset
957        let n_target = self.n_components.min(n_current + n_new);
958        let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
959
960        // Extract selected landmarks
961        let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
962        for (i, &idx) in selected_indices.iter().enumerate() {
963            merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
964        }
965
966        Ok(merged_landmarks)
967    }
968
969    /// Select diverse landmarks using maximum distance criterion
970    fn select_diverse_landmarks(
971        &self,
972        landmarks: &Array2<Float>,
973        n_select: usize,
974        rng: &mut RealStdRng,
975    ) -> Result<Vec<usize>> {
976        let n_landmarks = landmarks.nrows();
977
978        if n_select >= n_landmarks {
979            return Ok((0..n_landmarks).collect());
980        }
981
982        let mut selected = Vec::new();
983        let mut available: Vec<usize> = (0..n_landmarks).collect();
984
985        // Start with a random landmark
986        let first_idx = rng.gen_range(0..available.len());
987        selected.push(available.remove(first_idx));
988
989        // Greedily select landmarks that are maximally distant from already selected ones
990        while selected.len() < n_select && !available.is_empty() {
991            let mut best_idx = 0;
992            let mut max_min_distance = 0.0;
993
994            for (i, &candidate_idx) in available.iter().enumerate() {
995                // Compute minimum distance to already selected landmarks
996                let mut min_distance = Float::INFINITY;
997
998                for &selected_idx in &selected {
999                    let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1000                    let distance = diff.dot(&diff).sqrt();
1001                    if distance < min_distance {
1002                        min_distance = distance;
1003                    }
1004                }
1005
1006                if min_distance > max_min_distance {
1007                    max_min_distance = min_distance;
1008                    best_idx = i;
1009                }
1010            }
1011
1012            selected.push(available.remove(best_idx));
1013        }
1014
1015        Ok(selected)
1016    }
1017
1018    /// Selective update based on approximation quality
1019    fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1020        // Quality-based selective update that only incorporates new data if it improves approximation
1021
1022        let current_landmarks = self.landmark_data_.as_ref().unwrap();
1023
1024        // Step 1: Evaluate current approximation quality on new data
1025        let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1026
1027        // Step 2: Create candidate updates and evaluate their quality
1028        let mut best_update = self.clone();
1029        let mut best_quality = current_quality;
1030
1031        // Try append update
1032        let append_candidate = self.clone().append_update(new_data)?;
1033        let append_quality = append_candidate.evaluate_approximation_quality(
1034            append_candidate.landmark_data_.as_ref().unwrap(),
1035            new_data,
1036        )?;
1037
1038        if append_quality > best_quality + threshold {
1039            best_update = append_candidate;
1040            best_quality = append_quality;
1041        }
1042
1043        // Try merge update if we have enough data
1044        if new_data.nrows() >= 3 {
1045            let merge_candidate = self.clone().merge_update(new_data)?;
1046            let merge_quality = merge_candidate.evaluate_approximation_quality(
1047                merge_candidate.landmark_data_.as_ref().unwrap(),
1048                new_data,
1049            )?;
1050
1051            if merge_quality > best_quality + threshold {
1052                best_update = merge_candidate;
1053                best_quality = merge_quality;
1054            }
1055        }
1056
1057        // Try sliding window update
1058        let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1059        let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1060            sliding_candidate.landmark_data_.as_ref().unwrap(),
1061            new_data,
1062        )?;
1063
1064        if sliding_quality > best_quality + threshold {
1065            best_update = sliding_candidate;
1066            best_quality = sliding_quality;
1067        }
1068
1069        // Step 3: Only update if quality improvement exceeds threshold
1070        if best_quality > current_quality + threshold {
1071            Ok(best_update)
1072        } else {
1073            // No significant improvement, keep current state
1074            Ok(self)
1075        }
1076    }
1077
1078    /// Evaluate approximation quality using kernel approximation error
1079    fn evaluate_approximation_quality(
1080        &self,
1081        landmarks: &Array2<Float>,
1082        test_data: &Array2<Float>,
1083    ) -> Result<Float> {
1084        // Quality metric: negative approximation error (higher is better)
1085
1086        let n_test = test_data.nrows().min(50); // Limit for efficiency
1087        let test_subset = if test_data.nrows() > n_test {
1088            // Sample random subset for evaluation
1089            let mut rng = thread_rng();
1090            let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1091            indices.shuffle(&mut rng);
1092            test_data.select(Axis(0), &indices[..n_test])
1093        } else {
1094            test_data.to_owned()
1095        };
1096
1097        // Compute exact kernel matrix for test subset
1098        let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1099
1100        // Compute Nyström approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1101        let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1102        let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1103
1104        // Use our eigendecomposition to compute pseudo-inverse
1105        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1106
1107        // Construct pseudo-inverse
1108        let threshold = 1e-8;
1109        let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1110
1111        for i in 0..landmarks.nrows() {
1112            for j in 0..landmarks.nrows() {
1113                let mut sum = 0.0;
1114                for k in 0..eigenvals.len() {
1115                    if eigenvals[k] > threshold {
1116                        sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1117                    }
1118                }
1119                pseudo_inverse[[i, j]] = sum;
1120            }
1121        }
1122
1123        // Compute approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1124        let k_approx = k_test_landmarks
1125            .dot(&pseudo_inverse)
1126            .dot(&k_test_landmarks.t());
1127
1128        // Compute approximation error (Frobenius norm)
1129        let error_matrix = &k_exact - &k_approx;
1130        let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1131
1132        // Convert to quality score (negative error, higher is better)
1133        let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1134
1135        Ok(quality)
1136    }
1137
1138    /// Extract landmark data points
1139    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1140        let (_, n_features) = x.dim();
1141        let mut landmarks = Array2::zeros((indices.len(), n_features));
1142
1143        for (i, &idx) in indices.iter().enumerate() {
1144            landmarks.row_mut(i).assign(&x.row(idx));
1145        }
1146
1147        landmarks
1148    }
1149
1150    /// Compute eigendecomposition using power iteration method
1151    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
1152    fn compute_eigendecomposition(
1153        &self,
1154        matrix: Array2<Float>,
1155    ) -> Result<(Array1<Float>, Array2<Float>)> {
1156        let n = matrix.nrows();
1157
1158        if n != matrix.ncols() {
1159            return Err(SklearsError::InvalidInput(
1160                "Matrix must be square for eigendecomposition".to_string(),
1161            ));
1162        }
1163
1164        let mut eigenvals = Array1::zeros(n);
1165        let mut eigenvecs = Array2::zeros((n, n));
1166
1167        // Use deflation method to find multiple eigenvalues
1168        let mut deflated_matrix = matrix.clone();
1169
1170        for k in 0..n {
1171            // Power iteration for k-th eigenvalue/eigenvector
1172            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1173
1174            eigenvals[k] = eigenval;
1175            eigenvecs.column_mut(k).assign(&eigenvec);
1176
1177            // Deflate matrix: A_new = A - λ * v * v^T
1178            for i in 0..n {
1179                for j in 0..n {
1180                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1181                }
1182            }
1183        }
1184
1185        // Sort eigenvalues and eigenvectors in descending order
1186        let mut indices: Vec<usize> = (0..n).collect();
1187        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1188
1189        let mut sorted_eigenvals = Array1::zeros(n);
1190        let mut sorted_eigenvecs = Array2::zeros((n, n));
1191
1192        for (new_idx, &old_idx) in indices.iter().enumerate() {
1193            sorted_eigenvals[new_idx] = eigenvals[old_idx];
1194            sorted_eigenvecs
1195                .column_mut(new_idx)
1196                .assign(&eigenvecs.column(old_idx));
1197        }
1198
1199        Ok((sorted_eigenvals, sorted_eigenvecs))
1200    }
1201
1202    /// Power iteration method to find dominant eigenvalue and eigenvector
1203    fn power_iteration(
1204        &self,
1205        matrix: &Array2<Float>,
1206        max_iter: usize,
1207        tol: Float,
1208    ) -> Result<(Float, Array1<Float>)> {
1209        let n = matrix.nrows();
1210
1211        // Initialize with deterministic vector to ensure reproducibility
1212        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1213
1214        // Normalize
1215        let norm = v.dot(&v).sqrt();
1216        if norm < 1e-10 {
1217            return Err(SklearsError::InvalidInput(
1218                "Initial vector has zero norm".to_string(),
1219            ));
1220        }
1221        v /= norm;
1222
1223        let mut eigenval = 0.0;
1224
1225        for _iter in 0..max_iter {
1226            // Apply matrix
1227            let w = matrix.dot(&v);
1228
1229            // Compute Rayleigh quotient
1230            let new_eigenval = v.dot(&w);
1231
1232            // Normalize
1233            let w_norm = w.dot(&w).sqrt();
1234            if w_norm < 1e-10 {
1235                break;
1236            }
1237            let new_v = w / w_norm;
1238
1239            // Check convergence
1240            let eigenval_change = (new_eigenval - eigenval).abs();
1241            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1242
1243            if eigenval_change < tol && vector_change < tol {
1244                return Ok((new_eigenval, new_v));
1245            }
1246
1247            eigenval = new_eigenval;
1248            v = new_v;
1249        }
1250
1251        Ok((eigenval, v))
1252    }
1253
1254    /// Compute eigendecomposition of kernel matrix
1255    fn compute_decomposition(
1256        &self,
1257        mut kernel_matrix: Array2<Float>,
1258    ) -> Result<(Array2<Float>, Array2<Float>)> {
1259        // Add small regularization to diagonal for numerical stability
1260        let reg = 1e-8;
1261        for i in 0..kernel_matrix.nrows() {
1262            kernel_matrix[[i, i]] += reg;
1263        }
1264
1265        // Proper eigendecomposition for Nyström method
1266        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1267
1268        // Filter out small eigenvalues for numerical stability
1269        let threshold = 1e-8;
1270        let valid_indices: Vec<usize> = eigenvals
1271            .iter()
1272            .enumerate()
1273            .filter(|(_, &val)| val > threshold)
1274            .map(|(i, _)| i)
1275            .collect();
1276
1277        if valid_indices.is_empty() {
1278            return Err(SklearsError::InvalidInput(
1279                "No valid eigenvalues found in kernel matrix".to_string(),
1280            ));
1281        }
1282
1283        // Construct components and normalization matrices
1284        let n_valid = valid_indices.len();
1285        let mut components = Array2::zeros((eigenvals.len(), n_valid));
1286        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1287
1288        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1289            let sqrt_eigenval = eigenvals[old_idx].sqrt();
1290            components
1291                .column_mut(new_idx)
1292                .assign(&eigenvecs.column(old_idx));
1293
1294            // For Nyström method: normalization = V * Λ^(-1/2)
1295            for i in 0..eigenvals.len() {
1296                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1297            }
1298        }
1299
1300        Ok((components, normalization))
1301    }
1302
1303    /// Get number of updates performed
1304    pub fn update_count(&self) -> usize {
1305        self.update_count_
1306    }
1307
1308    /// Get current number of landmarks
1309    pub fn n_landmarks(&self) -> usize {
1310        self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1311    }
1312}
1313
1314impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1315    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1316        let _components = self
1317            .components_
1318            .as_ref()
1319            .ok_or_else(|| SklearsError::NotFitted {
1320                operation: "transform".to_string(),
1321            })?;
1322
1323        let normalization =
1324            self.normalization_
1325                .as_ref()
1326                .ok_or_else(|| SklearsError::NotFitted {
1327                    operation: "transform".to_string(),
1328                })?;
1329
1330        let landmark_data =
1331            self.landmark_data_
1332                .as_ref()
1333                .ok_or_else(|| SklearsError::NotFitted {
1334                    operation: "transform".to_string(),
1335                })?;
1336
1337        // Compute kernel between input and landmarks
1338        let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1339
1340        // Apply transformation: K(X, landmarks) @ normalization.T
1341        let transformed = kernel_x_landmarks.dot(&normalization.t());
1342
1343        Ok(transformed)
1344    }
1345}
1346
1347#[allow(non_snake_case)]
1348#[cfg(test)]
1349mod tests {
1350    use super::*;
1351    use scirs2_core::ndarray::array;
1352
1353    #[test]
1354    fn test_incremental_nystroem_basic() {
1355        let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1356        let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1357
1358        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1359            .update_strategy(UpdateStrategy::Append)
1360            .min_update_size(1);
1361
1362        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1363        assert_eq!(fitted.n_landmarks(), 3);
1364
1365        let updated = fitted.update(&x_new).unwrap();
1366        assert_eq!(updated.n_landmarks(), 5);
1367        assert_eq!(updated.update_count(), 1);
1368    }
1369
1370    #[test]
1371    fn test_incremental_transform() {
1372        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1373        let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1374
1375        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1376        let fitted = nystroem.fit(&x_train, &()).unwrap();
1377
1378        let transformed = fitted.transform(&x_test).unwrap();
1379        assert_eq!(transformed.shape()[0], 2);
1380        assert!(transformed.shape()[1] <= 3);
1381    }
1382
1383    #[test]
1384    fn test_sliding_window_update() {
1385        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1386        let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1387
1388        let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1389            .update_strategy(UpdateStrategy::SlidingWindow)
1390            .min_update_size(1);
1391
1392        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1393        let updated = fitted.update(&x_new).unwrap();
1394
1395        assert_eq!(updated.n_landmarks(), 3);
1396        assert_eq!(updated.update_count(), 1);
1397    }
1398
1399    #[test]
1400    fn test_different_kernels() {
1401        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1402
1403        // Test with RBF kernel
1404        let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1405        let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1406        let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1407        assert_eq!(rbf_transformed.shape()[0], 3);
1408
1409        // Test with polynomial kernel
1410        let poly_nystroem = IncrementalNystroem::new(
1411            Kernel::Polynomial {
1412                gamma: 1.0,
1413                coef0: 1.0,
1414                degree: 2,
1415            },
1416            3,
1417        );
1418        let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1419        let poly_transformed = poly_fitted.transform(&x).unwrap();
1420        assert_eq!(poly_transformed.shape()[0], 3);
1421    }
1422
1423    #[test]
1424    fn test_min_update_size() {
1425        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1426        let x_small = array![[3.0, 4.0]];
1427        let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1428
1429        let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1430
1431        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1432
1433        // Small update should not trigger recomputation
1434        let after_small = fitted.update(&x_small).unwrap();
1435        assert_eq!(after_small.update_count(), 0);
1436        assert_eq!(after_small.n_landmarks(), 2);
1437
1438        // Large update should trigger recomputation
1439        let after_large = after_small.update(&x_large).unwrap();
1440        assert_eq!(after_large.update_count(), 1);
1441        assert_eq!(after_large.n_landmarks(), 5);
1442    }
1443
1444    #[test]
1445    fn test_reproducibility() {
1446        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1447        let x_new = array![[4.0, 5.0]];
1448
1449        let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1450            .random_state(42)
1451            .min_update_size(1);
1452        let fitted1 = nystroem1.fit(&x, &()).unwrap();
1453        let updated1 = fitted1.update(&x_new).unwrap();
1454        let result1 = updated1.transform(&x).unwrap();
1455
1456        let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1457            .random_state(42)
1458            .min_update_size(1);
1459        let fitted2 = nystroem2.fit(&x, &()).unwrap();
1460        let updated2 = fitted2.update(&x_new).unwrap();
1461        let result2 = updated2.transform(&x).unwrap();
1462
1463        // Results should be very similar with same random seed (allowing for numerical precision)
1464        // Note: eigendecomposition can produce results that differ by a sign flip
1465        assert_eq!(result1.shape(), result2.shape());
1466
1467        // Check if results are similar or similar up to sign flip
1468        let mut direct_match = true;
1469        let mut sign_flip_match = true;
1470
1471        for i in 0..result1.len() {
1472            let val1 = result1.as_slice().unwrap()[i];
1473            let val2 = result2.as_slice().unwrap()[i];
1474
1475            if (val1 - val2).abs() > 1e-6 {
1476                direct_match = false;
1477            }
1478            if (val1 + val2).abs() > 1e-6 {
1479                sign_flip_match = false;
1480            }
1481        }
1482
1483        assert!(
1484            direct_match || sign_flip_match,
1485            "Results differ too much and are not related by sign flip"
1486        );
1487    }
1488}