sklears_kernel_approximation/
incremental_nystroem.rs

1//! Incremental Nyström method for online kernel approximation
2//!
3//! This module implements online/incremental versions of the Nyström method
4//! that can efficiently update kernel approximations as new data arrives.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Estimator, Fit, Trained, Transform, Untrained},
14    types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20/// Update strategy for incremental Nyström
21#[derive(Debug, Clone)]
22/// UpdateStrategy
23pub enum UpdateStrategy {
24    /// Simple addition of new landmarks
25    Append,
26    /// Replace oldest landmarks with new ones (sliding window)
27    SlidingWindow,
28    /// Merge new data with existing approximation
29    Merge,
30    /// Selective update based on approximation quality
31    Selective { threshold: Float },
32}
33
34/// Incremental Nyström method for online kernel approximation
35///
36/// Enables efficient updating of kernel approximations as new data arrives,
37/// without requiring complete recomputation from scratch.
38///
39/// # Parameters
40///
41/// * `kernel` - Kernel function to approximate
42/// * `n_components` - Maximum number of landmark points
43/// * `update_strategy` - Strategy for incorporating new data
44/// * `min_update_size` - Minimum number of new samples before updating
45/// * `sampling_strategy` - Strategy for selecting new landmarks
46///
47/// # Examples
48///
49/// ```rust,ignore
50/// use sklears_kernel_approximation::incremental_nystroem::{
51#[derive(Debug, Clone)]
52pub struct IncrementalNystroem<State = Untrained> {
53    pub kernel: Kernel,
54    pub n_components: usize,
55    pub update_strategy: UpdateStrategy,
56    pub min_update_size: usize,
57    pub sampling_strategy: SamplingStrategy,
58    pub random_state: Option<u64>,
59
60    // Fitted attributes
61    components_: Option<Array2<Float>>,
62    normalization_: Option<Array2<Float>>,
63    component_indices_: Option<Vec<usize>>,
64    landmark_data_: Option<Array2<Float>>,
65    update_count_: usize,
66    accumulated_data_: Option<Array2<Float>>,
67
68    _state: PhantomData<State>,
69}
70
71impl IncrementalNystroem<Untrained> {
72    /// Create a new incremental Nyström approximator
73    pub fn new(kernel: Kernel, n_components: usize) -> Self {
74        Self {
75            kernel,
76            n_components,
77            update_strategy: UpdateStrategy::Append,
78            min_update_size: 10,
79            sampling_strategy: SamplingStrategy::Random,
80            random_state: None,
81            components_: None,
82            normalization_: None,
83            component_indices_: None,
84            landmark_data_: None,
85            update_count_: 0,
86            accumulated_data_: None,
87            _state: PhantomData,
88        }
89    }
90
91    /// Set the update strategy
92    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
93        self.update_strategy = strategy;
94        self
95    }
96
97    /// Set minimum update size
98    pub fn min_update_size(mut self, size: usize) -> Self {
99        self.min_update_size = size;
100        self
101    }
102
103    /// Set sampling strategy
104    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
105        self.sampling_strategy = strategy;
106        self
107    }
108
109    /// Set random state
110    pub fn random_state(mut self, seed: u64) -> Self {
111        self.random_state = Some(seed);
112        self
113    }
114}
115
116impl Estimator for IncrementalNystroem<Untrained> {
117    type Config = ();
118    type Error = SklearsError;
119    type Float = Float;
120
121    fn config(&self) -> &Self::Config {
122        &()
123    }
124}
125
126impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
127    type Fitted = IncrementalNystroem<Trained>;
128
129    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
130        let (n_samples, _n_features) = x.dim();
131        let n_components = self.n_components.min(n_samples);
132
133        let mut rng = match self.random_state {
134            Some(seed) => RealStdRng::seed_from_u64(seed),
135            None => RealStdRng::from_seed(thread_rng().gen()),
136        };
137
138        // Select initial landmark points
139        let component_indices = self.select_components(x, n_components, &mut rng)?;
140        let landmark_data = self.extract_landmarks(x, &component_indices);
141
142        // Compute kernel matrix for landmarks
143        let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
144
145        // Compute eigendecomposition
146        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
147
148        Ok(IncrementalNystroem {
149            kernel: self.kernel,
150            n_components: self.n_components,
151            update_strategy: self.update_strategy,
152            min_update_size: self.min_update_size,
153            sampling_strategy: self.sampling_strategy,
154            random_state: self.random_state,
155            components_: Some(components),
156            normalization_: Some(normalization),
157            component_indices_: Some(component_indices),
158            landmark_data_: Some(landmark_data),
159            update_count_: 0,
160            accumulated_data_: None,
161            _state: PhantomData,
162        })
163    }
164}
165
166impl IncrementalNystroem<Untrained> {
167    /// Select component indices based on sampling strategy
168    fn select_components(
169        &self,
170        x: &Array2<Float>,
171        n_components: usize,
172        rng: &mut RealStdRng,
173    ) -> Result<Vec<usize>> {
174        let (n_samples, _) = x.dim();
175
176        match &self.sampling_strategy {
177            SamplingStrategy::Random => {
178                let mut indices: Vec<usize> = (0..n_samples).collect();
179                indices.shuffle(rng);
180                Ok(indices[..n_components].to_vec())
181            }
182            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
183            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
184            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
185        }
186    }
187
188    /// Simple k-means based sampling
189    fn kmeans_sampling(
190        &self,
191        x: &Array2<Float>,
192        n_components: usize,
193        rng: &mut RealStdRng,
194    ) -> Result<Vec<usize>> {
195        let (n_samples, n_features) = x.dim();
196        let mut centers = Array2::zeros((n_components, n_features));
197
198        // Initialize centers randomly
199        let mut indices: Vec<usize> = (0..n_samples).collect();
200        indices.shuffle(rng);
201        for (i, &idx) in indices[..n_components].iter().enumerate() {
202            centers.row_mut(i).assign(&x.row(idx));
203        }
204
205        // Run a few iterations of k-means
206        for _iter in 0..5 {
207            let mut assignments = vec![0; n_samples];
208
209            // Assign points to nearest centers
210            for i in 0..n_samples {
211                let mut min_dist = Float::INFINITY;
212                let mut best_center = 0;
213
214                for j in 0..n_components {
215                    let diff = &x.row(i) - &centers.row(j);
216                    let dist = diff.dot(&diff);
217                    if dist < min_dist {
218                        min_dist = dist;
219                        best_center = j;
220                    }
221                }
222                assignments[i] = best_center;
223            }
224
225            // Update centers
226            for j in 0..n_components {
227                let cluster_points: Vec<usize> = assignments
228                    .iter()
229                    .enumerate()
230                    .filter(|(_, &assignment)| assignment == j)
231                    .map(|(i, _)| i)
232                    .collect();
233
234                if !cluster_points.is_empty() {
235                    let mut new_center = Array1::zeros(n_features);
236                    for &point_idx in &cluster_points {
237                        new_center = new_center + x.row(point_idx);
238                    }
239                    new_center /= cluster_points.len() as Float;
240                    centers.row_mut(j).assign(&new_center);
241                }
242            }
243        }
244
245        // Find closest points to final centers
246        let mut selected_indices = Vec::new();
247        for j in 0..n_components {
248            let mut min_dist = Float::INFINITY;
249            let mut best_point = 0;
250
251            for i in 0..n_samples {
252                let diff = &x.row(i) - &centers.row(j);
253                let dist = diff.dot(&diff);
254                if dist < min_dist {
255                    min_dist = dist;
256                    best_point = i;
257                }
258            }
259            selected_indices.push(best_point);
260        }
261
262        selected_indices.sort_unstable();
263        selected_indices.dedup();
264
265        // Fill remaining slots randomly if needed
266        while selected_indices.len() < n_components {
267            let random_idx = rng.gen_range(0..n_samples);
268            if !selected_indices.contains(&random_idx) {
269                selected_indices.push(random_idx);
270            }
271        }
272
273        Ok(selected_indices[..n_components].to_vec())
274    }
275
276    /// Leverage score based sampling
277    fn leverage_score_sampling(
278        &self,
279        x: &Array2<Float>,
280        n_components: usize,
281        _rng: &mut RealStdRng,
282    ) -> Result<Vec<usize>> {
283        let (n_samples, _) = x.dim();
284
285        // Compute leverage scores (diagonal of hat matrix)
286        // For simplicity, we approximate using row norms as proxy
287        let mut scores = Vec::new();
288        for i in 0..n_samples {
289            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
290            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
291        }
292
293        // Sample based on scores using cumulative distribution
294        let total_score: Float = scores.iter().sum();
295        if total_score <= 0.0 {
296            return Err(SklearsError::InvalidInput(
297                "All scores are zero or negative".to_string(),
298            ));
299        }
300
301        // Create cumulative distribution
302        let mut cumulative = Vec::with_capacity(scores.len());
303        let mut sum = 0.0;
304        for &score in &scores {
305            sum += score / total_score;
306            cumulative.push(sum);
307        }
308
309        let mut selected_indices = Vec::new();
310        for _ in 0..n_components {
311            let r = thread_rng().gen::<Float>();
312            // Find index where cumulative probability >= r
313            let mut idx = cumulative
314                .iter()
315                .position(|&cum| cum >= r)
316                .unwrap_or(scores.len() - 1);
317
318            // Ensure no duplicates
319            while selected_indices.contains(&idx) {
320                let r = thread_rng().gen::<Float>();
321                idx = cumulative
322                    .iter()
323                    .position(|&cum| cum >= r)
324                    .unwrap_or(scores.len() - 1);
325            }
326            selected_indices.push(idx);
327        }
328
329        Ok(selected_indices)
330    }
331
332    /// Column norm based sampling
333    fn column_norm_sampling(
334        &self,
335        x: &Array2<Float>,
336        n_components: usize,
337        rng: &mut RealStdRng,
338    ) -> Result<Vec<usize>> {
339        let (n_samples, _) = x.dim();
340
341        // Compute row norms
342        let mut norms = Vec::new();
343        for i in 0..n_samples {
344            let norm = x.row(i).dot(&x.row(i)).sqrt();
345            norms.push(norm + 1e-10);
346        }
347
348        // Sort by norm and take diverse selection
349        let mut indices_with_norms: Vec<(usize, Float)> = norms
350            .iter()
351            .enumerate()
352            .map(|(i, &norm)| (i, norm))
353            .collect();
354        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355
356        let mut selected_indices = Vec::new();
357        let step = n_samples.max(1) / n_components.max(1);
358
359        for i in 0..n_components {
360            let idx = (i * step).min(n_samples - 1);
361            selected_indices.push(indices_with_norms[idx].0);
362        }
363
364        // Fill remaining with random if needed
365        while selected_indices.len() < n_components {
366            let random_idx = rng.gen_range(0..n_samples);
367            if !selected_indices.contains(&random_idx) {
368                selected_indices.push(random_idx);
369            }
370        }
371
372        Ok(selected_indices)
373    }
374
375    /// Extract landmark data points
376    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
377        let (_, n_features) = x.dim();
378        let mut landmarks = Array2::zeros((indices.len(), n_features));
379
380        for (i, &idx) in indices.iter().enumerate() {
381            landmarks.row_mut(i).assign(&x.row(idx));
382        }
383
384        landmarks
385    }
386
387    /// Compute eigendecomposition using power iteration method
388    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
389    fn compute_eigendecomposition(
390        &self,
391        matrix: Array2<Float>,
392    ) -> Result<(Array1<Float>, Array2<Float>)> {
393        let n = matrix.nrows();
394
395        if n != matrix.ncols() {
396            return Err(SklearsError::InvalidInput(
397                "Matrix must be square for eigendecomposition".to_string(),
398            ));
399        }
400
401        let mut eigenvals = Array1::zeros(n);
402        let mut eigenvecs = Array2::zeros((n, n));
403
404        // Use deflation method to find multiple eigenvalues
405        let mut deflated_matrix = matrix.clone();
406
407        for k in 0..n {
408            // Power iteration for k-th eigenvalue/eigenvector
409            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
410
411            eigenvals[k] = eigenval;
412            eigenvecs.column_mut(k).assign(&eigenvec);
413
414            // Deflate matrix: A_new = A - λ * v * v^T
415            for i in 0..n {
416                for j in 0..n {
417                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
418                }
419            }
420        }
421
422        // Sort eigenvalues and eigenvectors in descending order
423        let mut indices: Vec<usize> = (0..n).collect();
424        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
425
426        let mut sorted_eigenvals = Array1::zeros(n);
427        let mut sorted_eigenvecs = Array2::zeros((n, n));
428
429        for (new_idx, &old_idx) in indices.iter().enumerate() {
430            sorted_eigenvals[new_idx] = eigenvals[old_idx];
431            sorted_eigenvecs
432                .column_mut(new_idx)
433                .assign(&eigenvecs.column(old_idx));
434        }
435
436        Ok((sorted_eigenvals, sorted_eigenvecs))
437    }
438
439    /// Power iteration method to find dominant eigenvalue and eigenvector
440    fn power_iteration(
441        &self,
442        matrix: &Array2<Float>,
443        max_iter: usize,
444        tol: Float,
445    ) -> Result<(Float, Array1<Float>)> {
446        let n = matrix.nrows();
447
448        // Initialize with deterministic vector to ensure reproducibility
449        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
450
451        // Normalize
452        let norm = v.dot(&v).sqrt();
453        if norm < 1e-10 {
454            return Err(SklearsError::InvalidInput(
455                "Initial vector has zero norm".to_string(),
456            ));
457        }
458        v /= norm;
459
460        let mut eigenval = 0.0;
461
462        for _iter in 0..max_iter {
463            // Apply matrix
464            let w = matrix.dot(&v);
465
466            // Compute Rayleigh quotient
467            let new_eigenval = v.dot(&w);
468
469            // Normalize
470            let w_norm = w.dot(&w).sqrt();
471            if w_norm < 1e-10 {
472                break;
473            }
474            let new_v = w / w_norm;
475
476            // Check convergence
477            let eigenval_change = (new_eigenval - eigenval).abs();
478            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
479
480            if eigenval_change < tol && vector_change < tol {
481                return Ok((new_eigenval, new_v));
482            }
483
484            eigenval = new_eigenval;
485            v = new_v;
486        }
487
488        Ok((eigenval, v))
489    }
490
491    /// Compute eigendecomposition of kernel matrix
492    fn compute_decomposition(
493        &self,
494        mut kernel_matrix: Array2<Float>,
495    ) -> Result<(Array2<Float>, Array2<Float>)> {
496        // Add small regularization to diagonal for numerical stability
497        let reg = 1e-8;
498        for i in 0..kernel_matrix.nrows() {
499            kernel_matrix[[i, i]] += reg;
500        }
501
502        // Proper eigendecomposition for Nyström method
503        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
504
505        // Filter out small eigenvalues for numerical stability
506        let threshold = 1e-8;
507        let valid_indices: Vec<usize> = eigenvals
508            .iter()
509            .enumerate()
510            .filter(|(_, &val)| val > threshold)
511            .map(|(i, _)| i)
512            .collect();
513
514        if valid_indices.is_empty() {
515            return Err(SklearsError::InvalidInput(
516                "No valid eigenvalues found in kernel matrix".to_string(),
517            ));
518        }
519
520        // Construct components and normalization matrices
521        let n_valid = valid_indices.len();
522        let mut components = Array2::zeros((eigenvals.len(), n_valid));
523        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
524
525        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
526            let sqrt_eigenval = eigenvals[old_idx].sqrt();
527            components
528                .column_mut(new_idx)
529                .assign(&eigenvecs.column(old_idx));
530
531            // For Nyström method: normalization = V * Λ^(-1/2)
532            for i in 0..eigenvals.len() {
533                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
534            }
535        }
536
537        Ok((components, normalization))
538    }
539}
540
541impl IncrementalNystroem<Trained> {
542    /// Select component indices based on sampling strategy
543    fn select_components(
544        &self,
545        x: &Array2<Float>,
546        n_components: usize,
547        rng: &mut RealStdRng,
548    ) -> Result<Vec<usize>> {
549        let (n_samples, _) = x.dim();
550
551        match &self.sampling_strategy {
552            SamplingStrategy::Random => {
553                let mut indices: Vec<usize> = (0..n_samples).collect();
554                indices.shuffle(rng);
555                Ok(indices[..n_components].to_vec())
556            }
557            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
558            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
559            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
560        }
561    }
562
563    /// Simple k-means based sampling
564    fn kmeans_sampling(
565        &self,
566        x: &Array2<Float>,
567        n_components: usize,
568        rng: &mut RealStdRng,
569    ) -> Result<Vec<usize>> {
570        let (n_samples, n_features) = x.dim();
571        let mut centers = Array2::zeros((n_components, n_features));
572
573        // Initialize centers randomly
574        let mut indices: Vec<usize> = (0..n_samples).collect();
575        indices.shuffle(rng);
576        for (i, &idx) in indices[..n_components].iter().enumerate() {
577            centers.row_mut(i).assign(&x.row(idx));
578        }
579
580        // Run a few iterations of k-means
581        for _iter in 0..5 {
582            let mut assignments = vec![0; n_samples];
583
584            // Assign points to nearest centers
585            for i in 0..n_samples {
586                let mut min_dist = Float::INFINITY;
587                let mut best_center = 0;
588
589                for j in 0..n_components {
590                    let diff = &x.row(i) - &centers.row(j);
591                    let dist = diff.dot(&diff);
592                    if dist < min_dist {
593                        min_dist = dist;
594                        best_center = j;
595                    }
596                }
597                assignments[i] = best_center;
598            }
599
600            // Update centers
601            for j in 0..n_components {
602                let cluster_points: Vec<usize> = assignments
603                    .iter()
604                    .enumerate()
605                    .filter(|(_, &assignment)| assignment == j)
606                    .map(|(i, _)| i)
607                    .collect();
608
609                if !cluster_points.is_empty() {
610                    let mut new_center = Array1::zeros(n_features);
611                    for &point_idx in &cluster_points {
612                        new_center = new_center + x.row(point_idx);
613                    }
614                    new_center /= cluster_points.len() as Float;
615                    centers.row_mut(j).assign(&new_center);
616                }
617            }
618        }
619
620        // Find closest points to final centers
621        let mut selected_indices = Vec::new();
622        for j in 0..n_components {
623            let mut min_dist = Float::INFINITY;
624            let mut best_point = 0;
625
626            for i in 0..n_samples {
627                let diff = &x.row(i) - &centers.row(j);
628                let dist = diff.dot(&diff);
629                if dist < min_dist {
630                    min_dist = dist;
631                    best_point = i;
632                }
633            }
634            selected_indices.push(best_point);
635        }
636
637        selected_indices.sort_unstable();
638        selected_indices.dedup();
639
640        // Fill remaining slots randomly if needed
641        while selected_indices.len() < n_components {
642            let random_idx = rng.gen_range(0..n_samples);
643            if !selected_indices.contains(&random_idx) {
644                selected_indices.push(random_idx);
645            }
646        }
647
648        Ok(selected_indices[..n_components].to_vec())
649    }
650
651    /// Leverage score based sampling
652    fn leverage_score_sampling(
653        &self,
654        x: &Array2<Float>,
655        n_components: usize,
656        _rng: &mut RealStdRng,
657    ) -> Result<Vec<usize>> {
658        let (n_samples, _) = x.dim();
659
660        // Compute leverage scores (diagonal of hat matrix)
661        // For simplicity, we approximate using row norms as proxy
662        let mut scores = Vec::new();
663        for i in 0..n_samples {
664            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
665            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
666        }
667
668        // Sample based on scores using cumulative distribution
669        let total_score: Float = scores.iter().sum();
670        if total_score <= 0.0 {
671            return Err(SklearsError::InvalidInput(
672                "All scores are zero or negative".to_string(),
673            ));
674        }
675
676        // Create cumulative distribution
677        let mut cumulative = Vec::with_capacity(scores.len());
678        let mut sum = 0.0;
679        for &score in &scores {
680            sum += score / total_score;
681            cumulative.push(sum);
682        }
683
684        let mut selected_indices = Vec::new();
685        for _ in 0..n_components {
686            let r = thread_rng().gen::<Float>();
687            // Find index where cumulative probability >= r
688            let mut idx = cumulative
689                .iter()
690                .position(|&cum| cum >= r)
691                .unwrap_or(scores.len() - 1);
692
693            // Ensure no duplicates
694            while selected_indices.contains(&idx) {
695                let r = thread_rng().gen::<Float>();
696                idx = cumulative
697                    .iter()
698                    .position(|&cum| cum >= r)
699                    .unwrap_or(scores.len() - 1);
700            }
701            selected_indices.push(idx);
702        }
703
704        Ok(selected_indices)
705    }
706
707    /// Column norm based sampling
708    fn column_norm_sampling(
709        &self,
710        x: &Array2<Float>,
711        n_components: usize,
712        rng: &mut RealStdRng,
713    ) -> Result<Vec<usize>> {
714        let (n_samples, _) = x.dim();
715
716        // Compute row norms
717        let mut norms = Vec::new();
718        for i in 0..n_samples {
719            let norm = x.row(i).dot(&x.row(i)).sqrt();
720            norms.push(norm + 1e-10);
721        }
722
723        // Sort by norm and take diverse selection
724        let mut indices_with_norms: Vec<(usize, Float)> = norms
725            .iter()
726            .enumerate()
727            .map(|(i, &norm)| (i, norm))
728            .collect();
729        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
730
731        let mut selected_indices = Vec::new();
732        let step = n_samples.max(1) / n_components.max(1);
733
734        for i in 0..n_components {
735            let idx = (i * step).min(n_samples - 1);
736            selected_indices.push(indices_with_norms[idx].0);
737        }
738
739        // Fill remaining with random if needed
740        while selected_indices.len() < n_components {
741            let random_idx = rng.gen_range(0..n_samples);
742            if !selected_indices.contains(&random_idx) {
743                selected_indices.push(random_idx);
744            }
745        }
746
747        Ok(selected_indices)
748    }
749
750    /// Update the approximation with new data
751    pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
752        // Accumulate new data
753        match &self.accumulated_data_ {
754            Some(existing) => {
755                let combined =
756                    scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
757                self.accumulated_data_ = Some(combined);
758            }
759            None => {
760                self.accumulated_data_ = Some(x_new.clone());
761            }
762        }
763
764        // Check if we have enough accumulated data to update
765        let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
766            accumulated.nrows() >= self.min_update_size
767        } else {
768            false
769        };
770
771        if should_update {
772            if let Some(accumulated) = self.accumulated_data_.take() {
773                self = self.perform_update(&accumulated)?;
774                self.update_count_ += 1;
775            }
776        }
777
778        Ok(self)
779    }
780
781    /// Perform the actual update based on the strategy
782    fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
783        match self.update_strategy.clone() {
784            UpdateStrategy::Append => self.append_update(new_data),
785            UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
786            UpdateStrategy::Merge => self.merge_update(new_data),
787            UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
788        }
789    }
790
791    /// Append new landmarks (if space available)
792    fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
793        let current_landmarks = self.landmark_data_.as_ref().unwrap();
794        let current_components = current_landmarks.nrows();
795
796        if current_components >= self.n_components {
797            // No space to append, just return current state
798            return Ok(self);
799        }
800
801        let available_space = self.n_components - current_components;
802        let n_new = available_space.min(new_data.nrows());
803
804        if n_new == 0 {
805            return Ok(self);
806        }
807
808        // Select new landmarks from new data
809        let mut rng = match self.random_state {
810            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
811            None => RealStdRng::from_seed(thread_rng().gen()),
812        };
813
814        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
815        indices.shuffle(&mut rng);
816        let selected_indices = &indices[..n_new];
817
818        // Extract new landmarks
819        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
820
821        // Combine with existing landmarks
822        let combined_landmarks =
823            scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
824
825        // Recompute decomposition
826        let kernel_matrix = self
827            .kernel
828            .compute_kernel(&combined_landmarks, &combined_landmarks);
829        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
830
831        // Update indices
832        let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
833        let base_index = current_landmarks.nrows();
834        for &idx in selected_indices {
835            new_component_indices.push(base_index + idx);
836        }
837
838        self.components_ = Some(components);
839        self.normalization_ = Some(normalization);
840        self.component_indices_ = Some(new_component_indices);
841        self.landmark_data_ = Some(combined_landmarks);
842
843        Ok(self)
844    }
845
846    /// Sliding window update (replace oldest landmarks)
847    fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
848        let current_landmarks = self.landmark_data_.as_ref().unwrap();
849        let n_new = new_data.nrows().min(self.n_components);
850
851        if n_new == 0 {
852            return Ok(self);
853        }
854
855        // Select new landmarks
856        let mut rng = match self.random_state {
857            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
858            None => RealStdRng::from_seed(thread_rng().gen()),
859        };
860
861        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
862        indices.shuffle(&mut rng);
863        let selected_indices = &indices[..n_new];
864
865        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
866
867        // Replace oldest landmarks with new ones
868        let n_keep = self.n_components - n_new;
869        let combined_landmarks = if n_keep > 0 {
870            let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
871            scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
872        } else {
873            new_landmarks
874        };
875
876        // Recompute decomposition
877        let kernel_matrix = self
878            .kernel
879            .compute_kernel(&combined_landmarks, &combined_landmarks);
880        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
881
882        // Update component indices (simplified)
883        let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
884
885        self.components_ = Some(components);
886        self.normalization_ = Some(normalization);
887        self.component_indices_ = Some(new_component_indices);
888        self.landmark_data_ = Some(combined_landmarks);
889
890        Ok(self)
891    }
892
893    /// Merge update (combine approximations)
894    fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
895        // Sophisticated merging strategy that combines existing and new Nyström approximations
896        // This is based on the idea of merging two kernel approximations optimally
897
898        let current_landmarks = self.landmark_data_.as_ref().unwrap();
899        let _current_components = self.components_.as_ref().unwrap();
900        let _current_normalization = self.normalization_.as_ref().unwrap();
901
902        // Step 1: Create a new Nyström approximation from the new data
903        let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
904
905        let mut rng = match self.random_state {
906            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
907            None => RealStdRng::from_seed(thread_rng().gen()),
908        };
909
910        // Select new landmarks using the same strategy
911        let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
912        let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
913
914        // Compute new kernel matrix and decomposition
915        let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
916        let (_new_components, _new_normalization) =
917            self.compute_decomposition(new_kernel_matrix)?;
918
919        // Step 2: Combine the landmarks intelligently
920        // Merge by selecting the most diverse/informative landmarks from both sets
921        let merged_landmarks =
922            self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
923
924        // Step 3: Recompute the full approximation on merged landmarks
925        let merged_kernel_matrix = self
926            .kernel
927            .compute_kernel(&merged_landmarks, &merged_landmarks);
928        let (final_components, final_normalization) =
929            self.compute_decomposition(merged_kernel_matrix)?;
930
931        // Update component indices (simplified for merged case)
932        let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
933
934        let mut updated_self = self;
935        updated_self.components_ = Some(final_components);
936        updated_self.normalization_ = Some(final_normalization);
937        updated_self.component_indices_ = Some(final_component_indices);
938        updated_self.landmark_data_ = Some(merged_landmarks);
939
940        Ok(updated_self)
941    }
942
943    /// Intelligently merge landmarks from existing and new data
944    fn merge_landmarks_intelligently(
945        &self,
946        current_landmarks: &Array2<Float>,
947        new_landmarks: &Array2<Float>,
948        rng: &mut RealStdRng,
949    ) -> Result<Array2<Float>> {
950        let n_current = current_landmarks.nrows();
951        let n_new = new_landmarks.nrows();
952        let n_features = current_landmarks.ncols();
953
954        // Combine all landmarks temporarily
955        let all_landmarks = scirs2_core::ndarray::concatenate![
956            Axis(0),
957            current_landmarks.clone(),
958            new_landmarks.clone()
959        ];
960
961        // Use diversity-based selection to choose the best subset
962        let n_target = self.n_components.min(n_current + n_new);
963        let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
964
965        // Extract selected landmarks
966        let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
967        for (i, &idx) in selected_indices.iter().enumerate() {
968            merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
969        }
970
971        Ok(merged_landmarks)
972    }
973
974    /// Select diverse landmarks using maximum distance criterion
975    fn select_diverse_landmarks(
976        &self,
977        landmarks: &Array2<Float>,
978        n_select: usize,
979        rng: &mut RealStdRng,
980    ) -> Result<Vec<usize>> {
981        let n_landmarks = landmarks.nrows();
982
983        if n_select >= n_landmarks {
984            return Ok((0..n_landmarks).collect());
985        }
986
987        let mut selected = Vec::new();
988        let mut available: Vec<usize> = (0..n_landmarks).collect();
989
990        // Start with a random landmark
991        let first_idx = rng.gen_range(0..available.len());
992        selected.push(available.remove(first_idx));
993
994        // Greedily select landmarks that are maximally distant from already selected ones
995        while selected.len() < n_select && !available.is_empty() {
996            let mut best_idx = 0;
997            let mut max_min_distance = 0.0;
998
999            for (i, &candidate_idx) in available.iter().enumerate() {
1000                // Compute minimum distance to already selected landmarks
1001                let mut min_distance = Float::INFINITY;
1002
1003                for &selected_idx in &selected {
1004                    let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1005                    let distance = diff.dot(&diff).sqrt();
1006                    if distance < min_distance {
1007                        min_distance = distance;
1008                    }
1009                }
1010
1011                if min_distance > max_min_distance {
1012                    max_min_distance = min_distance;
1013                    best_idx = i;
1014                }
1015            }
1016
1017            selected.push(available.remove(best_idx));
1018        }
1019
1020        Ok(selected)
1021    }
1022
1023    /// Selective update based on approximation quality
1024    fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1025        // Quality-based selective update that only incorporates new data if it improves approximation
1026
1027        let current_landmarks = self.landmark_data_.as_ref().unwrap();
1028
1029        // Step 1: Evaluate current approximation quality on new data
1030        let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1031
1032        // Step 2: Create candidate updates and evaluate their quality
1033        let mut best_update = self.clone();
1034        let mut best_quality = current_quality;
1035
1036        // Try append update
1037        let append_candidate = self.clone().append_update(new_data)?;
1038        let append_quality = append_candidate.evaluate_approximation_quality(
1039            append_candidate.landmark_data_.as_ref().unwrap(),
1040            new_data,
1041        )?;
1042
1043        if append_quality > best_quality + threshold {
1044            best_update = append_candidate;
1045            best_quality = append_quality;
1046        }
1047
1048        // Try merge update if we have enough data
1049        if new_data.nrows() >= 3 {
1050            let merge_candidate = self.clone().merge_update(new_data)?;
1051            let merge_quality = merge_candidate.evaluate_approximation_quality(
1052                merge_candidate.landmark_data_.as_ref().unwrap(),
1053                new_data,
1054            )?;
1055
1056            if merge_quality > best_quality + threshold {
1057                best_update = merge_candidate;
1058                best_quality = merge_quality;
1059            }
1060        }
1061
1062        // Try sliding window update
1063        let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1064        let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1065            sliding_candidate.landmark_data_.as_ref().unwrap(),
1066            new_data,
1067        )?;
1068
1069        if sliding_quality > best_quality + threshold {
1070            best_update = sliding_candidate;
1071            best_quality = sliding_quality;
1072        }
1073
1074        // Step 3: Only update if quality improvement exceeds threshold
1075        if best_quality > current_quality + threshold {
1076            Ok(best_update)
1077        } else {
1078            // No significant improvement, keep current state
1079            Ok(self)
1080        }
1081    }
1082
1083    /// Evaluate approximation quality using kernel approximation error
1084    fn evaluate_approximation_quality(
1085        &self,
1086        landmarks: &Array2<Float>,
1087        test_data: &Array2<Float>,
1088    ) -> Result<Float> {
1089        // Quality metric: negative approximation error (higher is better)
1090
1091        let n_test = test_data.nrows().min(50); // Limit for efficiency
1092        let test_subset = if test_data.nrows() > n_test {
1093            // Sample random subset for evaluation
1094            let mut rng = thread_rng();
1095            let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1096            indices.shuffle(&mut rng);
1097            test_data.select(Axis(0), &indices[..n_test])
1098        } else {
1099            test_data.to_owned()
1100        };
1101
1102        // Compute exact kernel matrix for test subset
1103        let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1104
1105        // Compute Nyström approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1106        let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1107        let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1108
1109        // Use our eigendecomposition to compute pseudo-inverse
1110        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1111
1112        // Construct pseudo-inverse
1113        let threshold = 1e-8;
1114        let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1115
1116        for i in 0..landmarks.nrows() {
1117            for j in 0..landmarks.nrows() {
1118                let mut sum = 0.0;
1119                for k in 0..eigenvals.len() {
1120                    if eigenvals[k] > threshold {
1121                        sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1122                    }
1123                }
1124                pseudo_inverse[[i, j]] = sum;
1125            }
1126        }
1127
1128        // Compute approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1129        let k_approx = k_test_landmarks
1130            .dot(&pseudo_inverse)
1131            .dot(&k_test_landmarks.t());
1132
1133        // Compute approximation error (Frobenius norm)
1134        let error_matrix = &k_exact - &k_approx;
1135        let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1136
1137        // Convert to quality score (negative error, higher is better)
1138        let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1139
1140        Ok(quality)
1141    }
1142
1143    /// Extract landmark data points
1144    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1145        let (_, n_features) = x.dim();
1146        let mut landmarks = Array2::zeros((indices.len(), n_features));
1147
1148        for (i, &idx) in indices.iter().enumerate() {
1149            landmarks.row_mut(i).assign(&x.row(idx));
1150        }
1151
1152        landmarks
1153    }
1154
1155    /// Compute eigendecomposition using power iteration method
1156    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
1157    fn compute_eigendecomposition(
1158        &self,
1159        matrix: Array2<Float>,
1160    ) -> Result<(Array1<Float>, Array2<Float>)> {
1161        let n = matrix.nrows();
1162
1163        if n != matrix.ncols() {
1164            return Err(SklearsError::InvalidInput(
1165                "Matrix must be square for eigendecomposition".to_string(),
1166            ));
1167        }
1168
1169        let mut eigenvals = Array1::zeros(n);
1170        let mut eigenvecs = Array2::zeros((n, n));
1171
1172        // Use deflation method to find multiple eigenvalues
1173        let mut deflated_matrix = matrix.clone();
1174
1175        for k in 0..n {
1176            // Power iteration for k-th eigenvalue/eigenvector
1177            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1178
1179            eigenvals[k] = eigenval;
1180            eigenvecs.column_mut(k).assign(&eigenvec);
1181
1182            // Deflate matrix: A_new = A - λ * v * v^T
1183            for i in 0..n {
1184                for j in 0..n {
1185                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1186                }
1187            }
1188        }
1189
1190        // Sort eigenvalues and eigenvectors in descending order
1191        let mut indices: Vec<usize> = (0..n).collect();
1192        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1193
1194        let mut sorted_eigenvals = Array1::zeros(n);
1195        let mut sorted_eigenvecs = Array2::zeros((n, n));
1196
1197        for (new_idx, &old_idx) in indices.iter().enumerate() {
1198            sorted_eigenvals[new_idx] = eigenvals[old_idx];
1199            sorted_eigenvecs
1200                .column_mut(new_idx)
1201                .assign(&eigenvecs.column(old_idx));
1202        }
1203
1204        Ok((sorted_eigenvals, sorted_eigenvecs))
1205    }
1206
1207    /// Power iteration method to find dominant eigenvalue and eigenvector
1208    fn power_iteration(
1209        &self,
1210        matrix: &Array2<Float>,
1211        max_iter: usize,
1212        tol: Float,
1213    ) -> Result<(Float, Array1<Float>)> {
1214        let n = matrix.nrows();
1215
1216        // Initialize with deterministic vector to ensure reproducibility
1217        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1218
1219        // Normalize
1220        let norm = v.dot(&v).sqrt();
1221        if norm < 1e-10 {
1222            return Err(SklearsError::InvalidInput(
1223                "Initial vector has zero norm".to_string(),
1224            ));
1225        }
1226        v /= norm;
1227
1228        let mut eigenval = 0.0;
1229
1230        for _iter in 0..max_iter {
1231            // Apply matrix
1232            let w = matrix.dot(&v);
1233
1234            // Compute Rayleigh quotient
1235            let new_eigenval = v.dot(&w);
1236
1237            // Normalize
1238            let w_norm = w.dot(&w).sqrt();
1239            if w_norm < 1e-10 {
1240                break;
1241            }
1242            let new_v = w / w_norm;
1243
1244            // Check convergence
1245            let eigenval_change = (new_eigenval - eigenval).abs();
1246            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1247
1248            if eigenval_change < tol && vector_change < tol {
1249                return Ok((new_eigenval, new_v));
1250            }
1251
1252            eigenval = new_eigenval;
1253            v = new_v;
1254        }
1255
1256        Ok((eigenval, v))
1257    }
1258
1259    /// Compute eigendecomposition of kernel matrix
1260    fn compute_decomposition(
1261        &self,
1262        mut kernel_matrix: Array2<Float>,
1263    ) -> Result<(Array2<Float>, Array2<Float>)> {
1264        // Add small regularization to diagonal for numerical stability
1265        let reg = 1e-8;
1266        for i in 0..kernel_matrix.nrows() {
1267            kernel_matrix[[i, i]] += reg;
1268        }
1269
1270        // Proper eigendecomposition for Nyström method
1271        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1272
1273        // Filter out small eigenvalues for numerical stability
1274        let threshold = 1e-8;
1275        let valid_indices: Vec<usize> = eigenvals
1276            .iter()
1277            .enumerate()
1278            .filter(|(_, &val)| val > threshold)
1279            .map(|(i, _)| i)
1280            .collect();
1281
1282        if valid_indices.is_empty() {
1283            return Err(SklearsError::InvalidInput(
1284                "No valid eigenvalues found in kernel matrix".to_string(),
1285            ));
1286        }
1287
1288        // Construct components and normalization matrices
1289        let n_valid = valid_indices.len();
1290        let mut components = Array2::zeros((eigenvals.len(), n_valid));
1291        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1292
1293        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1294            let sqrt_eigenval = eigenvals[old_idx].sqrt();
1295            components
1296                .column_mut(new_idx)
1297                .assign(&eigenvecs.column(old_idx));
1298
1299            // For Nyström method: normalization = V * Λ^(-1/2)
1300            for i in 0..eigenvals.len() {
1301                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1302            }
1303        }
1304
1305        Ok((components, normalization))
1306    }
1307
1308    /// Get number of updates performed
1309    pub fn update_count(&self) -> usize {
1310        self.update_count_
1311    }
1312
1313    /// Get current number of landmarks
1314    pub fn n_landmarks(&self) -> usize {
1315        self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1316    }
1317}
1318
1319impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1320    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1321        let _components = self
1322            .components_
1323            .as_ref()
1324            .ok_or_else(|| SklearsError::NotFitted {
1325                operation: "transform".to_string(),
1326            })?;
1327
1328        let normalization =
1329            self.normalization_
1330                .as_ref()
1331                .ok_or_else(|| SklearsError::NotFitted {
1332                    operation: "transform".to_string(),
1333                })?;
1334
1335        let landmark_data =
1336            self.landmark_data_
1337                .as_ref()
1338                .ok_or_else(|| SklearsError::NotFitted {
1339                    operation: "transform".to_string(),
1340                })?;
1341
1342        // Compute kernel between input and landmarks
1343        let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1344
1345        // Apply transformation: K(X, landmarks) @ normalization.T
1346        let transformed = kernel_x_landmarks.dot(&normalization.t());
1347
1348        Ok(transformed)
1349    }
1350}
1351
1352#[allow(non_snake_case)]
1353#[cfg(test)]
1354mod tests {
1355    use super::*;
1356    use scirs2_core::ndarray::array;
1357
1358    #[test]
1359    fn test_incremental_nystroem_basic() {
1360        let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1361        let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1362
1363        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1364            .update_strategy(UpdateStrategy::Append)
1365            .min_update_size(1);
1366
1367        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1368        assert_eq!(fitted.n_landmarks(), 3);
1369
1370        let updated = fitted.update(&x_new).unwrap();
1371        assert_eq!(updated.n_landmarks(), 5);
1372        assert_eq!(updated.update_count(), 1);
1373    }
1374
1375    #[test]
1376    fn test_incremental_transform() {
1377        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1378        let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1379
1380        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1381        let fitted = nystroem.fit(&x_train, &()).unwrap();
1382
1383        let transformed = fitted.transform(&x_test).unwrap();
1384        assert_eq!(transformed.shape()[0], 2);
1385        assert!(transformed.shape()[1] <= 3);
1386    }
1387
1388    #[test]
1389    fn test_sliding_window_update() {
1390        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1391        let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1392
1393        let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1394            .update_strategy(UpdateStrategy::SlidingWindow)
1395            .min_update_size(1);
1396
1397        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1398        let updated = fitted.update(&x_new).unwrap();
1399
1400        assert_eq!(updated.n_landmarks(), 3);
1401        assert_eq!(updated.update_count(), 1);
1402    }
1403
1404    #[test]
1405    fn test_different_kernels() {
1406        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1407
1408        // Test with RBF kernel
1409        let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1410        let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1411        let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1412        assert_eq!(rbf_transformed.shape()[0], 3);
1413
1414        // Test with polynomial kernel
1415        let poly_nystroem = IncrementalNystroem::new(
1416            Kernel::Polynomial {
1417                gamma: 1.0,
1418                coef0: 1.0,
1419                degree: 2,
1420            },
1421            3,
1422        );
1423        let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1424        let poly_transformed = poly_fitted.transform(&x).unwrap();
1425        assert_eq!(poly_transformed.shape()[0], 3);
1426    }
1427
1428    #[test]
1429    fn test_min_update_size() {
1430        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1431        let x_small = array![[3.0, 4.0]];
1432        let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1433
1434        let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1435
1436        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1437
1438        // Small update should not trigger recomputation
1439        let after_small = fitted.update(&x_small).unwrap();
1440        assert_eq!(after_small.update_count(), 0);
1441        assert_eq!(after_small.n_landmarks(), 2);
1442
1443        // Large update should trigger recomputation
1444        let after_large = after_small.update(&x_large).unwrap();
1445        assert_eq!(after_large.update_count(), 1);
1446        assert_eq!(after_large.n_landmarks(), 5);
1447    }
1448
1449    #[test]
1450    fn test_reproducibility() {
1451        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1452        let x_new = array![[4.0, 5.0]];
1453
1454        let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1455            .random_state(42)
1456            .min_update_size(1);
1457        let fitted1 = nystroem1.fit(&x, &()).unwrap();
1458        let updated1 = fitted1.update(&x_new).unwrap();
1459        let result1 = updated1.transform(&x).unwrap();
1460
1461        let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1462            .random_state(42)
1463            .min_update_size(1);
1464        let fitted2 = nystroem2.fit(&x, &()).unwrap();
1465        let updated2 = fitted2.update(&x_new).unwrap();
1466        let result2 = updated2.transform(&x).unwrap();
1467
1468        // Results should be very similar with same random seed (allowing for numerical precision)
1469        // Note: eigendecomposition can produce results that differ by a sign flip
1470        assert_eq!(result1.shape(), result2.shape());
1471
1472        // Check if results are similar or similar up to sign flip
1473        let mut direct_match = true;
1474        let mut sign_flip_match = true;
1475
1476        for i in 0..result1.len() {
1477            let val1 = result1.as_slice().unwrap()[i];
1478            let val2 = result2.as_slice().unwrap()[i];
1479
1480            if (val1 - val2).abs() > 1e-6 {
1481                direct_match = false;
1482            }
1483            if (val1 + val2).abs() > 1e-6 {
1484                sign_flip_match = false;
1485            }
1486        }
1487
1488        assert!(
1489            direct_match || sign_flip_match,
1490            "Results differ too much and are not related by sign flip"
1491        );
1492    }
1493}