sklears_kernel_approximation/
incremental_nystroem.rs

1//! Incremental Nyström method for online kernel approximation
2//!
3//! This module implements online/incremental versions of the Nyström method
4//! that can efficiently update kernel approximations as new data arrives.
5
6use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::{Estimator, Fit, Trained, Transform, Untrained},
13    types::Float,
14};
15use std::marker::PhantomData;
16
17use crate::nystroem::{Kernel, SamplingStrategy};
18
19/// Update strategy for incremental Nyström
20#[derive(Debug, Clone)]
21/// UpdateStrategy
22pub enum UpdateStrategy {
23    /// Simple addition of new landmarks
24    Append,
25    /// Replace oldest landmarks with new ones (sliding window)
26    SlidingWindow,
27    /// Merge new data with existing approximation
28    Merge,
29    /// Selective update based on approximation quality
30    Selective { threshold: Float },
31}
32
33/// Incremental Nyström method for online kernel approximation
34///
35/// Enables efficient updating of kernel approximations as new data arrives,
36/// without requiring complete recomputation from scratch.
37///
38/// # Parameters
39///
40/// * `kernel` - Kernel function to approximate
41/// * `n_components` - Maximum number of landmark points
42/// * `update_strategy` - Strategy for incorporating new data
43/// * `min_update_size` - Minimum number of new samples before updating
44/// * `sampling_strategy` - Strategy for selecting new landmarks
45///
46/// # Examples
47///
48/// ```rust,ignore
49/// use sklears_kernel_approximation::incremental_nystroem::{
50#[derive(Debug, Clone)]
51pub struct IncrementalNystroem<State = Untrained> {
52    pub kernel: Kernel,
53    pub n_components: usize,
54    pub update_strategy: UpdateStrategy,
55    pub min_update_size: usize,
56    pub sampling_strategy: SamplingStrategy,
57    pub random_state: Option<u64>,
58
59    // Fitted attributes
60    components_: Option<Array2<Float>>,
61    normalization_: Option<Array2<Float>>,
62    component_indices_: Option<Vec<usize>>,
63    landmark_data_: Option<Array2<Float>>,
64    update_count_: usize,
65    accumulated_data_: Option<Array2<Float>>,
66
67    _state: PhantomData<State>,
68}
69
70impl IncrementalNystroem<Untrained> {
71    /// Create a new incremental Nyström approximator
72    pub fn new(kernel: Kernel, n_components: usize) -> Self {
73        Self {
74            kernel,
75            n_components,
76            update_strategy: UpdateStrategy::Append,
77            min_update_size: 10,
78            sampling_strategy: SamplingStrategy::Random,
79            random_state: None,
80            components_: None,
81            normalization_: None,
82            component_indices_: None,
83            landmark_data_: None,
84            update_count_: 0,
85            accumulated_data_: None,
86            _state: PhantomData,
87        }
88    }
89
90    /// Set the update strategy
91    pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
92        self.update_strategy = strategy;
93        self
94    }
95
96    /// Set minimum update size
97    pub fn min_update_size(mut self, size: usize) -> Self {
98        self.min_update_size = size;
99        self
100    }
101
102    /// Set sampling strategy
103    pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
104        self.sampling_strategy = strategy;
105        self
106    }
107
108    /// Set random state
109    pub fn random_state(mut self, seed: u64) -> Self {
110        self.random_state = Some(seed);
111        self
112    }
113}
114
115impl Estimator for IncrementalNystroem<Untrained> {
116    type Config = ();
117    type Error = SklearsError;
118    type Float = Float;
119
120    fn config(&self) -> &Self::Config {
121        &()
122    }
123}
124
125impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
126    type Fitted = IncrementalNystroem<Trained>;
127
128    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
129        let (n_samples, n_features) = x.dim();
130        let n_components = self.n_components.min(n_samples);
131
132        let mut rng = match self.random_state {
133            Some(seed) => RealStdRng::seed_from_u64(seed),
134            None => RealStdRng::from_seed(thread_rng().gen()),
135        };
136
137        // Select initial landmark points
138        let component_indices = self.select_components(x, n_components, &mut rng)?;
139        let landmark_data = self.extract_landmarks(x, &component_indices);
140
141        // Compute kernel matrix for landmarks
142        let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
143
144        // Compute eigendecomposition
145        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
146
147        Ok(IncrementalNystroem {
148            kernel: self.kernel,
149            n_components: self.n_components,
150            update_strategy: self.update_strategy,
151            min_update_size: self.min_update_size,
152            sampling_strategy: self.sampling_strategy,
153            random_state: self.random_state,
154            components_: Some(components),
155            normalization_: Some(normalization),
156            component_indices_: Some(component_indices),
157            landmark_data_: Some(landmark_data),
158            update_count_: 0,
159            accumulated_data_: None,
160            _state: PhantomData,
161        })
162    }
163}
164
165impl IncrementalNystroem<Untrained> {
166    /// Select component indices based on sampling strategy
167    fn select_components(
168        &self,
169        x: &Array2<Float>,
170        n_components: usize,
171        rng: &mut RealStdRng,
172    ) -> Result<Vec<usize>> {
173        let (n_samples, _) = x.dim();
174
175        match &self.sampling_strategy {
176            SamplingStrategy::Random => {
177                let mut indices: Vec<usize> = (0..n_samples).collect();
178                indices.shuffle(rng);
179                Ok(indices[..n_components].to_vec())
180            }
181            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
182            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
183            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
184        }
185    }
186
187    /// Simple k-means based sampling
188    fn kmeans_sampling(
189        &self,
190        x: &Array2<Float>,
191        n_components: usize,
192        rng: &mut RealStdRng,
193    ) -> Result<Vec<usize>> {
194        let (n_samples, n_features) = x.dim();
195        let mut centers = Array2::zeros((n_components, n_features));
196
197        // Initialize centers randomly
198        let mut indices: Vec<usize> = (0..n_samples).collect();
199        indices.shuffle(rng);
200        for (i, &idx) in indices[..n_components].iter().enumerate() {
201            centers.row_mut(i).assign(&x.row(idx));
202        }
203
204        // Run a few iterations of k-means
205        for _iter in 0..5 {
206            let mut assignments = vec![0; n_samples];
207
208            // Assign points to nearest centers
209            for i in 0..n_samples {
210                let mut min_dist = Float::INFINITY;
211                let mut best_center = 0;
212
213                for j in 0..n_components {
214                    let diff = &x.row(i) - &centers.row(j);
215                    let dist = diff.dot(&diff);
216                    if dist < min_dist {
217                        min_dist = dist;
218                        best_center = j;
219                    }
220                }
221                assignments[i] = best_center;
222            }
223
224            // Update centers
225            for j in 0..n_components {
226                let cluster_points: Vec<usize> = assignments
227                    .iter()
228                    .enumerate()
229                    .filter(|(_, &assignment)| assignment == j)
230                    .map(|(i, _)| i)
231                    .collect();
232
233                if !cluster_points.is_empty() {
234                    let mut new_center = Array1::zeros(n_features);
235                    for &point_idx in &cluster_points {
236                        new_center = new_center + &x.row(point_idx);
237                    }
238                    new_center /= cluster_points.len() as Float;
239                    centers.row_mut(j).assign(&new_center);
240                }
241            }
242        }
243
244        // Find closest points to final centers
245        let mut selected_indices = Vec::new();
246        for j in 0..n_components {
247            let mut min_dist = Float::INFINITY;
248            let mut best_point = 0;
249
250            for i in 0..n_samples {
251                let diff = &x.row(i) - &centers.row(j);
252                let dist = diff.dot(&diff);
253                if dist < min_dist {
254                    min_dist = dist;
255                    best_point = i;
256                }
257            }
258            selected_indices.push(best_point);
259        }
260
261        selected_indices.sort_unstable();
262        selected_indices.dedup();
263
264        // Fill remaining slots randomly if needed
265        while selected_indices.len() < n_components {
266            let random_idx = rng.gen_range(0..n_samples);
267            if !selected_indices.contains(&random_idx) {
268                selected_indices.push(random_idx);
269            }
270        }
271
272        Ok(selected_indices[..n_components].to_vec())
273    }
274
275    /// Leverage score based sampling
276    fn leverage_score_sampling(
277        &self,
278        x: &Array2<Float>,
279        n_components: usize,
280        rng: &mut RealStdRng,
281    ) -> Result<Vec<usize>> {
282        let (n_samples, _) = x.dim();
283
284        // Compute leverage scores (diagonal of hat matrix)
285        // For simplicity, we approximate using row norms as proxy
286        let mut scores = Vec::new();
287        for i in 0..n_samples {
288            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
289            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
290        }
291
292        // Sample based on scores using cumulative distribution
293        let total_score: Float = scores.iter().sum();
294        if total_score <= 0.0 {
295            return Err(SklearsError::InvalidInput(
296                "All scores are zero or negative".to_string(),
297            ));
298        }
299
300        // Create cumulative distribution
301        let mut cumulative = Vec::with_capacity(scores.len());
302        let mut sum = 0.0;
303        for &score in &scores {
304            sum += score / total_score;
305            cumulative.push(sum);
306        }
307
308        let mut selected_indices = Vec::new();
309        for _ in 0..n_components {
310            let r = thread_rng().gen::<Float>();
311            // Find index where cumulative probability >= r
312            let mut idx = cumulative
313                .iter()
314                .position(|&cum| cum >= r)
315                .unwrap_or(scores.len() - 1);
316
317            // Ensure no duplicates
318            while selected_indices.contains(&idx) {
319                let r = thread_rng().gen::<Float>();
320                idx = cumulative
321                    .iter()
322                    .position(|&cum| cum >= r)
323                    .unwrap_or(scores.len() - 1);
324            }
325            selected_indices.push(idx);
326        }
327
328        Ok(selected_indices)
329    }
330
331    /// Column norm based sampling
332    fn column_norm_sampling(
333        &self,
334        x: &Array2<Float>,
335        n_components: usize,
336        rng: &mut RealStdRng,
337    ) -> Result<Vec<usize>> {
338        let (n_samples, _) = x.dim();
339
340        // Compute row norms
341        let mut norms = Vec::new();
342        for i in 0..n_samples {
343            let norm = x.row(i).dot(&x.row(i)).sqrt();
344            norms.push(norm + 1e-10);
345        }
346
347        // Sort by norm and take diverse selection
348        let mut indices_with_norms: Vec<(usize, Float)> = norms
349            .iter()
350            .enumerate()
351            .map(|(i, &norm)| (i, norm))
352            .collect();
353        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
354
355        let mut selected_indices = Vec::new();
356        let step = n_samples.max(1) / n_components.max(1);
357
358        for i in 0..n_components {
359            let idx = (i * step).min(n_samples - 1);
360            selected_indices.push(indices_with_norms[idx].0);
361        }
362
363        // Fill remaining with random if needed
364        while selected_indices.len() < n_components {
365            let random_idx = rng.gen_range(0..n_samples);
366            if !selected_indices.contains(&random_idx) {
367                selected_indices.push(random_idx);
368            }
369        }
370
371        Ok(selected_indices)
372    }
373
374    /// Extract landmark data points
375    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
376        let (_, n_features) = x.dim();
377        let mut landmarks = Array2::zeros((indices.len(), n_features));
378
379        for (i, &idx) in indices.iter().enumerate() {
380            landmarks.row_mut(i).assign(&x.row(idx));
381        }
382
383        landmarks
384    }
385
386    /// Compute eigendecomposition using power iteration method
387    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
388    fn compute_eigendecomposition(
389        &self,
390        matrix: Array2<Float>,
391    ) -> Result<(Array1<Float>, Array2<Float>)> {
392        let n = matrix.nrows();
393
394        if n != matrix.ncols() {
395            return Err(SklearsError::InvalidInput(
396                "Matrix must be square for eigendecomposition".to_string(),
397            ));
398        }
399
400        let mut eigenvals = Array1::zeros(n);
401        let mut eigenvecs = Array2::zeros((n, n));
402
403        // Use deflation method to find multiple eigenvalues
404        let mut deflated_matrix = matrix.clone();
405
406        for k in 0..n {
407            // Power iteration for k-th eigenvalue/eigenvector
408            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
409
410            eigenvals[k] = eigenval;
411            eigenvecs.column_mut(k).assign(&eigenvec);
412
413            // Deflate matrix: A_new = A - λ * v * v^T
414            for i in 0..n {
415                for j in 0..n {
416                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
417                }
418            }
419        }
420
421        // Sort eigenvalues and eigenvectors in descending order
422        let mut indices: Vec<usize> = (0..n).collect();
423        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
424
425        let mut sorted_eigenvals = Array1::zeros(n);
426        let mut sorted_eigenvecs = Array2::zeros((n, n));
427
428        for (new_idx, &old_idx) in indices.iter().enumerate() {
429            sorted_eigenvals[new_idx] = eigenvals[old_idx];
430            sorted_eigenvecs
431                .column_mut(new_idx)
432                .assign(&eigenvecs.column(old_idx));
433        }
434
435        Ok((sorted_eigenvals, sorted_eigenvecs))
436    }
437
438    /// Power iteration method to find dominant eigenvalue and eigenvector
439    fn power_iteration(
440        &self,
441        matrix: &Array2<Float>,
442        max_iter: usize,
443        tol: Float,
444    ) -> Result<(Float, Array1<Float>)> {
445        let n = matrix.nrows();
446
447        // Initialize with deterministic vector to ensure reproducibility
448        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
449
450        // Normalize
451        let norm = v.dot(&v).sqrt();
452        if norm < 1e-10 {
453            return Err(SklearsError::InvalidInput(
454                "Initial vector has zero norm".to_string(),
455            ));
456        }
457        v /= norm;
458
459        let mut eigenval = 0.0;
460
461        for _iter in 0..max_iter {
462            // Apply matrix
463            let w = matrix.dot(&v);
464
465            // Compute Rayleigh quotient
466            let new_eigenval = v.dot(&w);
467
468            // Normalize
469            let w_norm = w.dot(&w).sqrt();
470            if w_norm < 1e-10 {
471                break;
472            }
473            let new_v = w / w_norm;
474
475            // Check convergence
476            let eigenval_change = (new_eigenval - eigenval).abs();
477            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
478
479            if eigenval_change < tol && vector_change < tol {
480                return Ok((new_eigenval, new_v));
481            }
482
483            eigenval = new_eigenval;
484            v = new_v;
485        }
486
487        Ok((eigenval, v))
488    }
489
490    /// Compute eigendecomposition of kernel matrix
491    fn compute_decomposition(
492        &self,
493        mut kernel_matrix: Array2<Float>,
494    ) -> Result<(Array2<Float>, Array2<Float>)> {
495        // Add small regularization to diagonal for numerical stability
496        let reg = 1e-8;
497        for i in 0..kernel_matrix.nrows() {
498            kernel_matrix[[i, i]] += reg;
499        }
500
501        // Proper eigendecomposition for Nyström method
502        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
503
504        // Filter out small eigenvalues for numerical stability
505        let threshold = 1e-8;
506        let valid_indices: Vec<usize> = eigenvals
507            .iter()
508            .enumerate()
509            .filter(|(_, &val)| val > threshold)
510            .map(|(i, _)| i)
511            .collect();
512
513        if valid_indices.is_empty() {
514            return Err(SklearsError::InvalidInput(
515                "No valid eigenvalues found in kernel matrix".to_string(),
516            ));
517        }
518
519        // Construct components and normalization matrices
520        let n_valid = valid_indices.len();
521        let mut components = Array2::zeros((eigenvals.len(), n_valid));
522        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
523
524        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
525            let sqrt_eigenval = eigenvals[old_idx].sqrt();
526            components
527                .column_mut(new_idx)
528                .assign(&eigenvecs.column(old_idx));
529
530            // For Nyström method: normalization = V * Λ^(-1/2)
531            for i in 0..eigenvals.len() {
532                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
533            }
534        }
535
536        Ok((components, normalization))
537    }
538}
539
540impl IncrementalNystroem<Trained> {
541    /// Select component indices based on sampling strategy
542    fn select_components(
543        &self,
544        x: &Array2<Float>,
545        n_components: usize,
546        rng: &mut RealStdRng,
547    ) -> Result<Vec<usize>> {
548        let (n_samples, _) = x.dim();
549
550        match &self.sampling_strategy {
551            SamplingStrategy::Random => {
552                let mut indices: Vec<usize> = (0..n_samples).collect();
553                indices.shuffle(rng);
554                Ok(indices[..n_components].to_vec())
555            }
556            SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
557            SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
558            SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
559        }
560    }
561
562    /// Simple k-means based sampling
563    fn kmeans_sampling(
564        &self,
565        x: &Array2<Float>,
566        n_components: usize,
567        rng: &mut RealStdRng,
568    ) -> Result<Vec<usize>> {
569        let (n_samples, n_features) = x.dim();
570        let mut centers = Array2::zeros((n_components, n_features));
571
572        // Initialize centers randomly
573        let mut indices: Vec<usize> = (0..n_samples).collect();
574        indices.shuffle(rng);
575        for (i, &idx) in indices[..n_components].iter().enumerate() {
576            centers.row_mut(i).assign(&x.row(idx));
577        }
578
579        // Run a few iterations of k-means
580        for _iter in 0..5 {
581            let mut assignments = vec![0; n_samples];
582
583            // Assign points to nearest centers
584            for i in 0..n_samples {
585                let mut min_dist = Float::INFINITY;
586                let mut best_center = 0;
587
588                for j in 0..n_components {
589                    let diff = &x.row(i) - &centers.row(j);
590                    let dist = diff.dot(&diff);
591                    if dist < min_dist {
592                        min_dist = dist;
593                        best_center = j;
594                    }
595                }
596                assignments[i] = best_center;
597            }
598
599            // Update centers
600            for j in 0..n_components {
601                let cluster_points: Vec<usize> = assignments
602                    .iter()
603                    .enumerate()
604                    .filter(|(_, &assignment)| assignment == j)
605                    .map(|(i, _)| i)
606                    .collect();
607
608                if !cluster_points.is_empty() {
609                    let mut new_center = Array1::zeros(n_features);
610                    for &point_idx in &cluster_points {
611                        new_center = new_center + &x.row(point_idx);
612                    }
613                    new_center /= cluster_points.len() as Float;
614                    centers.row_mut(j).assign(&new_center);
615                }
616            }
617        }
618
619        // Find closest points to final centers
620        let mut selected_indices = Vec::new();
621        for j in 0..n_components {
622            let mut min_dist = Float::INFINITY;
623            let mut best_point = 0;
624
625            for i in 0..n_samples {
626                let diff = &x.row(i) - &centers.row(j);
627                let dist = diff.dot(&diff);
628                if dist < min_dist {
629                    min_dist = dist;
630                    best_point = i;
631                }
632            }
633            selected_indices.push(best_point);
634        }
635
636        selected_indices.sort_unstable();
637        selected_indices.dedup();
638
639        // Fill remaining slots randomly if needed
640        while selected_indices.len() < n_components {
641            let random_idx = rng.gen_range(0..n_samples);
642            if !selected_indices.contains(&random_idx) {
643                selected_indices.push(random_idx);
644            }
645        }
646
647        Ok(selected_indices[..n_components].to_vec())
648    }
649
650    /// Leverage score based sampling
651    fn leverage_score_sampling(
652        &self,
653        x: &Array2<Float>,
654        n_components: usize,
655        rng: &mut RealStdRng,
656    ) -> Result<Vec<usize>> {
657        let (n_samples, _) = x.dim();
658
659        // Compute leverage scores (diagonal of hat matrix)
660        // For simplicity, we approximate using row norms as proxy
661        let mut scores = Vec::new();
662        for i in 0..n_samples {
663            let row_norm = x.row(i).dot(&x.row(i)).sqrt();
664            scores.push(row_norm + 1e-10); // Add small epsilon for numerical stability
665        }
666
667        // Sample based on scores using cumulative distribution
668        let total_score: Float = scores.iter().sum();
669        if total_score <= 0.0 {
670            return Err(SklearsError::InvalidInput(
671                "All scores are zero or negative".to_string(),
672            ));
673        }
674
675        // Create cumulative distribution
676        let mut cumulative = Vec::with_capacity(scores.len());
677        let mut sum = 0.0;
678        for &score in &scores {
679            sum += score / total_score;
680            cumulative.push(sum);
681        }
682
683        let mut selected_indices = Vec::new();
684        for _ in 0..n_components {
685            let r = thread_rng().gen::<Float>();
686            // Find index where cumulative probability >= r
687            let mut idx = cumulative
688                .iter()
689                .position(|&cum| cum >= r)
690                .unwrap_or(scores.len() - 1);
691
692            // Ensure no duplicates
693            while selected_indices.contains(&idx) {
694                let r = thread_rng().gen::<Float>();
695                idx = cumulative
696                    .iter()
697                    .position(|&cum| cum >= r)
698                    .unwrap_or(scores.len() - 1);
699            }
700            selected_indices.push(idx);
701        }
702
703        Ok(selected_indices)
704    }
705
706    /// Column norm based sampling
707    fn column_norm_sampling(
708        &self,
709        x: &Array2<Float>,
710        n_components: usize,
711        rng: &mut RealStdRng,
712    ) -> Result<Vec<usize>> {
713        let (n_samples, _) = x.dim();
714
715        // Compute row norms
716        let mut norms = Vec::new();
717        for i in 0..n_samples {
718            let norm = x.row(i).dot(&x.row(i)).sqrt();
719            norms.push(norm + 1e-10);
720        }
721
722        // Sort by norm and take diverse selection
723        let mut indices_with_norms: Vec<(usize, Float)> = norms
724            .iter()
725            .enumerate()
726            .map(|(i, &norm)| (i, norm))
727            .collect();
728        indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
729
730        let mut selected_indices = Vec::new();
731        let step = n_samples.max(1) / n_components.max(1);
732
733        for i in 0..n_components {
734            let idx = (i * step).min(n_samples - 1);
735            selected_indices.push(indices_with_norms[idx].0);
736        }
737
738        // Fill remaining with random if needed
739        while selected_indices.len() < n_components {
740            let random_idx = rng.gen_range(0..n_samples);
741            if !selected_indices.contains(&random_idx) {
742                selected_indices.push(random_idx);
743            }
744        }
745
746        Ok(selected_indices)
747    }
748
749    /// Update the approximation with new data
750    pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
751        // Accumulate new data
752        match &self.accumulated_data_ {
753            Some(existing) => {
754                let combined =
755                    scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
756                self.accumulated_data_ = Some(combined);
757            }
758            None => {
759                self.accumulated_data_ = Some(x_new.clone());
760            }
761        }
762
763        // Check if we have enough accumulated data to update
764        let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
765            accumulated.nrows() >= self.min_update_size
766        } else {
767            false
768        };
769
770        if should_update {
771            if let Some(accumulated) = self.accumulated_data_.take() {
772                self = self.perform_update(&accumulated)?;
773                self.update_count_ += 1;
774            }
775        }
776
777        Ok(self)
778    }
779
780    /// Perform the actual update based on the strategy
781    fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
782        match self.update_strategy.clone() {
783            UpdateStrategy::Append => self.append_update(new_data),
784            UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
785            UpdateStrategy::Merge => self.merge_update(new_data),
786            UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
787        }
788    }
789
790    /// Append new landmarks (if space available)
791    fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
792        let current_landmarks = self.landmark_data_.as_ref().unwrap();
793        let current_components = current_landmarks.nrows();
794
795        if current_components >= self.n_components {
796            // No space to append, just return current state
797            return Ok(self);
798        }
799
800        let available_space = self.n_components - current_components;
801        let n_new = available_space.min(new_data.nrows());
802
803        if n_new == 0 {
804            return Ok(self);
805        }
806
807        // Select new landmarks from new data
808        let mut rng = match self.random_state {
809            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
810            None => RealStdRng::from_seed(thread_rng().gen()),
811        };
812
813        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
814        indices.shuffle(&mut rng);
815        let selected_indices = &indices[..n_new];
816
817        // Extract new landmarks
818        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
819
820        // Combine with existing landmarks
821        let combined_landmarks =
822            scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
823
824        // Recompute decomposition
825        let kernel_matrix = self
826            .kernel
827            .compute_kernel(&combined_landmarks, &combined_landmarks);
828        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
829
830        // Update indices
831        let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
832        let base_index = current_landmarks.nrows();
833        for &idx in selected_indices {
834            new_component_indices.push(base_index + idx);
835        }
836
837        self.components_ = Some(components);
838        self.normalization_ = Some(normalization);
839        self.component_indices_ = Some(new_component_indices);
840        self.landmark_data_ = Some(combined_landmarks);
841
842        Ok(self)
843    }
844
845    /// Sliding window update (replace oldest landmarks)
846    fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
847        let current_landmarks = self.landmark_data_.as_ref().unwrap();
848        let n_new = new_data.nrows().min(self.n_components);
849
850        if n_new == 0 {
851            return Ok(self);
852        }
853
854        // Select new landmarks
855        let mut rng = match self.random_state {
856            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
857            None => RealStdRng::from_seed(thread_rng().gen()),
858        };
859
860        let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
861        indices.shuffle(&mut rng);
862        let selected_indices = &indices[..n_new];
863
864        let new_landmarks = self.extract_landmarks(new_data, selected_indices);
865
866        // Replace oldest landmarks with new ones
867        let n_keep = self.n_components - n_new;
868        let combined_landmarks = if n_keep > 0 {
869            let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
870            scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
871        } else {
872            new_landmarks
873        };
874
875        // Recompute decomposition
876        let kernel_matrix = self
877            .kernel
878            .compute_kernel(&combined_landmarks, &combined_landmarks);
879        let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
880
881        // Update component indices (simplified)
882        let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
883
884        self.components_ = Some(components);
885        self.normalization_ = Some(normalization);
886        self.component_indices_ = Some(new_component_indices);
887        self.landmark_data_ = Some(combined_landmarks);
888
889        Ok(self)
890    }
891
892    /// Merge update (combine approximations)
893    fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
894        // Sophisticated merging strategy that combines existing and new Nyström approximations
895        // This is based on the idea of merging two kernel approximations optimally
896
897        let current_landmarks = self.landmark_data_.as_ref().unwrap();
898        let current_components = self.components_.as_ref().unwrap();
899        let current_normalization = self.normalization_.as_ref().unwrap();
900
901        // Step 1: Create a new Nyström approximation from the new data
902        let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
903
904        let mut rng = match self.random_state {
905            Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
906            None => RealStdRng::from_seed(thread_rng().gen()),
907        };
908
909        // Select new landmarks using the same strategy
910        let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
911        let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
912
913        // Compute new kernel matrix and decomposition
914        let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
915        let (new_components, new_normalization) = self.compute_decomposition(new_kernel_matrix)?;
916
917        // Step 2: Combine the landmarks intelligently
918        // Merge by selecting the most diverse/informative landmarks from both sets
919        let merged_landmarks =
920            self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
921
922        // Step 3: Recompute the full approximation on merged landmarks
923        let merged_kernel_matrix = self
924            .kernel
925            .compute_kernel(&merged_landmarks, &merged_landmarks);
926        let (final_components, final_normalization) =
927            self.compute_decomposition(merged_kernel_matrix)?;
928
929        // Update component indices (simplified for merged case)
930        let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
931
932        let mut updated_self = self;
933        updated_self.components_ = Some(final_components);
934        updated_self.normalization_ = Some(final_normalization);
935        updated_self.component_indices_ = Some(final_component_indices);
936        updated_self.landmark_data_ = Some(merged_landmarks);
937
938        Ok(updated_self)
939    }
940
941    /// Intelligently merge landmarks from existing and new data
942    fn merge_landmarks_intelligently(
943        &self,
944        current_landmarks: &Array2<Float>,
945        new_landmarks: &Array2<Float>,
946        rng: &mut RealStdRng,
947    ) -> Result<Array2<Float>> {
948        let n_current = current_landmarks.nrows();
949        let n_new = new_landmarks.nrows();
950        let n_features = current_landmarks.ncols();
951
952        // Combine all landmarks temporarily
953        let all_landmarks = scirs2_core::ndarray::concatenate![
954            Axis(0),
955            current_landmarks.clone(),
956            new_landmarks.clone()
957        ];
958
959        // Use diversity-based selection to choose the best subset
960        let n_target = self.n_components.min(n_current + n_new);
961        let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
962
963        // Extract selected landmarks
964        let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
965        for (i, &idx) in selected_indices.iter().enumerate() {
966            merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
967        }
968
969        Ok(merged_landmarks)
970    }
971
972    /// Select diverse landmarks using maximum distance criterion
973    fn select_diverse_landmarks(
974        &self,
975        landmarks: &Array2<Float>,
976        n_select: usize,
977        rng: &mut RealStdRng,
978    ) -> Result<Vec<usize>> {
979        let n_landmarks = landmarks.nrows();
980
981        if n_select >= n_landmarks {
982            return Ok((0..n_landmarks).collect());
983        }
984
985        let mut selected = Vec::new();
986        let mut available: Vec<usize> = (0..n_landmarks).collect();
987
988        // Start with a random landmark
989        let first_idx = rng.gen_range(0..available.len());
990        selected.push(available.remove(first_idx));
991
992        // Greedily select landmarks that are maximally distant from already selected ones
993        while selected.len() < n_select && !available.is_empty() {
994            let mut best_idx = 0;
995            let mut max_min_distance = 0.0;
996
997            for (i, &candidate_idx) in available.iter().enumerate() {
998                // Compute minimum distance to already selected landmarks
999                let mut min_distance = Float::INFINITY;
1000
1001                for &selected_idx in &selected {
1002                    let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1003                    let distance = diff.dot(&diff).sqrt();
1004                    if distance < min_distance {
1005                        min_distance = distance;
1006                    }
1007                }
1008
1009                if min_distance > max_min_distance {
1010                    max_min_distance = min_distance;
1011                    best_idx = i;
1012                }
1013            }
1014
1015            selected.push(available.remove(best_idx));
1016        }
1017
1018        Ok(selected)
1019    }
1020
1021    /// Selective update based on approximation quality
1022    fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1023        // Quality-based selective update that only incorporates new data if it improves approximation
1024
1025        let current_landmarks = self.landmark_data_.as_ref().unwrap();
1026
1027        // Step 1: Evaluate current approximation quality on new data
1028        let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1029
1030        // Step 2: Create candidate updates and evaluate their quality
1031        let mut best_update = self.clone();
1032        let mut best_quality = current_quality;
1033
1034        // Try append update
1035        let append_candidate = self.clone().append_update(new_data)?;
1036        let append_quality = append_candidate.evaluate_approximation_quality(
1037            append_candidate.landmark_data_.as_ref().unwrap(),
1038            new_data,
1039        )?;
1040
1041        if append_quality > best_quality + threshold {
1042            best_update = append_candidate;
1043            best_quality = append_quality;
1044        }
1045
1046        // Try merge update if we have enough data
1047        if new_data.nrows() >= 3 {
1048            let merge_candidate = self.clone().merge_update(new_data)?;
1049            let merge_quality = merge_candidate.evaluate_approximation_quality(
1050                merge_candidate.landmark_data_.as_ref().unwrap(),
1051                new_data,
1052            )?;
1053
1054            if merge_quality > best_quality + threshold {
1055                best_update = merge_candidate;
1056                best_quality = merge_quality;
1057            }
1058        }
1059
1060        // Try sliding window update
1061        let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1062        let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1063            sliding_candidate.landmark_data_.as_ref().unwrap(),
1064            new_data,
1065        )?;
1066
1067        if sliding_quality > best_quality + threshold {
1068            best_update = sliding_candidate;
1069            best_quality = sliding_quality;
1070        }
1071
1072        // Step 3: Only update if quality improvement exceeds threshold
1073        if best_quality > current_quality + threshold {
1074            Ok(best_update)
1075        } else {
1076            // No significant improvement, keep current state
1077            Ok(self)
1078        }
1079    }
1080
1081    /// Evaluate approximation quality using kernel approximation error
1082    fn evaluate_approximation_quality(
1083        &self,
1084        landmarks: &Array2<Float>,
1085        test_data: &Array2<Float>,
1086    ) -> Result<Float> {
1087        // Quality metric: negative approximation error (higher is better)
1088
1089        let n_test = test_data.nrows().min(50); // Limit for efficiency
1090        let test_subset = if test_data.nrows() > n_test {
1091            // Sample random subset for evaluation
1092            let mut rng = thread_rng();
1093            let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1094            indices.shuffle(&mut rng);
1095            test_data.select(Axis(0), &indices[..n_test])
1096        } else {
1097            test_data.to_owned()
1098        };
1099
1100        // Compute exact kernel matrix for test subset
1101        let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1102
1103        // Compute Nyström approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1104        let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1105        let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1106
1107        // Use our eigendecomposition to compute pseudo-inverse
1108        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1109
1110        // Construct pseudo-inverse
1111        let threshold = 1e-8;
1112        let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1113
1114        for i in 0..landmarks.nrows() {
1115            for j in 0..landmarks.nrows() {
1116                let mut sum = 0.0;
1117                for k in 0..eigenvals.len() {
1118                    if eigenvals[k] > threshold {
1119                        sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1120                    }
1121                }
1122                pseudo_inverse[[i, j]] = sum;
1123            }
1124        }
1125
1126        // Compute approximation: K(X,Z) * K(Z,Z)^(-1) * K(Z,X)
1127        let k_approx = k_test_landmarks
1128            .dot(&pseudo_inverse)
1129            .dot(&k_test_landmarks.t());
1130
1131        // Compute approximation error (Frobenius norm)
1132        let error_matrix = &k_exact - &k_approx;
1133        let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1134
1135        // Convert to quality score (negative error, higher is better)
1136        let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1137
1138        Ok(quality)
1139    }
1140
1141    /// Extract landmark data points
1142    fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1143        let (_, n_features) = x.dim();
1144        let mut landmarks = Array2::zeros((indices.len(), n_features));
1145
1146        for (i, &idx) in indices.iter().enumerate() {
1147            landmarks.row_mut(i).assign(&x.row(idx));
1148        }
1149
1150        landmarks
1151    }
1152
1153    /// Compute eigendecomposition using power iteration method
1154    /// Returns (eigenvalues, eigenvectors) for symmetric matrix
1155    fn compute_eigendecomposition(
1156        &self,
1157        matrix: Array2<Float>,
1158    ) -> Result<(Array1<Float>, Array2<Float>)> {
1159        let n = matrix.nrows();
1160
1161        if n != matrix.ncols() {
1162            return Err(SklearsError::InvalidInput(
1163                "Matrix must be square for eigendecomposition".to_string(),
1164            ));
1165        }
1166
1167        let mut eigenvals = Array1::zeros(n);
1168        let mut eigenvecs = Array2::zeros((n, n));
1169
1170        // Use deflation method to find multiple eigenvalues
1171        let mut deflated_matrix = matrix.clone();
1172
1173        for k in 0..n {
1174            // Power iteration for k-th eigenvalue/eigenvector
1175            let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1176
1177            eigenvals[k] = eigenval;
1178            eigenvecs.column_mut(k).assign(&eigenvec);
1179
1180            // Deflate matrix: A_new = A - λ * v * v^T
1181            for i in 0..n {
1182                for j in 0..n {
1183                    deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1184                }
1185            }
1186        }
1187
1188        // Sort eigenvalues and eigenvectors in descending order
1189        let mut indices: Vec<usize> = (0..n).collect();
1190        indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1191
1192        let mut sorted_eigenvals = Array1::zeros(n);
1193        let mut sorted_eigenvecs = Array2::zeros((n, n));
1194
1195        for (new_idx, &old_idx) in indices.iter().enumerate() {
1196            sorted_eigenvals[new_idx] = eigenvals[old_idx];
1197            sorted_eigenvecs
1198                .column_mut(new_idx)
1199                .assign(&eigenvecs.column(old_idx));
1200        }
1201
1202        Ok((sorted_eigenvals, sorted_eigenvecs))
1203    }
1204
1205    /// Power iteration method to find dominant eigenvalue and eigenvector
1206    fn power_iteration(
1207        &self,
1208        matrix: &Array2<Float>,
1209        max_iter: usize,
1210        tol: Float,
1211    ) -> Result<(Float, Array1<Float>)> {
1212        let n = matrix.nrows();
1213
1214        // Initialize with deterministic vector to ensure reproducibility
1215        let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1216
1217        // Normalize
1218        let norm = v.dot(&v).sqrt();
1219        if norm < 1e-10 {
1220            return Err(SklearsError::InvalidInput(
1221                "Initial vector has zero norm".to_string(),
1222            ));
1223        }
1224        v /= norm;
1225
1226        let mut eigenval = 0.0;
1227
1228        for _iter in 0..max_iter {
1229            // Apply matrix
1230            let w = matrix.dot(&v);
1231
1232            // Compute Rayleigh quotient
1233            let new_eigenval = v.dot(&w);
1234
1235            // Normalize
1236            let w_norm = w.dot(&w).sqrt();
1237            if w_norm < 1e-10 {
1238                break;
1239            }
1240            let new_v = w / w_norm;
1241
1242            // Check convergence
1243            let eigenval_change = (new_eigenval - eigenval).abs();
1244            let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1245
1246            if eigenval_change < tol && vector_change < tol {
1247                return Ok((new_eigenval, new_v));
1248            }
1249
1250            eigenval = new_eigenval;
1251            v = new_v;
1252        }
1253
1254        Ok((eigenval, v))
1255    }
1256
1257    /// Compute eigendecomposition of kernel matrix
1258    fn compute_decomposition(
1259        &self,
1260        mut kernel_matrix: Array2<Float>,
1261    ) -> Result<(Array2<Float>, Array2<Float>)> {
1262        // Add small regularization to diagonal for numerical stability
1263        let reg = 1e-8;
1264        for i in 0..kernel_matrix.nrows() {
1265            kernel_matrix[[i, i]] += reg;
1266        }
1267
1268        // Proper eigendecomposition for Nyström method
1269        let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1270
1271        // Filter out small eigenvalues for numerical stability
1272        let threshold = 1e-8;
1273        let valid_indices: Vec<usize> = eigenvals
1274            .iter()
1275            .enumerate()
1276            .filter(|(_, &val)| val > threshold)
1277            .map(|(i, _)| i)
1278            .collect();
1279
1280        if valid_indices.is_empty() {
1281            return Err(SklearsError::InvalidInput(
1282                "No valid eigenvalues found in kernel matrix".to_string(),
1283            ));
1284        }
1285
1286        // Construct components and normalization matrices
1287        let n_valid = valid_indices.len();
1288        let mut components = Array2::zeros((eigenvals.len(), n_valid));
1289        let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1290
1291        for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1292            let sqrt_eigenval = eigenvals[old_idx].sqrt();
1293            components
1294                .column_mut(new_idx)
1295                .assign(&eigenvecs.column(old_idx));
1296
1297            // For Nyström method: normalization = V * Λ^(-1/2)
1298            for i in 0..eigenvals.len() {
1299                normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1300            }
1301        }
1302
1303        Ok((components, normalization))
1304    }
1305
1306    /// Get number of updates performed
1307    pub fn update_count(&self) -> usize {
1308        self.update_count_
1309    }
1310
1311    /// Get current number of landmarks
1312    pub fn n_landmarks(&self) -> usize {
1313        self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1314    }
1315}
1316
1317impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1318    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1319        let components = self
1320            .components_
1321            .as_ref()
1322            .ok_or_else(|| SklearsError::NotFitted {
1323                operation: "transform".to_string(),
1324            })?;
1325
1326        let normalization =
1327            self.normalization_
1328                .as_ref()
1329                .ok_or_else(|| SklearsError::NotFitted {
1330                    operation: "transform".to_string(),
1331                })?;
1332
1333        let landmark_data =
1334            self.landmark_data_
1335                .as_ref()
1336                .ok_or_else(|| SklearsError::NotFitted {
1337                    operation: "transform".to_string(),
1338                })?;
1339
1340        // Compute kernel between input and landmarks
1341        let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1342
1343        // Apply transformation: K(X, landmarks) @ normalization.T
1344        let transformed = kernel_x_landmarks.dot(&normalization.t());
1345
1346        Ok(transformed)
1347    }
1348}
1349
1350#[allow(non_snake_case)]
1351#[cfg(test)]
1352mod tests {
1353    use super::*;
1354    use scirs2_core::ndarray::array;
1355
1356    #[test]
1357    fn test_incremental_nystroem_basic() {
1358        let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1359        let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1360
1361        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1362            .update_strategy(UpdateStrategy::Append)
1363            .min_update_size(1);
1364
1365        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1366        assert_eq!(fitted.n_landmarks(), 3);
1367
1368        let updated = fitted.update(&x_new).unwrap();
1369        assert_eq!(updated.n_landmarks(), 5);
1370        assert_eq!(updated.update_count(), 1);
1371    }
1372
1373    #[test]
1374    fn test_incremental_transform() {
1375        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1376        let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1377
1378        let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1379        let fitted = nystroem.fit(&x_train, &()).unwrap();
1380
1381        let transformed = fitted.transform(&x_test).unwrap();
1382        assert_eq!(transformed.shape()[0], 2);
1383        assert!(transformed.shape()[1] <= 3);
1384    }
1385
1386    #[test]
1387    fn test_sliding_window_update() {
1388        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1389        let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1390
1391        let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1392            .update_strategy(UpdateStrategy::SlidingWindow)
1393            .min_update_size(1);
1394
1395        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1396        let updated = fitted.update(&x_new).unwrap();
1397
1398        assert_eq!(updated.n_landmarks(), 3);
1399        assert_eq!(updated.update_count(), 1);
1400    }
1401
1402    #[test]
1403    fn test_different_kernels() {
1404        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1405
1406        // Test with RBF kernel
1407        let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1408        let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1409        let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1410        assert_eq!(rbf_transformed.shape()[0], 3);
1411
1412        // Test with polynomial kernel
1413        let poly_nystroem = IncrementalNystroem::new(
1414            Kernel::Polynomial {
1415                gamma: 1.0,
1416                coef0: 1.0,
1417                degree: 2,
1418            },
1419            3,
1420        );
1421        let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1422        let poly_transformed = poly_fitted.transform(&x).unwrap();
1423        assert_eq!(poly_transformed.shape()[0], 3);
1424    }
1425
1426    #[test]
1427    fn test_min_update_size() {
1428        let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1429        let x_small = array![[3.0, 4.0]];
1430        let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1431
1432        let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1433
1434        let fitted = nystroem.fit(&x_initial, &()).unwrap();
1435
1436        // Small update should not trigger recomputation
1437        let after_small = fitted.update(&x_small).unwrap();
1438        assert_eq!(after_small.update_count(), 0);
1439        assert_eq!(after_small.n_landmarks(), 2);
1440
1441        // Large update should trigger recomputation
1442        let after_large = after_small.update(&x_large).unwrap();
1443        assert_eq!(after_large.update_count(), 1);
1444        assert_eq!(after_large.n_landmarks(), 5);
1445    }
1446
1447    #[test]
1448    fn test_reproducibility() {
1449        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1450        let x_new = array![[4.0, 5.0]];
1451
1452        let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1453            .random_state(42)
1454            .min_update_size(1);
1455        let fitted1 = nystroem1.fit(&x, &()).unwrap();
1456        let updated1 = fitted1.update(&x_new).unwrap();
1457        let result1 = updated1.transform(&x).unwrap();
1458
1459        let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1460            .random_state(42)
1461            .min_update_size(1);
1462        let fitted2 = nystroem2.fit(&x, &()).unwrap();
1463        let updated2 = fitted2.update(&x_new).unwrap();
1464        let result2 = updated2.transform(&x).unwrap();
1465
1466        // Results should be very similar with same random seed (allowing for numerical precision)
1467        // Note: eigendecomposition can produce results that differ by a sign flip
1468        assert_eq!(result1.shape(), result2.shape());
1469
1470        // Check if results are similar or similar up to sign flip
1471        let mut direct_match = true;
1472        let mut sign_flip_match = true;
1473
1474        for i in 0..result1.len() {
1475            let val1 = result1.as_slice().unwrap()[i];
1476            let val2 = result2.as_slice().unwrap()[i];
1477
1478            if (val1 - val2).abs() > 1e-6 {
1479                direct_match = false;
1480            }
1481            if (val1 + val2).abs() > 1e-6 {
1482                sign_flip_match = false;
1483            }
1484        }
1485
1486        assert!(
1487            direct_match || sign_flip_match,
1488            "Results differ too much and are not related by sign flip"
1489        );
1490    }
1491}