Skip to main content

sklears_kernel_approximation/
incremental_nystroem.rs

1//! Incremental Nyström method for online kernel approximation
2//!
3//! This module implements online/incremental versions of the Nyström method
4//! that can efficiently update kernel approximations as new data arrives.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::RngExt;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12    error::{Result, SklearsError},
13    traits::{Estimator, Fit, Trained, Transform, Untrained},
14    types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20/// Update strategy for incremental Nyström
21#[derive(Debug, Clone)]
22/// UpdateStrategy
23pub enum UpdateStrategy {
24    /// Simple addition of new landmarks
25    Append,
26    /// Replace oldest landmarks with new ones (sliding window)
27    SlidingWindow,
28    /// Merge new data with existing approximation
29    Merge,
30    /// Selective update based on approximation quality
31    Selective { threshold: Float },
32}
33
34/// Incremental Nyström method for online kernel approximation
35///
36/// Enables efficient updating of kernel approximations as new data arrives,
37/// without requiring complete recomputation from scratch.
38///
39/// # Parameters
40///
41/// * `kernel` - Kernel function to approximate
42/// * `n_components` - Maximum number of landmark points
43/// * `update_strategy` - Strategy for incorporating new data
44/// * `min_update_size` - Minimum number of new samples before updating
45/// * `sampling_strategy` - Strategy for selecting new landmarks
46#[derive(Debug, Clone)]
47pub struct IncrementalNystroem<State = Untrained> {
48    pub kernel: Kernel,
49    pub n_components: usize,
50    pub update_strategy: UpdateStrategy,
51    pub min_update_size: usize,
52    pub sampling_strategy: SamplingStrategy,
53    pub random_state: Option<u64>,
54
55    // Fitted attributes
56    components_: Option<Array2<Float>>,
57    normalization_: Option<Array2<Float>>,
58    component_indices_: Option<Vec<usize>>,
59    landmark_data_: Option<Array2<Float>>,
60    update_count_: usize,
61    accumulated_data_: Option<Array2<Float>>,
62
63    _state: PhantomData<State>,
64}
65
66impl IncrementalNystroem<Untrained> {
67    /// Create a new incremental Nyström approximator
68    pub fn new(kernel: Kernel, n_components: usize) -> Self {
69        Self {
70            kernel,
71            n_components,
72            update_strategy: UpdateStrategy::Append,
73            min_update_size: 10,
74            sampling_strategy: SamplingStrategy::Random,
75            random_state: None,
76            components_: None,
77            normalization_: None,
78            component_indices_: None,
79            landmark_data_: None,
80            update_count_: 0,
81            accumulated_data_: None,
82            _state: PhantomData,
83        }
84    }
85
86    /// Set the update strategy
87    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
88        self.update_strategy = strategy;
89        self
90    }
91
92    /// Set minimum update size
93    pub fn min_update_size(mut self, size: usize) -> Self {
94        self.min_update_size = size;
95        self
96    }
97
98    /// Set sampling strategy
99    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
100        self.sampling_strategy = strategy;
101        self
102    }
103
104    /// Set random state
105    pub fn random_state(mut self, seed: u64) -> Self {
106        self.random_state = Some(seed);
107        self
108    }
109}
110
111impl Estimator for IncrementalNystroem<Untrained> {
112    type Config = ();
113    type Error = SklearsError;
114    type Float = Float;
115
116    fn config(&self) -> &Self::Config {
117        &()
118    }
119}
120
121impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
122    type Fitted = IncrementalNystroem<Trained>;
123
124    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
125        let (n_samples, _n_features) = x.dim();
126        let n_components = self.n_components.min(n_samples);
127
128        let mut rng = match self.random_state {
129            Some(seed) => RealStdRng::seed_from_u64(seed),
130            None => RealStdRng::from_seed(thread_rng().random()),
131        };
132
133        // Select initial landmark points
134        let component_indices = self.select_components(x, n_components, &mut rng)?;
135        let landmark_data = self.extract_landmarks(x, &component_indices);
136
137        // Compute kernel matrix for landmarks
138        let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
139
140        // Compute eigendecomposition
141        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
142
143        Ok(IncrementalNystroem {
144            kernel: self.kernel,
145            n_components: self.n_components,
146            update_strategy: self.update_strategy,
147            min_update_size: self.min_update_size,
148            sampling_strategy: self.sampling_strategy,
149            random_state: self.random_state,
150            components_: Some(components),
151            normalization_: Some(normalization),
152            component_indices_: Some(component_indices),
153            landmark_data_: Some(landmark_data),
154            update_count_: 0,
155            accumulated_data_: None,
156            _state: PhantomData,
157        })
158    }
159}
160
161impl IncrementalNystroem<Untrained> {
162    /// Select component indices based on sampling strategy
163    fn select_components(
164        &self,
165        x: &Array2<Float>,
166        n_components: usize,
167        rng: &mut RealStdRng,
168    ) -> Result<Vec<usize>> {
169        let (n_samples, _) = x.dim();
170
171        match &self.sampling_strategy {
172            SamplingStrategy::Random => {
173                let mut indices: Vec<usize> = (0..n_samples).collect();
174                indices.shuffle(rng);
175                Ok(indices[..n_components].to_vec())
176            }
177            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
178            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
179            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
180        }
181    }
182
183    /// Simple k-means based sampling
184    fn kmeans_sampling(
185        &self,
186        x: &Array2<Float>,
187        n_components: usize,
188        rng: &mut RealStdRng,
189    ) -> Result<Vec<usize>> {
190        let (n_samples, n_features) = x.dim();
191        let mut centers = Array2::zeros((n_components, n_features));
192
193        // Initialize centers randomly
194        let mut indices: Vec<usize> = (0..n_samples).collect();
195        indices.shuffle(rng);
196        for (i, &idx) in indices[..n_components].iter().enumerate() {
197            centers.row_mut(i).assign(&x.row(idx));
198        }
199
200        // Run a few iterations of k-means
201        for _iter in 0..5 {
202            let mut assignments = vec![0; n_samples];
203
204            // Assign points to nearest centers
205            for i in 0..n_samples {
206                let mut min_dist = Float::INFINITY;
207                let mut best_center = 0;
208
209                for j in 0..n_components {
210                    let diff = &x.row(i) - &centers.row(j);
211                    let dist = diff.dot(&diff);
212                    if dist < min_dist {
213                        min_dist = dist;
214                        best_center = j;
215                    }
216                }
217                assignments[i] = best_center;
218            }
219
220            // Update centers
221            for j in 0..n_components {
222                let cluster_points: Vec<usize> = assignments
223                    .iter()
224                    .enumerate()
225                    .filter(|(_, &assignment)| assignment == j)
226                    .map(|(i, _)| i)
227                    .collect();
228
229                if !cluster_points.is_empty() {
230                    let mut new_center = Array1::zeros(n_features);
231                    for &point_idx in &cluster_points {
232                        new_center = new_center + x.row(point_idx);
233                    }
234                    new_center /= cluster_points.len() as Float;
235                    centers.row_mut(j).assign(&new_center);
236                }
237            }
238        }
239
240        // Find closest points to final centers
241        let mut selected_indices = Vec::new();
242        for j in 0..n_components {
243            let mut min_dist = Float::INFINITY;
244            let mut best_point = 0;
245
246            for i in 0..n_samples {
247                let diff = &x.row(i) - &centers.row(j);
248                let dist = diff.dot(&diff);
249                if dist < min_dist {
250                    min_dist = dist;
251                    best_point = i;
252                }
253            }
254            selected_indices.push(best_point);
255        }
256
257        selected_indices.sort_unstable();
258        selected_indices.dedup();
259
260        // Fill remaining slots randomly if needed
261        while selected_indices.len() < n_components {
262            let random_idx = rng.random_range(0..n_samples);
263            if !selected_indices.contains(&random_idx) {
264                selected_indices.push(random_idx);
265            }
266        }
267
268        Ok(selected_indices[..n_components].to_vec())
269    }
270
271    /// Leverage score based sampling
272    fn leverage_score_sampling(
273        &self,
274        x: &Array2<Float>,
275        n_components: usize,
276        _rng: &mut RealStdRng,
277    ) -> Result<Vec<usize>> {
278        let (n_samples, _) = x.dim();
279
280        // Compute leverage scores (diagonal of hat matrix)
281        // For simplicity, we approximate using row norms as proxy
282        let mut scores = Vec::new();
283        for i in 0..n_samples {
284            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
285            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
286        }
287
288        // Sample based on scores using cumulative distribution
289        let total_score: Float = scores.iter().sum();
290        if total_score <= 0.0 {
291            return Err(SklearsError::InvalidInput(
292                "All scores are zero or negative".to_string(),
293            ));
294        }
295
296        // Create cumulative distribution
297        let mut cumulative = Vec::with_capacity(scores.len());
298        let mut sum = 0.0;
299        for &score in &scores {
300            sum += score / total_score;
301            cumulative.push(sum);
302        }
303
304        let mut selected_indices = Vec::new();
305        for _ in 0..n_components {
306            let r = thread_rng().random::<Float>();
307            // Find index where cumulative probability >= r
308            let mut idx = cumulative
309                .iter()
310                .position(|&cum| cum >= r)
311                .unwrap_or(scores.len() - 1);
312
313            // Ensure no duplicates
314            while selected_indices.contains(&idx) {
315                let r = thread_rng().random::<Float>();
316                idx = cumulative
317                    .iter()
318                    .position(|&cum| cum >= r)
319                    .unwrap_or(scores.len() - 1);
320            }
321            selected_indices.push(idx);
322        }
323
324        Ok(selected_indices)
325    }
326
327    /// Column norm based sampling
328    fn column_norm_sampling(
329        &self,
330        x: &Array2<Float>,
331        n_components: usize,
332        rng: &mut RealStdRng,
333    ) -> Result<Vec<usize>> {
334        let (n_samples, _) = x.dim();
335
336        // Compute row norms
337        let mut norms = Vec::new();
338        for i in 0..n_samples {
339            let norm = x.row(i).dot(&x.row(i)).sqrt();
340            norms.push(norm + 1e-10);
341        }
342
343        // Sort by norm and take diverse selection
344        let mut indices_with_norms: Vec<(usize, Float)> = norms
345            .iter()
346            .enumerate()
347            .map(|(i, &norm)| (i, norm))
348            .collect();
349        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
350
351        let mut selected_indices = Vec::new();
352        let step = n_samples.max(1) / n_components.max(1);
353
354        for i in 0..n_components {
355            let idx = (i * step).min(n_samples - 1);
356            selected_indices.push(indices_with_norms[idx].0);
357        }
358
359        // Fill remaining with random if needed
360        while selected_indices.len() < n_components {
361            let random_idx = rng.random_range(0..n_samples);
362            if !selected_indices.contains(&random_idx) {
363                selected_indices.push(random_idx);
364            }
365        }
366
367        Ok(selected_indices)
368    }
369
370    /// Extract landmark data points
371    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
372        let (_, n_features) = x.dim();
373        let mut landmarks = Array2::zeros((indices.len(), n_features));
374
375        for (i, &idx) in indices.iter().enumerate() {
376            landmarks.row_mut(i).assign(&x.row(idx));
377        }
378
379        landmarks
380    }
381
382    /// Compute eigendecomposition using power iteration method
383    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
384    fn compute_eigendecomposition(
385        &self,
386        matrix: Array2<Float>,
387    ) -> Result<(Array1<Float>, Array2<Float>)> {
388        let n = matrix.nrows();
389
390        if n != matrix.ncols() {
391            return Err(SklearsError::InvalidInput(
392                "Matrix must be square for eigendecomposition".to_string(),
393            ));
394        }
395
396        let mut eigenvals = Array1::zeros(n);
397        let mut eigenvecs = Array2::zeros((n, n));
398
399        // Use deflation method to find multiple eigenvalues
400        let mut deflated_matrix = matrix.clone();
401
402        for k in 0..n {
403            // Power iteration for k-th eigenvalue/eigenvector
404            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
405
406            eigenvals[k] = eigenval;
407            eigenvecs.column_mut(k).assign(&eigenvec);
408
409            // Deflate matrix: A_new = A - λ * v * v^T
410            for i in 0..n {
411                for j in 0..n {
412                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
413                }
414            }
415        }
416
417        // Sort eigenvalues and eigenvectors in descending order
418        let mut indices: Vec<usize> = (0..n).collect();
419        indices.sort_by(|&i, &j| {
420            eigenvals[j]
421                .partial_cmp(&eigenvals[i])
422                .expect("operation should succeed")
423        });
424
425        let mut sorted_eigenvals = Array1::zeros(n);
426        let mut sorted_eigenvecs = Array2::zeros((n, n));
427
428        for (new_idx, &old_idx) in indices.iter().enumerate() {
429            sorted_eigenvals[new_idx] = eigenvals[old_idx];
430            sorted_eigenvecs
431                .column_mut(new_idx)
432                .assign(&eigenvecs.column(old_idx));
433        }
434
435        Ok((sorted_eigenvals, sorted_eigenvecs))
436    }
437
438    /// Power iteration method to find dominant eigenvalue and eigenvector
439    fn power_iteration(
440        &self,
441        matrix: &Array2<Float>,
442        max_iter: usize,
443        tol: Float,
444    ) -> Result<(Float, Array1<Float>)> {
445        let n = matrix.nrows();
446
447        // Initialize with deterministic vector to ensure reproducibility
448        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
449
450        // Normalize
451        let norm = v.dot(&v).sqrt();
452        if norm < 1e-10 {
453            return Err(SklearsError::InvalidInput(
454                "Initial vector has zero norm".to_string(),
455            ));
456        }
457        v /= norm;
458
459        let mut eigenval = 0.0;
460
461        for _iter in 0..max_iter {
462            // Apply matrix
463            let w = matrix.dot(&v);
464
465            // Compute Rayleigh quotient
466            let new_eigenval = v.dot(&w);
467
468            // Normalize
469            let w_norm = w.dot(&w).sqrt();
470            if w_norm < 1e-10 {
471                break;
472            }
473            let new_v = w / w_norm;
474
475            // Check convergence
476            let eigenval_change = (new_eigenval - eigenval).abs();
477            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
478
479            if eigenval_change < tol && vector_change < tol {
480                return Ok((new_eigenval, new_v));
481            }
482
483            eigenval = new_eigenval;
484            v = new_v;
485        }
486
487        Ok((eigenval, v))
488    }
489
490    /// Compute eigendecomposition of kernel matrix
491    fn compute_decomposition(
492        &self,
493        mut kernel_matrix: Array2<Float>,
494    ) -> Result<(Array2<Float>, Array2<Float>)> {
495        // Add small regularization to diagonal for numerical stability
496        let reg = 1e-8;
497        for i in 0..kernel_matrix.nrows() {
498            kernel_matrix[[i, i]] += reg;
499        }
500
501        // Proper eigendecomposition for Nyström method
502        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
503
504        // Filter out small eigenvalues for numerical stability
505        let threshold = 1e-8;
506        let valid_indices: Vec<usize> = eigenvals
507            .iter()
508            .enumerate()
509            .filter(|(_, &val)| val > threshold)
510            .map(|(i, _)| i)
511            .collect();
512
513        if valid_indices.is_empty() {
514            return Err(SklearsError::InvalidInput(
515                "No valid eigenvalues found in kernel matrix".to_string(),
516            ));
517        }
518
519        // Construct components and normalization matrices
520        let n_valid = valid_indices.len();
521        let mut components = Array2::zeros((eigenvals.len(), n_valid));
522        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
523
524        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
525            let sqrt_eigenval = eigenvals[old_idx].sqrt();
526            components
527                .column_mut(new_idx)
528                .assign(&eigenvecs.column(old_idx));
529
530            // For Nyström method: normalization = V * Λ^(-1/2)
531            for i in 0..eigenvals.len() {
532                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
533            }
534        }
535
536        Ok((components, normalization))
537    }
538}
539
540impl IncrementalNystroem<Trained> {
541    /// Select component indices based on sampling strategy
542    fn select_components(
543        &self,
544        x: &Array2<Float>,
545        n_components: usize,
546        rng: &mut RealStdRng,
547    ) -> Result<Vec<usize>> {
548        let (n_samples, _) = x.dim();
549
550        match &self.sampling_strategy {
551            SamplingStrategy::Random => {
552                let mut indices: Vec<usize> = (0..n_samples).collect();
553                indices.shuffle(rng);
554                Ok(indices[..n_components].to_vec())
555            }
556            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
557            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
558            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
559        }
560    }
561
562    /// Simple k-means based sampling
563    fn kmeans_sampling(
564        &self,
565        x: &Array2<Float>,
566        n_components: usize,
567        rng: &mut RealStdRng,
568    ) -> Result<Vec<usize>> {
569        let (n_samples, n_features) = x.dim();
570        let mut centers = Array2::zeros((n_components, n_features));
571
572        // Initialize centers randomly
573        let mut indices: Vec<usize> = (0..n_samples).collect();
574        indices.shuffle(rng);
575        for (i, &idx) in indices[..n_components].iter().enumerate() {
576            centers.row_mut(i).assign(&x.row(idx));
577        }
578
579        // Run a few iterations of k-means
580        for _iter in 0..5 {
581            let mut assignments = vec![0; n_samples];
582
583            // Assign points to nearest centers
584            for i in 0..n_samples {
585                let mut min_dist = Float::INFINITY;
586                let mut best_center = 0;
587
588                for j in 0..n_components {
589                    let diff = &x.row(i) - &centers.row(j);
590                    let dist = diff.dot(&diff);
591                    if dist < min_dist {
592                        min_dist = dist;
593                        best_center = j;
594                    }
595                }
596                assignments[i] = best_center;
597            }
598
599            // Update centers
600            for j in 0..n_components {
601                let cluster_points: Vec<usize> = assignments
602                    .iter()
603                    .enumerate()
604                    .filter(|(_, &assignment)| assignment == j)
605                    .map(|(i, _)| i)
606                    .collect();
607
608                if !cluster_points.is_empty() {
609                    let mut new_center = Array1::zeros(n_features);
610                    for &point_idx in &cluster_points {
611                        new_center = new_center + x.row(point_idx);
612                    }
613                    new_center /= cluster_points.len() as Float;
614                    centers.row_mut(j).assign(&new_center);
615                }
616            }
617        }
618
619        // Find closest points to final centers
620        let mut selected_indices = Vec::new();
621        for j in 0..n_components {
622            let mut min_dist = Float::INFINITY;
623            let mut best_point = 0;
624
625            for i in 0..n_samples {
626                let diff = &x.row(i) - &centers.row(j);
627                let dist = diff.dot(&diff);
628                if dist < min_dist {
629                    min_dist = dist;
630                    best_point = i;
631                }
632            }
633            selected_indices.push(best_point);
634        }
635
636        selected_indices.sort_unstable();
637        selected_indices.dedup();
638
639        // Fill remaining slots randomly if needed
640        while selected_indices.len() < n_components {
641            let random_idx = rng.random_range(0..n_samples);
642            if !selected_indices.contains(&random_idx) {
643                selected_indices.push(random_idx);
644            }
645        }
646
647        Ok(selected_indices[..n_components].to_vec())
648    }
649
650    /// Leverage score based sampling
651    fn leverage_score_sampling(
652        &self,
653        x: &Array2<Float>,
654        n_components: usize,
655        _rng: &mut RealStdRng,
656    ) -> Result<Vec<usize>> {
657        let (n_samples, _) = x.dim();
658
659        // Compute leverage scores (diagonal of hat matrix)
660        // For simplicity, we approximate using row norms as proxy
661        let mut scores = Vec::new();
662        for i in 0..n_samples {
663            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
664            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
665        }
666
667        // Sample based on scores using cumulative distribution
668        let total_score: Float = scores.iter().sum();
669        if total_score <= 0.0 {
670            return Err(SklearsError::InvalidInput(
671                "All scores are zero or negative".to_string(),
672            ));
673        }
674
675        // Create cumulative distribution
676        let mut cumulative = Vec::with_capacity(scores.len());
677        let mut sum = 0.0;
678        for &score in &scores {
679            sum += score / total_score;
680            cumulative.push(sum);
681        }
682
683        let mut selected_indices = Vec::new();
684        for _ in 0..n_components {
685            let r = thread_rng().random::<Float>();
686            // Find index where cumulative probability >= r
687            let mut idx = cumulative
688                .iter()
689                .position(|&cum| cum >= r)
690                .unwrap_or(scores.len() - 1);
691
692            // Ensure no duplicates
693            while selected_indices.contains(&idx) {
694                let r = thread_rng().random::<Float>();
695                idx = cumulative
696                    .iter()
697                    .position(|&cum| cum >= r)
698                    .unwrap_or(scores.len() - 1);
699            }
700            selected_indices.push(idx);
701        }
702
703        Ok(selected_indices)
704    }
705
706    /// Column norm based sampling
707    fn column_norm_sampling(
708        &self,
709        x: &Array2<Float>,
710        n_components: usize,
711        rng: &mut RealStdRng,
712    ) -> Result<Vec<usize>> {
713        let (n_samples, _) = x.dim();
714
715        // Compute row norms
716        let mut norms = Vec::new();
717        for i in 0..n_samples {
718            let norm = x.row(i).dot(&x.row(i)).sqrt();
719            norms.push(norm + 1e-10);
720        }
721
722        // Sort by norm and take diverse selection
723        let mut indices_with_norms: Vec<(usize, Float)> = norms
724            .iter()
725            .enumerate()
726            .map(|(i, &norm)| (i, norm))
727            .collect();
728        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
729
730        let mut selected_indices = Vec::new();
731        let step = n_samples.max(1) / n_components.max(1);
732
733        for i in 0..n_components {
734            let idx = (i * step).min(n_samples - 1);
735            selected_indices.push(indices_with_norms[idx].0);
736        }
737
738        // Fill remaining with random if needed
739        while selected_indices.len() < n_components {
740            let random_idx = rng.random_range(0..n_samples);
741            if !selected_indices.contains(&random_idx) {
742                selected_indices.push(random_idx);
743            }
744        }
745
746        Ok(selected_indices)
747    }
748
749    /// Update the approximation with new data
750    pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
751        // Accumulate new data
752        match &self.accumulated_data_ {
753            Some(existing) => {
754                let combined =
755                    scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
756                self.accumulated_data_ = Some(combined);
757            }
758            None => {
759                self.accumulated_data_ = Some(x_new.clone());
760            }
761        }
762
763        // Check if we have enough accumulated data to update
764        let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
765            accumulated.nrows() >= self.min_update_size
766        } else {
767            false
768        };
769
770        if should_update {
771            if let Some(accumulated) = self.accumulated_data_.take() {
772                self = self.perform_update(&accumulated)?;
773                self.update_count_ += 1;
774            }
775        }
776
777        Ok(self)
778    }
779
780    /// Perform the actual update based on the strategy
781    fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
782        match self.update_strategy.clone() {
783            UpdateStrategy::Append => self.append_update(new_data),
784            UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
785            UpdateStrategy::Merge => self.merge_update(new_data),
786            UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
787        }
788    }
789
790    /// Append new landmarks (if space available)
791    fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
792        let current_landmarks = self
793            .landmark_data_
794            .as_ref()
795            .expect("operation should succeed");
796        let current_components = current_landmarks.nrows();
797
798        if current_components >= self.n_components {
799            // No space to append, just return current state
800            return Ok(self);
801        }
802
803        let available_space = self.n_components - current_components;
804        let n_new = available_space.min(new_data.nrows());
805
806        if n_new == 0 {
807            return Ok(self);
808        }
809
810        // Select new landmarks from new data
811        let mut rng = match self.random_state {
812            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
813            None => RealStdRng::from_seed(thread_rng().random()),
814        };
815
816        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
817        indices.shuffle(&mut rng);
818        let selected_indices = &indices[..n_new];
819
820        // Extract new landmarks
821        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
822
823        // Combine with existing landmarks
824        let combined_landmarks =
825            scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
826
827        // Recompute decomposition
828        let kernel_matrix = self
829            .kernel
830            .compute_kernel(&combined_landmarks, &combined_landmarks);
831        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
832
833        // Update indices
834        let mut new_component_indices = self
835            .component_indices_
836            .as_ref()
837            .expect("operation should succeed")
838            .clone();
839        let base_index = current_landmarks.nrows();
840        for &idx in selected_indices {
841            new_component_indices.push(base_index + idx);
842        }
843
844        self.components_ = Some(components);
845        self.normalization_ = Some(normalization);
846        self.component_indices_ = Some(new_component_indices);
847        self.landmark_data_ = Some(combined_landmarks);
848
849        Ok(self)
850    }
851
852    /// Sliding window update (replace oldest landmarks)
853    fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
854        let current_landmarks = self
855            .landmark_data_
856            .as_ref()
857            .expect("operation should succeed");
858        let n_new = new_data.nrows().min(self.n_components);
859
860        if n_new == 0 {
861            return Ok(self);
862        }
863
864        // Select new landmarks
865        let mut rng = match self.random_state {
866            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
867            None => RealStdRng::from_seed(thread_rng().random()),
868        };
869
870        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
871        indices.shuffle(&mut rng);
872        let selected_indices = &indices[..n_new];
873
874        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
875
876        // Replace oldest landmarks with new ones
877        let n_keep = self.n_components - n_new;
878        let combined_landmarks = if n_keep > 0 {
879            let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
880            scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
881        } else {
882            new_landmarks
883        };
884
885        // Recompute decomposition
886        let kernel_matrix = self
887            .kernel
888            .compute_kernel(&combined_landmarks, &combined_landmarks);
889        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
890
891        // Update component indices (simplified)
892        let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
893
894        self.components_ = Some(components);
895        self.normalization_ = Some(normalization);
896        self.component_indices_ = Some(new_component_indices);
897        self.landmark_data_ = Some(combined_landmarks);
898
899        Ok(self)
900    }
901
902    /// Merge update (combine approximations)
903    fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
904        // Sophisticated merging strategy that combines existing and new Nyström approximations
905        // This is based on the idea of merging two kernel approximations optimally
906
907        let current_landmarks = self
908            .landmark_data_
909            .as_ref()
910            .expect("operation should succeed");
911        let _current_components = self.components_.as_ref().expect("operation should succeed");
912        let _current_normalization = self
913            .normalization_
914            .as_ref()
915            .expect("operation should succeed");
916
917        // Step 1: Create a new Nyström approximation from the new data
918        let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
919
920        let mut rng = match self.random_state {
921            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
922            None => RealStdRng::from_seed(thread_rng().random()),
923        };
924
925        // Select new landmarks using the same strategy
926        let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
927        let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
928
929        // Compute new kernel matrix and decomposition
930        let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
931        let (_new_components, _new_normalization) =
932            self.compute_decomposition(new_kernel_matrix)?;
933
934        // Step 2: Combine the landmarks intelligently
935        // Merge by selecting the most diverse/informative landmarks from both sets
936        let merged_landmarks =
937            self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
938
939        // Step 3: Recompute the full approximation on merged landmarks
940        let merged_kernel_matrix = self
941            .kernel
942            .compute_kernel(&merged_landmarks, &merged_landmarks);
943        let (final_components, final_normalization) =
944            self.compute_decomposition(merged_kernel_matrix)?;
945
946        // Update component indices (simplified for merged case)
947        let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
948
949        let mut updated_self = self;
950        updated_self.components_ = Some(final_components);
951        updated_self.normalization_ = Some(final_normalization);
952        updated_self.component_indices_ = Some(final_component_indices);
953        updated_self.landmark_data_ = Some(merged_landmarks);
954
955        Ok(updated_self)
956    }
957
958    /// Intelligently merge landmarks from existing and new data
959    fn merge_landmarks_intelligently(
960        &self,
961        current_landmarks: &Array2<Float>,
962        new_landmarks: &Array2<Float>,
963        rng: &mut RealStdRng,
964    ) -> Result<Array2<Float>> {
965        let n_current = current_landmarks.nrows();
966        let n_new = new_landmarks.nrows();
967        let n_features = current_landmarks.ncols();
968
969        // Combine all landmarks temporarily
970        let all_landmarks = scirs2_core::ndarray::concatenate![
971            Axis(0),
972            current_landmarks.clone(),
973            new_landmarks.clone()
974        ];
975
976        // Use diversity-based selection to choose the best subset
977        let n_target = self.n_components.min(n_current + n_new);
978        let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
979
980        // Extract selected landmarks
981        let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
982        for (i, &idx) in selected_indices.iter().enumerate() {
983            merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
984        }
985
986        Ok(merged_landmarks)
987    }
988
989    /// Select diverse landmarks using maximum distance criterion
990    fn select_diverse_landmarks(
991        &self,
992        landmarks: &Array2<Float>,
993        n_select: usize,
994        rng: &mut RealStdRng,
995    ) -> Result<Vec<usize>> {
996        let n_landmarks = landmarks.nrows();
997
998        if n_select >= n_landmarks {
999            return Ok((0..n_landmarks).collect());
1000        }
1001
1002        let mut selected = Vec::new();
1003        let mut available: Vec<usize> = (0..n_landmarks).collect();
1004
1005        // Start with a random landmark
1006        let first_idx = rng.random_range(0..available.len());
1007        selected.push(available.remove(first_idx));
1008
1009        // Greedily select landmarks that are maximally distant from already selected ones
1010        while selected.len() < n_select && !available.is_empty() {
1011            let mut best_idx = 0;
1012            let mut max_min_distance = 0.0;
1013
1014            for (i, &candidate_idx) in available.iter().enumerate() {
1015                // Compute minimum distance to already selected landmarks
1016                let mut min_distance = Float::INFINITY;
1017
1018                for &selected_idx in &selected {
1019                    let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1020                    let distance = diff.dot(&diff).sqrt();
1021                    if distance < min_distance {
1022                        min_distance = distance;
1023                    }
1024                }
1025
1026                if min_distance > max_min_distance {
1027                    max_min_distance = min_distance;
1028                    best_idx = i;
1029                }
1030            }
1031
1032            selected.push(available.remove(best_idx));
1033        }
1034
1035        Ok(selected)
1036    }
1037
1038    /// Selective update based on approximation quality
1039    fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1040        // Quality-based selective update that only incorporates new data if it improves approximation
1041
1042        let current_landmarks = self
1043            .landmark_data_
1044            .as_ref()
1045            .expect("operation should succeed");
1046
1047        // Step 1: Evaluate current approximation quality on new data
1048        let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1049
1050        // Step 2: Create candidate updates and evaluate their quality
1051        let mut best_update = self.clone();
1052        let mut best_quality = current_quality;
1053
1054        // Try append update
1055        let append_candidate = self.clone().append_update(new_data)?;
1056        let append_quality = append_candidate.evaluate_approximation_quality(
1057            append_candidate
1058                .landmark_data_
1059                .as_ref()
1060                .expect("operation should succeed"),
1061            new_data,
1062        )?;
1063
1064        if append_quality > best_quality + threshold {
1065            best_update = append_candidate;
1066            best_quality = append_quality;
1067        }
1068
1069        // Try merge update if we have enough data
1070        if new_data.nrows() >= 3 {
1071            let merge_candidate = self.clone().merge_update(new_data)?;
1072            let merge_quality = merge_candidate.evaluate_approximation_quality(
1073                merge_candidate
1074                    .landmark_data_
1075                    .as_ref()
1076                    .expect("operation should succeed"),
1077                new_data,
1078            )?;
1079
1080            if merge_quality > best_quality + threshold {
1081                best_update = merge_candidate;
1082                best_quality = merge_quality;
1083            }
1084        }
1085
1086        // Try sliding window update
1087        let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1088        let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1089            sliding_candidate
1090                .landmark_data_
1091                .as_ref()
1092                .expect("operation should succeed"),
1093            new_data,
1094        )?;
1095
1096        if sliding_quality > best_quality + threshold {
1097            best_update = sliding_candidate;
1098            best_quality = sliding_quality;
1099        }
1100
1101        // Step 3: Only update if quality improvement exceeds threshold
1102        if best_quality > current_quality + threshold {
1103            Ok(best_update)
1104        } else {
1105            // No significant improvement, keep current state
1106            Ok(self)
1107        }
1108    }
1109
1110    /// Evaluate approximation quality using kernel approximation error
1111    fn evaluate_approximation_quality(
1112        &self,
1113        landmarks: &Array2<Float>,
1114        test_data: &Array2<Float>,
1115    ) -> Result<Float> {
1116        // Quality metric: negative approximation error (higher is better)
1117
1118        let n_test = test_data.nrows().min(50); // Limit for efficiency
1119        let test_subset = if test_data.nrows() > n_test {
1120            // Sample random subset for evaluation
1121            let mut rng = thread_rng();
1122            let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1123            indices.shuffle(&mut rng);
1124            test_data.select(Axis(0), &indices[..n_test])
1125        } else {
1126            test_data.to_owned()
1127        };
1128
1129        // Compute exact kernel matrix for test subset
1130        let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1131
1132        // Compute Nyström approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1133        let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1134        let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1135
1136        // Use our eigendecomposition to compute pseudo-inverse
1137        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1138
1139        // Construct pseudo-inverse
1140        let threshold = 1e-8;
1141        let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1142
1143        for i in 0..landmarks.nrows() {
1144            for j in 0..landmarks.nrows() {
1145                let mut sum = 0.0;
1146                for k in 0..eigenvals.len() {
1147                    if eigenvals[k] > threshold {
1148                        sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1149                    }
1150                }
1151                pseudo_inverse[[i, j]] = sum;
1152            }
1153        }
1154
1155        // Compute approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1156        let k_approx = k_test_landmarks
1157            .dot(&pseudo_inverse)
1158            .dot(&k_test_landmarks.t());
1159
1160        // Compute approximation error (Frobenius norm)
1161        let error_matrix = &k_exact - &k_approx;
1162        let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1163
1164        // Convert to quality score (negative error, higher is better)
1165        let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1166
1167        Ok(quality)
1168    }
1169
1170    /// Extract landmark data points
1171    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1172        let (_, n_features) = x.dim();
1173        let mut landmarks = Array2::zeros((indices.len(), n_features));
1174
1175        for (i, &idx) in indices.iter().enumerate() {
1176            landmarks.row_mut(i).assign(&x.row(idx));
1177        }
1178
1179        landmarks
1180    }
1181
1182    /// Compute eigendecomposition using power iteration method
1183    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
1184    fn compute_eigendecomposition(
1185        &self,
1186        matrix: Array2<Float>,
1187    ) -> Result<(Array1<Float>, Array2<Float>)> {
1188        let n = matrix.nrows();
1189
1190        if n != matrix.ncols() {
1191            return Err(SklearsError::InvalidInput(
1192                "Matrix must be square for eigendecomposition".to_string(),
1193            ));
1194        }
1195
1196        let mut eigenvals = Array1::zeros(n);
1197        let mut eigenvecs = Array2::zeros((n, n));
1198
1199        // Use deflation method to find multiple eigenvalues
1200        let mut deflated_matrix = matrix.clone();
1201
1202        for k in 0..n {
1203            // Power iteration for k-th eigenvalue/eigenvector
1204            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1205
1206            eigenvals[k] = eigenval;
1207            eigenvecs.column_mut(k).assign(&eigenvec);
1208
1209            // Deflate matrix: A_new = A - λ * v * v^T
1210            for i in 0..n {
1211                for j in 0..n {
1212                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1213                }
1214            }
1215        }
1216
1217        // Sort eigenvalues and eigenvectors in descending order
1218        let mut indices: Vec<usize> = (0..n).collect();
1219        indices.sort_by(|&i, &j| {
1220            eigenvals[j]
1221                .partial_cmp(&eigenvals[i])
1222                .expect("operation should succeed")
1223        });
1224
1225        let mut sorted_eigenvals = Array1::zeros(n);
1226        let mut sorted_eigenvecs = Array2::zeros((n, n));
1227
1228        for (new_idx, &old_idx) in indices.iter().enumerate() {
1229            sorted_eigenvals[new_idx] = eigenvals[old_idx];
1230            sorted_eigenvecs
1231                .column_mut(new_idx)
1232                .assign(&eigenvecs.column(old_idx));
1233        }
1234
1235        Ok((sorted_eigenvals, sorted_eigenvecs))
1236    }
1237
1238    /// Power iteration method to find dominant eigenvalue and eigenvector
1239    fn power_iteration(
1240        &self,
1241        matrix: &Array2<Float>,
1242        max_iter: usize,
1243        tol: Float,
1244    ) -> Result<(Float, Array1<Float>)> {
1245        let n = matrix.nrows();
1246
1247        // Initialize with deterministic vector to ensure reproducibility
1248        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1249
1250        // Normalize
1251        let norm = v.dot(&v).sqrt();
1252        if norm < 1e-10 {
1253            return Err(SklearsError::InvalidInput(
1254                "Initial vector has zero norm".to_string(),
1255            ));
1256        }
1257        v /= norm;
1258
1259        let mut eigenval = 0.0;
1260
1261        for _iter in 0..max_iter {
1262            // Apply matrix
1263            let w = matrix.dot(&v);
1264
1265            // Compute Rayleigh quotient
1266            let new_eigenval = v.dot(&w);
1267
1268            // Normalize
1269            let w_norm = w.dot(&w).sqrt();
1270            if w_norm < 1e-10 {
1271                break;
1272            }
1273            let new_v = w / w_norm;
1274
1275            // Check convergence
1276            let eigenval_change = (new_eigenval - eigenval).abs();
1277            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1278
1279            if eigenval_change < tol && vector_change < tol {
1280                return Ok((new_eigenval, new_v));
1281            }
1282
1283            eigenval = new_eigenval;
1284            v = new_v;
1285        }
1286
1287        Ok((eigenval, v))
1288    }
1289
1290    /// Compute eigendecomposition of kernel matrix
1291    fn compute_decomposition(
1292        &self,
1293        mut kernel_matrix: Array2<Float>,
1294    ) -> Result<(Array2<Float>, Array2<Float>)> {
1295        // Add small regularization to diagonal for numerical stability
1296        let reg = 1e-8;
1297        for i in 0..kernel_matrix.nrows() {
1298            kernel_matrix[[i, i]] += reg;
1299        }
1300
1301        // Proper eigendecomposition for Nyström method
1302        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1303
1304        // Filter out small eigenvalues for numerical stability
1305        let threshold = 1e-8;
1306        let valid_indices: Vec<usize> = eigenvals
1307            .iter()
1308            .enumerate()
1309            .filter(|(_, &val)| val > threshold)
1310            .map(|(i, _)| i)
1311            .collect();
1312
1313        if valid_indices.is_empty() {
1314            return Err(SklearsError::InvalidInput(
1315                "No valid eigenvalues found in kernel matrix".to_string(),
1316            ));
1317        }
1318
1319        // Construct components and normalization matrices
1320        let n_valid = valid_indices.len();
1321        let mut components = Array2::zeros((eigenvals.len(), n_valid));
1322        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1323
1324        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1325            let sqrt_eigenval = eigenvals[old_idx].sqrt();
1326            components
1327                .column_mut(new_idx)
1328                .assign(&eigenvecs.column(old_idx));
1329
1330            // For Nyström method: normalization = V * Λ^(-1/2)
1331            for i in 0..eigenvals.len() {
1332                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1333            }
1334        }
1335
1336        Ok((components, normalization))
1337    }
1338
1339    /// Get number of updates performed
1340    pub fn update_count(&self) -> usize {
1341        self.update_count_
1342    }
1343
1344    /// Get current number of landmarks
1345    pub fn n_landmarks(&self) -> usize {
1346        self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1347    }
1348}
1349
1350impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1351    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1352        let _components = self
1353            .components_
1354            .as_ref()
1355            .ok_or_else(|| SklearsError::NotFitted {
1356                operation: "transform".to_string(),
1357            })?;
1358
1359        let normalization =
1360            self.normalization_
1361                .as_ref()
1362                .ok_or_else(|| SklearsError::NotFitted {
1363                    operation: "transform".to_string(),
1364                })?;
1365
1366        let landmark_data =
1367            self.landmark_data_
1368                .as_ref()
1369                .ok_or_else(|| SklearsError::NotFitted {
1370                    operation: "transform".to_string(),
1371                })?;
1372
1373        // Compute kernel between input and landmarks
1374        let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1375
1376        // Apply transformation: K(X, landmarks) @ normalization.T
1377        let transformed = kernel_x_landmarks.dot(&normalization.t());
1378
1379        Ok(transformed)
1380    }
1381}
1382
1383#[allow(non_snake_case)]
1384#[cfg(test)]
1385mod tests {
1386    use super::*;
1387    use scirs2_core::ndarray::array;
1388
1389    #[test]
1390    fn test_incremental_nystroem_basic() {
1391        let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1392        let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1393
1394        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1395            .update_strategy(UpdateStrategy::Append)
1396            .min_update_size(1);
1397
1398        let fitted = nystroem
1399            .fit(&x_initial, &())
1400            .expect("operation should succeed");
1401        assert_eq!(fitted.n_landmarks(), 3);
1402
1403        let updated = fitted.update(&x_new).expect("operation should succeed");
1404        assert_eq!(updated.n_landmarks(), 5);
1405        assert_eq!(updated.update_count(), 1);
1406    }
1407
1408    #[test]
1409    fn test_incremental_transform() {
1410        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1411        let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1412
1413        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1414        let fitted = nystroem
1415            .fit(&x_train, &())
1416            .expect("operation should succeed");
1417
1418        let transformed = fitted.transform(&x_test).expect("operation should succeed");
1419        assert_eq!(transformed.shape()[0], 2);
1420        assert!(transformed.shape()[1] <= 3);
1421    }
1422
1423    #[test]
1424    fn test_sliding_window_update() {
1425        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1426        let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1427
1428        let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1429            .update_strategy(UpdateStrategy::SlidingWindow)
1430            .min_update_size(1);
1431
1432        let fitted = nystroem
1433            .fit(&x_initial, &())
1434            .expect("operation should succeed");
1435        let updated = fitted.update(&x_new).expect("operation should succeed");
1436
1437        assert_eq!(updated.n_landmarks(), 3);
1438        assert_eq!(updated.update_count(), 1);
1439    }
1440
1441    #[test]
1442    fn test_different_kernels() {
1443        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1444
1445        // Test with RBF kernel
1446        let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1447        let rbf_fitted = rbf_nystroem.fit(&x, &()).expect("operation should succeed");
1448        let rbf_transformed = rbf_fitted.transform(&x).expect("operation should succeed");
1449        assert_eq!(rbf_transformed.shape()[0], 3);
1450
1451        // Test with polynomial kernel
1452        let poly_nystroem = IncrementalNystroem::new(
1453            Kernel::Polynomial {
1454                gamma: 1.0,
1455                coef0: 1.0,
1456                degree: 2,
1457            },
1458            3,
1459        );
1460        let poly_fitted = poly_nystroem
1461            .fit(&x, &())
1462            .expect("operation should succeed");
1463        let poly_transformed = poly_fitted.transform(&x).expect("operation should succeed");
1464        assert_eq!(poly_transformed.shape()[0], 3);
1465    }
1466
1467    #[test]
1468    fn test_min_update_size() {
1469        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1470        let x_small = array![[3.0, 4.0]];
1471        let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1472
1473        let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1474
1475        let fitted = nystroem
1476            .fit(&x_initial, &())
1477            .expect("operation should succeed");
1478
1479        // Small update should not trigger recomputation
1480        let after_small = fitted.update(&x_small).expect("operation should succeed");
1481        assert_eq!(after_small.update_count(), 0);
1482        assert_eq!(after_small.n_landmarks(), 2);
1483
1484        // Large update should trigger recomputation
1485        let after_large = after_small
1486            .update(&x_large)
1487            .expect("operation should succeed");
1488        assert_eq!(after_large.update_count(), 1);
1489        assert_eq!(after_large.n_landmarks(), 5);
1490    }
1491
1492    #[test]
1493    fn test_reproducibility() {
1494        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1495        let x_new = array![[4.0, 5.0]];
1496
1497        let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1498            .random_state(42)
1499            .min_update_size(1);
1500        let fitted1 = nystroem1.fit(&x, &()).expect("operation should succeed");
1501        let updated1 = fitted1.update(&x_new).expect("operation should succeed");
1502        let result1 = updated1.transform(&x).expect("operation should succeed");
1503
1504        let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1505            .random_state(42)
1506            .min_update_size(1);
1507        let fitted2 = nystroem2.fit(&x, &()).expect("operation should succeed");
1508        let updated2 = fitted2.update(&x_new).expect("operation should succeed");
1509        let result2 = updated2.transform(&x).expect("operation should succeed");
1510
1511        // Results should be very similar with same random seed (allowing for numerical precision)
1512        // Note: eigendecomposition can produce results that differ by a sign flip
1513        assert_eq!(result1.shape(), result2.shape());
1514
1515        // Check if results are similar or similar up to sign flip
1516        let mut direct_match = true;
1517        let mut sign_flip_match = true;
1518
1519        for i in 0..result1.len() {
1520            let val1 = result1.as_slice().expect("operation should succeed")[i];
1521            let val2 = result2.as_slice().expect("operation should succeed")[i];
1522
1523            if (val1 - val2).abs() > 1e-6 {
1524                direct_match = false;
1525            }
1526            if (val1 + val2).abs() > 1e-6 {
1527                sign_flip_match = false;
1528            }
1529        }
1530
1531        assert!(
1532            direct_match || sign_flip_match,
1533            "Results differ too much and are not related by sign flip"
1534        );
1535    }
1536}