1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::RngExt;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12 error::{Result, SklearsError},
13 traits::{Estimator, Fit, Trained, Transform, Untrained},
14 types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20#[derive(Debug, Clone)]
22pub enum UpdateStrategy {
24 Append,
26 SlidingWindow,
28 Merge,
30 Selective { threshold: Float },
32}
33
34#[derive(Debug, Clone)]
47pub struct IncrementalNystroem<State = Untrained> {
48 pub kernel: Kernel,
49 pub n_components: usize,
50 pub update_strategy: UpdateStrategy,
51 pub min_update_size: usize,
52 pub sampling_strategy: SamplingStrategy,
53 pub random_state: Option<u64>,
54
55 components_: Option<Array2<Float>>,
57 normalization_: Option<Array2<Float>>,
58 component_indices_: Option<Vec<usize>>,
59 landmark_data_: Option<Array2<Float>>,
60 update_count_: usize,
61 accumulated_data_: Option<Array2<Float>>,
62
63 _state: PhantomData<State>,
64}
65
66impl IncrementalNystroem<Untrained> {
67 pub fn new(kernel: Kernel, n_components: usize) -> Self {
69 Self {
70 kernel,
71 n_components,
72 update_strategy: UpdateStrategy::Append,
73 min_update_size: 10,
74 sampling_strategy: SamplingStrategy::Random,
75 random_state: None,
76 components_: None,
77 normalization_: None,
78 component_indices_: None,
79 landmark_data_: None,
80 update_count_: 0,
81 accumulated_data_: None,
82 _state: PhantomData,
83 }
84 }
85
86 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
88 self.update_strategy = strategy;
89 self
90 }
91
92 pub fn min_update_size(mut self, size: usize) -> Self {
94 self.min_update_size = size;
95 self
96 }
97
98 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
100 self.sampling_strategy = strategy;
101 self
102 }
103
104 pub fn random_state(mut self, seed: u64) -> Self {
106 self.random_state = Some(seed);
107 self
108 }
109}
110
111impl Estimator for IncrementalNystroem<Untrained> {
112 type Config = ();
113 type Error = SklearsError;
114 type Float = Float;
115
116 fn config(&self) -> &Self::Config {
117 &()
118 }
119}
120
121impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
122 type Fitted = IncrementalNystroem<Trained>;
123
124 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
125 let (n_samples, _n_features) = x.dim();
126 let n_components = self.n_components.min(n_samples);
127
128 let mut rng = match self.random_state {
129 Some(seed) => RealStdRng::seed_from_u64(seed),
130 None => RealStdRng::from_seed(thread_rng().random()),
131 };
132
133 let component_indices = self.select_components(x, n_components, &mut rng)?;
135 let landmark_data = self.extract_landmarks(x, &component_indices);
136
137 let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
139
140 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
142
143 Ok(IncrementalNystroem {
144 kernel: self.kernel,
145 n_components: self.n_components,
146 update_strategy: self.update_strategy,
147 min_update_size: self.min_update_size,
148 sampling_strategy: self.sampling_strategy,
149 random_state: self.random_state,
150 components_: Some(components),
151 normalization_: Some(normalization),
152 component_indices_: Some(component_indices),
153 landmark_data_: Some(landmark_data),
154 update_count_: 0,
155 accumulated_data_: None,
156 _state: PhantomData,
157 })
158 }
159}
160
161impl IncrementalNystroem<Untrained> {
162 fn select_components(
164 &self,
165 x: &Array2<Float>,
166 n_components: usize,
167 rng: &mut RealStdRng,
168 ) -> Result<Vec<usize>> {
169 let (n_samples, _) = x.dim();
170
171 match &self.sampling_strategy {
172 SamplingStrategy::Random => {
173 let mut indices: Vec<usize> = (0..n_samples).collect();
174 indices.shuffle(rng);
175 Ok(indices[..n_components].to_vec())
176 }
177 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
178 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
179 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
180 }
181 }
182
183 fn kmeans_sampling(
185 &self,
186 x: &Array2<Float>,
187 n_components: usize,
188 rng: &mut RealStdRng,
189 ) -> Result<Vec<usize>> {
190 let (n_samples, n_features) = x.dim();
191 let mut centers = Array2::zeros((n_components, n_features));
192
193 let mut indices: Vec<usize> = (0..n_samples).collect();
195 indices.shuffle(rng);
196 for (i, &idx) in indices[..n_components].iter().enumerate() {
197 centers.row_mut(i).assign(&x.row(idx));
198 }
199
200 for _iter in 0..5 {
202 let mut assignments = vec![0; n_samples];
203
204 for i in 0..n_samples {
206 let mut min_dist = Float::INFINITY;
207 let mut best_center = 0;
208
209 for j in 0..n_components {
210 let diff = &x.row(i) - ¢ers.row(j);
211 let dist = diff.dot(&diff);
212 if dist < min_dist {
213 min_dist = dist;
214 best_center = j;
215 }
216 }
217 assignments[i] = best_center;
218 }
219
220 for j in 0..n_components {
222 let cluster_points: Vec<usize> = assignments
223 .iter()
224 .enumerate()
225 .filter(|(_, &assignment)| assignment == j)
226 .map(|(i, _)| i)
227 .collect();
228
229 if !cluster_points.is_empty() {
230 let mut new_center = Array1::zeros(n_features);
231 for &point_idx in &cluster_points {
232 new_center = new_center + x.row(point_idx);
233 }
234 new_center /= cluster_points.len() as Float;
235 centers.row_mut(j).assign(&new_center);
236 }
237 }
238 }
239
240 let mut selected_indices = Vec::new();
242 for j in 0..n_components {
243 let mut min_dist = Float::INFINITY;
244 let mut best_point = 0;
245
246 for i in 0..n_samples {
247 let diff = &x.row(i) - ¢ers.row(j);
248 let dist = diff.dot(&diff);
249 if dist < min_dist {
250 min_dist = dist;
251 best_point = i;
252 }
253 }
254 selected_indices.push(best_point);
255 }
256
257 selected_indices.sort_unstable();
258 selected_indices.dedup();
259
260 while selected_indices.len() < n_components {
262 let random_idx = rng.random_range(0..n_samples);
263 if !selected_indices.contains(&random_idx) {
264 selected_indices.push(random_idx);
265 }
266 }
267
268 Ok(selected_indices[..n_components].to_vec())
269 }
270
271 fn leverage_score_sampling(
273 &self,
274 x: &Array2<Float>,
275 n_components: usize,
276 _rng: &mut RealStdRng,
277 ) -> Result<Vec<usize>> {
278 let (n_samples, _) = x.dim();
279
280 let mut scores = Vec::new();
283 for i in 0..n_samples {
284 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
285 scores.push(row_norm + 1e-10); }
287
288 let total_score: Float = scores.iter().sum();
290 if total_score <= 0.0 {
291 return Err(SklearsError::InvalidInput(
292 "All scores are zero or negative".to_string(),
293 ));
294 }
295
296 let mut cumulative = Vec::with_capacity(scores.len());
298 let mut sum = 0.0;
299 for &score in &scores {
300 sum += score / total_score;
301 cumulative.push(sum);
302 }
303
304 let mut selected_indices = Vec::new();
305 for _ in 0..n_components {
306 let r = thread_rng().random::<Float>();
307 let mut idx = cumulative
309 .iter()
310 .position(|&cum| cum >= r)
311 .unwrap_or(scores.len() - 1);
312
313 while selected_indices.contains(&idx) {
315 let r = thread_rng().random::<Float>();
316 idx = cumulative
317 .iter()
318 .position(|&cum| cum >= r)
319 .unwrap_or(scores.len() - 1);
320 }
321 selected_indices.push(idx);
322 }
323
324 Ok(selected_indices)
325 }
326
327 fn column_norm_sampling(
329 &self,
330 x: &Array2<Float>,
331 n_components: usize,
332 rng: &mut RealStdRng,
333 ) -> Result<Vec<usize>> {
334 let (n_samples, _) = x.dim();
335
336 let mut norms = Vec::new();
338 for i in 0..n_samples {
339 let norm = x.row(i).dot(&x.row(i)).sqrt();
340 norms.push(norm + 1e-10);
341 }
342
343 let mut indices_with_norms: Vec<(usize, Float)> = norms
345 .iter()
346 .enumerate()
347 .map(|(i, &norm)| (i, norm))
348 .collect();
349 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
350
351 let mut selected_indices = Vec::new();
352 let step = n_samples.max(1) / n_components.max(1);
353
354 for i in 0..n_components {
355 let idx = (i * step).min(n_samples - 1);
356 selected_indices.push(indices_with_norms[idx].0);
357 }
358
359 while selected_indices.len() < n_components {
361 let random_idx = rng.random_range(0..n_samples);
362 if !selected_indices.contains(&random_idx) {
363 selected_indices.push(random_idx);
364 }
365 }
366
367 Ok(selected_indices)
368 }
369
370 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
372 let (_, n_features) = x.dim();
373 let mut landmarks = Array2::zeros((indices.len(), n_features));
374
375 for (i, &idx) in indices.iter().enumerate() {
376 landmarks.row_mut(i).assign(&x.row(idx));
377 }
378
379 landmarks
380 }
381
382 fn compute_eigendecomposition(
385 &self,
386 matrix: Array2<Float>,
387 ) -> Result<(Array1<Float>, Array2<Float>)> {
388 let n = matrix.nrows();
389
390 if n != matrix.ncols() {
391 return Err(SklearsError::InvalidInput(
392 "Matrix must be square for eigendecomposition".to_string(),
393 ));
394 }
395
396 let mut eigenvals = Array1::zeros(n);
397 let mut eigenvecs = Array2::zeros((n, n));
398
399 let mut deflated_matrix = matrix.clone();
401
402 for k in 0..n {
403 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
405
406 eigenvals[k] = eigenval;
407 eigenvecs.column_mut(k).assign(&eigenvec);
408
409 for i in 0..n {
411 for j in 0..n {
412 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
413 }
414 }
415 }
416
417 let mut indices: Vec<usize> = (0..n).collect();
419 indices.sort_by(|&i, &j| {
420 eigenvals[j]
421 .partial_cmp(&eigenvals[i])
422 .expect("operation should succeed")
423 });
424
425 let mut sorted_eigenvals = Array1::zeros(n);
426 let mut sorted_eigenvecs = Array2::zeros((n, n));
427
428 for (new_idx, &old_idx) in indices.iter().enumerate() {
429 sorted_eigenvals[new_idx] = eigenvals[old_idx];
430 sorted_eigenvecs
431 .column_mut(new_idx)
432 .assign(&eigenvecs.column(old_idx));
433 }
434
435 Ok((sorted_eigenvals, sorted_eigenvecs))
436 }
437
438 fn power_iteration(
440 &self,
441 matrix: &Array2<Float>,
442 max_iter: usize,
443 tol: Float,
444 ) -> Result<(Float, Array1<Float>)> {
445 let n = matrix.nrows();
446
447 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
449
450 let norm = v.dot(&v).sqrt();
452 if norm < 1e-10 {
453 return Err(SklearsError::InvalidInput(
454 "Initial vector has zero norm".to_string(),
455 ));
456 }
457 v /= norm;
458
459 let mut eigenval = 0.0;
460
461 for _iter in 0..max_iter {
462 let w = matrix.dot(&v);
464
465 let new_eigenval = v.dot(&w);
467
468 let w_norm = w.dot(&w).sqrt();
470 if w_norm < 1e-10 {
471 break;
472 }
473 let new_v = w / w_norm;
474
475 let eigenval_change = (new_eigenval - eigenval).abs();
477 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
478
479 if eigenval_change < tol && vector_change < tol {
480 return Ok((new_eigenval, new_v));
481 }
482
483 eigenval = new_eigenval;
484 v = new_v;
485 }
486
487 Ok((eigenval, v))
488 }
489
490 fn compute_decomposition(
492 &self,
493 mut kernel_matrix: Array2<Float>,
494 ) -> Result<(Array2<Float>, Array2<Float>)> {
495 let reg = 1e-8;
497 for i in 0..kernel_matrix.nrows() {
498 kernel_matrix[[i, i]] += reg;
499 }
500
501 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
503
504 let threshold = 1e-8;
506 let valid_indices: Vec<usize> = eigenvals
507 .iter()
508 .enumerate()
509 .filter(|(_, &val)| val > threshold)
510 .map(|(i, _)| i)
511 .collect();
512
513 if valid_indices.is_empty() {
514 return Err(SklearsError::InvalidInput(
515 "No valid eigenvalues found in kernel matrix".to_string(),
516 ));
517 }
518
519 let n_valid = valid_indices.len();
521 let mut components = Array2::zeros((eigenvals.len(), n_valid));
522 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
523
524 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
525 let sqrt_eigenval = eigenvals[old_idx].sqrt();
526 components
527 .column_mut(new_idx)
528 .assign(&eigenvecs.column(old_idx));
529
530 for i in 0..eigenvals.len() {
532 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
533 }
534 }
535
536 Ok((components, normalization))
537 }
538}
539
540impl IncrementalNystroem<Trained> {
541 fn select_components(
543 &self,
544 x: &Array2<Float>,
545 n_components: usize,
546 rng: &mut RealStdRng,
547 ) -> Result<Vec<usize>> {
548 let (n_samples, _) = x.dim();
549
550 match &self.sampling_strategy {
551 SamplingStrategy::Random => {
552 let mut indices: Vec<usize> = (0..n_samples).collect();
553 indices.shuffle(rng);
554 Ok(indices[..n_components].to_vec())
555 }
556 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
557 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
558 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
559 }
560 }
561
562 fn kmeans_sampling(
564 &self,
565 x: &Array2<Float>,
566 n_components: usize,
567 rng: &mut RealStdRng,
568 ) -> Result<Vec<usize>> {
569 let (n_samples, n_features) = x.dim();
570 let mut centers = Array2::zeros((n_components, n_features));
571
572 let mut indices: Vec<usize> = (0..n_samples).collect();
574 indices.shuffle(rng);
575 for (i, &idx) in indices[..n_components].iter().enumerate() {
576 centers.row_mut(i).assign(&x.row(idx));
577 }
578
579 for _iter in 0..5 {
581 let mut assignments = vec![0; n_samples];
582
583 for i in 0..n_samples {
585 let mut min_dist = Float::INFINITY;
586 let mut best_center = 0;
587
588 for j in 0..n_components {
589 let diff = &x.row(i) - ¢ers.row(j);
590 let dist = diff.dot(&diff);
591 if dist < min_dist {
592 min_dist = dist;
593 best_center = j;
594 }
595 }
596 assignments[i] = best_center;
597 }
598
599 for j in 0..n_components {
601 let cluster_points: Vec<usize> = assignments
602 .iter()
603 .enumerate()
604 .filter(|(_, &assignment)| assignment == j)
605 .map(|(i, _)| i)
606 .collect();
607
608 if !cluster_points.is_empty() {
609 let mut new_center = Array1::zeros(n_features);
610 for &point_idx in &cluster_points {
611 new_center = new_center + x.row(point_idx);
612 }
613 new_center /= cluster_points.len() as Float;
614 centers.row_mut(j).assign(&new_center);
615 }
616 }
617 }
618
619 let mut selected_indices = Vec::new();
621 for j in 0..n_components {
622 let mut min_dist = Float::INFINITY;
623 let mut best_point = 0;
624
625 for i in 0..n_samples {
626 let diff = &x.row(i) - ¢ers.row(j);
627 let dist = diff.dot(&diff);
628 if dist < min_dist {
629 min_dist = dist;
630 best_point = i;
631 }
632 }
633 selected_indices.push(best_point);
634 }
635
636 selected_indices.sort_unstable();
637 selected_indices.dedup();
638
639 while selected_indices.len() < n_components {
641 let random_idx = rng.random_range(0..n_samples);
642 if !selected_indices.contains(&random_idx) {
643 selected_indices.push(random_idx);
644 }
645 }
646
647 Ok(selected_indices[..n_components].to_vec())
648 }
649
650 fn leverage_score_sampling(
652 &self,
653 x: &Array2<Float>,
654 n_components: usize,
655 _rng: &mut RealStdRng,
656 ) -> Result<Vec<usize>> {
657 let (n_samples, _) = x.dim();
658
659 let mut scores = Vec::new();
662 for i in 0..n_samples {
663 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
664 scores.push(row_norm + 1e-10); }
666
667 let total_score: Float = scores.iter().sum();
669 if total_score <= 0.0 {
670 return Err(SklearsError::InvalidInput(
671 "All scores are zero or negative".to_string(),
672 ));
673 }
674
675 let mut cumulative = Vec::with_capacity(scores.len());
677 let mut sum = 0.0;
678 for &score in &scores {
679 sum += score / total_score;
680 cumulative.push(sum);
681 }
682
683 let mut selected_indices = Vec::new();
684 for _ in 0..n_components {
685 let r = thread_rng().random::<Float>();
686 let mut idx = cumulative
688 .iter()
689 .position(|&cum| cum >= r)
690 .unwrap_or(scores.len() - 1);
691
692 while selected_indices.contains(&idx) {
694 let r = thread_rng().random::<Float>();
695 idx = cumulative
696 .iter()
697 .position(|&cum| cum >= r)
698 .unwrap_or(scores.len() - 1);
699 }
700 selected_indices.push(idx);
701 }
702
703 Ok(selected_indices)
704 }
705
706 fn column_norm_sampling(
708 &self,
709 x: &Array2<Float>,
710 n_components: usize,
711 rng: &mut RealStdRng,
712 ) -> Result<Vec<usize>> {
713 let (n_samples, _) = x.dim();
714
715 let mut norms = Vec::new();
717 for i in 0..n_samples {
718 let norm = x.row(i).dot(&x.row(i)).sqrt();
719 norms.push(norm + 1e-10);
720 }
721
722 let mut indices_with_norms: Vec<(usize, Float)> = norms
724 .iter()
725 .enumerate()
726 .map(|(i, &norm)| (i, norm))
727 .collect();
728 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
729
730 let mut selected_indices = Vec::new();
731 let step = n_samples.max(1) / n_components.max(1);
732
733 for i in 0..n_components {
734 let idx = (i * step).min(n_samples - 1);
735 selected_indices.push(indices_with_norms[idx].0);
736 }
737
738 while selected_indices.len() < n_components {
740 let random_idx = rng.random_range(0..n_samples);
741 if !selected_indices.contains(&random_idx) {
742 selected_indices.push(random_idx);
743 }
744 }
745
746 Ok(selected_indices)
747 }
748
749 pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
751 match &self.accumulated_data_ {
753 Some(existing) => {
754 let combined =
755 scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
756 self.accumulated_data_ = Some(combined);
757 }
758 None => {
759 self.accumulated_data_ = Some(x_new.clone());
760 }
761 }
762
763 let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
765 accumulated.nrows() >= self.min_update_size
766 } else {
767 false
768 };
769
770 if should_update {
771 if let Some(accumulated) = self.accumulated_data_.take() {
772 self = self.perform_update(&accumulated)?;
773 self.update_count_ += 1;
774 }
775 }
776
777 Ok(self)
778 }
779
780 fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
782 match self.update_strategy.clone() {
783 UpdateStrategy::Append => self.append_update(new_data),
784 UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
785 UpdateStrategy::Merge => self.merge_update(new_data),
786 UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
787 }
788 }
789
790 fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
792 let current_landmarks = self
793 .landmark_data_
794 .as_ref()
795 .expect("operation should succeed");
796 let current_components = current_landmarks.nrows();
797
798 if current_components >= self.n_components {
799 return Ok(self);
801 }
802
803 let available_space = self.n_components - current_components;
804 let n_new = available_space.min(new_data.nrows());
805
806 if n_new == 0 {
807 return Ok(self);
808 }
809
810 let mut rng = match self.random_state {
812 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
813 None => RealStdRng::from_seed(thread_rng().random()),
814 };
815
816 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
817 indices.shuffle(&mut rng);
818 let selected_indices = &indices[..n_new];
819
820 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
822
823 let combined_landmarks =
825 scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
826
827 let kernel_matrix = self
829 .kernel
830 .compute_kernel(&combined_landmarks, &combined_landmarks);
831 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
832
833 let mut new_component_indices = self
835 .component_indices_
836 .as_ref()
837 .expect("operation should succeed")
838 .clone();
839 let base_index = current_landmarks.nrows();
840 for &idx in selected_indices {
841 new_component_indices.push(base_index + idx);
842 }
843
844 self.components_ = Some(components);
845 self.normalization_ = Some(normalization);
846 self.component_indices_ = Some(new_component_indices);
847 self.landmark_data_ = Some(combined_landmarks);
848
849 Ok(self)
850 }
851
852 fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
854 let current_landmarks = self
855 .landmark_data_
856 .as_ref()
857 .expect("operation should succeed");
858 let n_new = new_data.nrows().min(self.n_components);
859
860 if n_new == 0 {
861 return Ok(self);
862 }
863
864 let mut rng = match self.random_state {
866 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
867 None => RealStdRng::from_seed(thread_rng().random()),
868 };
869
870 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
871 indices.shuffle(&mut rng);
872 let selected_indices = &indices[..n_new];
873
874 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
875
876 let n_keep = self.n_components - n_new;
878 let combined_landmarks = if n_keep > 0 {
879 let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
880 scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
881 } else {
882 new_landmarks
883 };
884
885 let kernel_matrix = self
887 .kernel
888 .compute_kernel(&combined_landmarks, &combined_landmarks);
889 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
890
891 let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
893
894 self.components_ = Some(components);
895 self.normalization_ = Some(normalization);
896 self.component_indices_ = Some(new_component_indices);
897 self.landmark_data_ = Some(combined_landmarks);
898
899 Ok(self)
900 }
901
902 fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
904 let current_landmarks = self
908 .landmark_data_
909 .as_ref()
910 .expect("operation should succeed");
911 let _current_components = self.components_.as_ref().expect("operation should succeed");
912 let _current_normalization = self
913 .normalization_
914 .as_ref()
915 .expect("operation should succeed");
916
917 let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
919
920 let mut rng = match self.random_state {
921 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
922 None => RealStdRng::from_seed(thread_rng().random()),
923 };
924
925 let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
927 let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
928
929 let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
931 let (_new_components, _new_normalization) =
932 self.compute_decomposition(new_kernel_matrix)?;
933
934 let merged_landmarks =
937 self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
938
939 let merged_kernel_matrix = self
941 .kernel
942 .compute_kernel(&merged_landmarks, &merged_landmarks);
943 let (final_components, final_normalization) =
944 self.compute_decomposition(merged_kernel_matrix)?;
945
946 let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
948
949 let mut updated_self = self;
950 updated_self.components_ = Some(final_components);
951 updated_self.normalization_ = Some(final_normalization);
952 updated_self.component_indices_ = Some(final_component_indices);
953 updated_self.landmark_data_ = Some(merged_landmarks);
954
955 Ok(updated_self)
956 }
957
958 fn merge_landmarks_intelligently(
960 &self,
961 current_landmarks: &Array2<Float>,
962 new_landmarks: &Array2<Float>,
963 rng: &mut RealStdRng,
964 ) -> Result<Array2<Float>> {
965 let n_current = current_landmarks.nrows();
966 let n_new = new_landmarks.nrows();
967 let n_features = current_landmarks.ncols();
968
969 let all_landmarks = scirs2_core::ndarray::concatenate![
971 Axis(0),
972 current_landmarks.clone(),
973 new_landmarks.clone()
974 ];
975
976 let n_target = self.n_components.min(n_current + n_new);
978 let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
979
980 let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
982 for (i, &idx) in selected_indices.iter().enumerate() {
983 merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
984 }
985
986 Ok(merged_landmarks)
987 }
988
989 fn select_diverse_landmarks(
991 &self,
992 landmarks: &Array2<Float>,
993 n_select: usize,
994 rng: &mut RealStdRng,
995 ) -> Result<Vec<usize>> {
996 let n_landmarks = landmarks.nrows();
997
998 if n_select >= n_landmarks {
999 return Ok((0..n_landmarks).collect());
1000 }
1001
1002 let mut selected = Vec::new();
1003 let mut available: Vec<usize> = (0..n_landmarks).collect();
1004
1005 let first_idx = rng.random_range(0..available.len());
1007 selected.push(available.remove(first_idx));
1008
1009 while selected.len() < n_select && !available.is_empty() {
1011 let mut best_idx = 0;
1012 let mut max_min_distance = 0.0;
1013
1014 for (i, &candidate_idx) in available.iter().enumerate() {
1015 let mut min_distance = Float::INFINITY;
1017
1018 for &selected_idx in &selected {
1019 let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1020 let distance = diff.dot(&diff).sqrt();
1021 if distance < min_distance {
1022 min_distance = distance;
1023 }
1024 }
1025
1026 if min_distance > max_min_distance {
1027 max_min_distance = min_distance;
1028 best_idx = i;
1029 }
1030 }
1031
1032 selected.push(available.remove(best_idx));
1033 }
1034
1035 Ok(selected)
1036 }
1037
1038 fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1040 let current_landmarks = self
1043 .landmark_data_
1044 .as_ref()
1045 .expect("operation should succeed");
1046
1047 let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1049
1050 let mut best_update = self.clone();
1052 let mut best_quality = current_quality;
1053
1054 let append_candidate = self.clone().append_update(new_data)?;
1056 let append_quality = append_candidate.evaluate_approximation_quality(
1057 append_candidate
1058 .landmark_data_
1059 .as_ref()
1060 .expect("operation should succeed"),
1061 new_data,
1062 )?;
1063
1064 if append_quality > best_quality + threshold {
1065 best_update = append_candidate;
1066 best_quality = append_quality;
1067 }
1068
1069 if new_data.nrows() >= 3 {
1071 let merge_candidate = self.clone().merge_update(new_data)?;
1072 let merge_quality = merge_candidate.evaluate_approximation_quality(
1073 merge_candidate
1074 .landmark_data_
1075 .as_ref()
1076 .expect("operation should succeed"),
1077 new_data,
1078 )?;
1079
1080 if merge_quality > best_quality + threshold {
1081 best_update = merge_candidate;
1082 best_quality = merge_quality;
1083 }
1084 }
1085
1086 let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1088 let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1089 sliding_candidate
1090 .landmark_data_
1091 .as_ref()
1092 .expect("operation should succeed"),
1093 new_data,
1094 )?;
1095
1096 if sliding_quality > best_quality + threshold {
1097 best_update = sliding_candidate;
1098 best_quality = sliding_quality;
1099 }
1100
1101 if best_quality > current_quality + threshold {
1103 Ok(best_update)
1104 } else {
1105 Ok(self)
1107 }
1108 }
1109
1110 fn evaluate_approximation_quality(
1112 &self,
1113 landmarks: &Array2<Float>,
1114 test_data: &Array2<Float>,
1115 ) -> Result<Float> {
1116 let n_test = test_data.nrows().min(50); let test_subset = if test_data.nrows() > n_test {
1120 let mut rng = thread_rng();
1122 let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1123 indices.shuffle(&mut rng);
1124 test_data.select(Axis(0), &indices[..n_test])
1125 } else {
1126 test_data.to_owned()
1127 };
1128
1129 let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1131
1132 let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1134 let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1135
1136 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1138
1139 let threshold = 1e-8;
1141 let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1142
1143 for i in 0..landmarks.nrows() {
1144 for j in 0..landmarks.nrows() {
1145 let mut sum = 0.0;
1146 for k in 0..eigenvals.len() {
1147 if eigenvals[k] > threshold {
1148 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1149 }
1150 }
1151 pseudo_inverse[[i, j]] = sum;
1152 }
1153 }
1154
1155 let k_approx = k_test_landmarks
1157 .dot(&pseudo_inverse)
1158 .dot(&k_test_landmarks.t());
1159
1160 let error_matrix = &k_exact - &k_approx;
1162 let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1163
1164 let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1166
1167 Ok(quality)
1168 }
1169
1170 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1172 let (_, n_features) = x.dim();
1173 let mut landmarks = Array2::zeros((indices.len(), n_features));
1174
1175 for (i, &idx) in indices.iter().enumerate() {
1176 landmarks.row_mut(i).assign(&x.row(idx));
1177 }
1178
1179 landmarks
1180 }
1181
1182 fn compute_eigendecomposition(
1185 &self,
1186 matrix: Array2<Float>,
1187 ) -> Result<(Array1<Float>, Array2<Float>)> {
1188 let n = matrix.nrows();
1189
1190 if n != matrix.ncols() {
1191 return Err(SklearsError::InvalidInput(
1192 "Matrix must be square for eigendecomposition".to_string(),
1193 ));
1194 }
1195
1196 let mut eigenvals = Array1::zeros(n);
1197 let mut eigenvecs = Array2::zeros((n, n));
1198
1199 let mut deflated_matrix = matrix.clone();
1201
1202 for k in 0..n {
1203 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1205
1206 eigenvals[k] = eigenval;
1207 eigenvecs.column_mut(k).assign(&eigenvec);
1208
1209 for i in 0..n {
1211 for j in 0..n {
1212 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1213 }
1214 }
1215 }
1216
1217 let mut indices: Vec<usize> = (0..n).collect();
1219 indices.sort_by(|&i, &j| {
1220 eigenvals[j]
1221 .partial_cmp(&eigenvals[i])
1222 .expect("operation should succeed")
1223 });
1224
1225 let mut sorted_eigenvals = Array1::zeros(n);
1226 let mut sorted_eigenvecs = Array2::zeros((n, n));
1227
1228 for (new_idx, &old_idx) in indices.iter().enumerate() {
1229 sorted_eigenvals[new_idx] = eigenvals[old_idx];
1230 sorted_eigenvecs
1231 .column_mut(new_idx)
1232 .assign(&eigenvecs.column(old_idx));
1233 }
1234
1235 Ok((sorted_eigenvals, sorted_eigenvecs))
1236 }
1237
1238 fn power_iteration(
1240 &self,
1241 matrix: &Array2<Float>,
1242 max_iter: usize,
1243 tol: Float,
1244 ) -> Result<(Float, Array1<Float>)> {
1245 let n = matrix.nrows();
1246
1247 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1249
1250 let norm = v.dot(&v).sqrt();
1252 if norm < 1e-10 {
1253 return Err(SklearsError::InvalidInput(
1254 "Initial vector has zero norm".to_string(),
1255 ));
1256 }
1257 v /= norm;
1258
1259 let mut eigenval = 0.0;
1260
1261 for _iter in 0..max_iter {
1262 let w = matrix.dot(&v);
1264
1265 let new_eigenval = v.dot(&w);
1267
1268 let w_norm = w.dot(&w).sqrt();
1270 if w_norm < 1e-10 {
1271 break;
1272 }
1273 let new_v = w / w_norm;
1274
1275 let eigenval_change = (new_eigenval - eigenval).abs();
1277 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1278
1279 if eigenval_change < tol && vector_change < tol {
1280 return Ok((new_eigenval, new_v));
1281 }
1282
1283 eigenval = new_eigenval;
1284 v = new_v;
1285 }
1286
1287 Ok((eigenval, v))
1288 }
1289
1290 fn compute_decomposition(
1292 &self,
1293 mut kernel_matrix: Array2<Float>,
1294 ) -> Result<(Array2<Float>, Array2<Float>)> {
1295 let reg = 1e-8;
1297 for i in 0..kernel_matrix.nrows() {
1298 kernel_matrix[[i, i]] += reg;
1299 }
1300
1301 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1303
1304 let threshold = 1e-8;
1306 let valid_indices: Vec<usize> = eigenvals
1307 .iter()
1308 .enumerate()
1309 .filter(|(_, &val)| val > threshold)
1310 .map(|(i, _)| i)
1311 .collect();
1312
1313 if valid_indices.is_empty() {
1314 return Err(SklearsError::InvalidInput(
1315 "No valid eigenvalues found in kernel matrix".to_string(),
1316 ));
1317 }
1318
1319 let n_valid = valid_indices.len();
1321 let mut components = Array2::zeros((eigenvals.len(), n_valid));
1322 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1323
1324 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1325 let sqrt_eigenval = eigenvals[old_idx].sqrt();
1326 components
1327 .column_mut(new_idx)
1328 .assign(&eigenvecs.column(old_idx));
1329
1330 for i in 0..eigenvals.len() {
1332 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1333 }
1334 }
1335
1336 Ok((components, normalization))
1337 }
1338
1339 pub fn update_count(&self) -> usize {
1341 self.update_count_
1342 }
1343
1344 pub fn n_landmarks(&self) -> usize {
1346 self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1347 }
1348}
1349
1350impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1351 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1352 let _components = self
1353 .components_
1354 .as_ref()
1355 .ok_or_else(|| SklearsError::NotFitted {
1356 operation: "transform".to_string(),
1357 })?;
1358
1359 let normalization =
1360 self.normalization_
1361 .as_ref()
1362 .ok_or_else(|| SklearsError::NotFitted {
1363 operation: "transform".to_string(),
1364 })?;
1365
1366 let landmark_data =
1367 self.landmark_data_
1368 .as_ref()
1369 .ok_or_else(|| SklearsError::NotFitted {
1370 operation: "transform".to_string(),
1371 })?;
1372
1373 let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1375
1376 let transformed = kernel_x_landmarks.dot(&normalization.t());
1378
1379 Ok(transformed)
1380 }
1381}
1382
1383#[allow(non_snake_case)]
1384#[cfg(test)]
1385mod tests {
1386 use super::*;
1387 use scirs2_core::ndarray::array;
1388
1389 #[test]
1390 fn test_incremental_nystroem_basic() {
1391 let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1392 let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1393
1394 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1395 .update_strategy(UpdateStrategy::Append)
1396 .min_update_size(1);
1397
1398 let fitted = nystroem
1399 .fit(&x_initial, &())
1400 .expect("operation should succeed");
1401 assert_eq!(fitted.n_landmarks(), 3);
1402
1403 let updated = fitted.update(&x_new).expect("operation should succeed");
1404 assert_eq!(updated.n_landmarks(), 5);
1405 assert_eq!(updated.update_count(), 1);
1406 }
1407
1408 #[test]
1409 fn test_incremental_transform() {
1410 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1411 let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1412
1413 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1414 let fitted = nystroem
1415 .fit(&x_train, &())
1416 .expect("operation should succeed");
1417
1418 let transformed = fitted.transform(&x_test).expect("operation should succeed");
1419 assert_eq!(transformed.shape()[0], 2);
1420 assert!(transformed.shape()[1] <= 3);
1421 }
1422
1423 #[test]
1424 fn test_sliding_window_update() {
1425 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1426 let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1427
1428 let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1429 .update_strategy(UpdateStrategy::SlidingWindow)
1430 .min_update_size(1);
1431
1432 let fitted = nystroem
1433 .fit(&x_initial, &())
1434 .expect("operation should succeed");
1435 let updated = fitted.update(&x_new).expect("operation should succeed");
1436
1437 assert_eq!(updated.n_landmarks(), 3);
1438 assert_eq!(updated.update_count(), 1);
1439 }
1440
1441 #[test]
1442 fn test_different_kernels() {
1443 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1444
1445 let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1447 let rbf_fitted = rbf_nystroem.fit(&x, &()).expect("operation should succeed");
1448 let rbf_transformed = rbf_fitted.transform(&x).expect("operation should succeed");
1449 assert_eq!(rbf_transformed.shape()[0], 3);
1450
1451 let poly_nystroem = IncrementalNystroem::new(
1453 Kernel::Polynomial {
1454 gamma: 1.0,
1455 coef0: 1.0,
1456 degree: 2,
1457 },
1458 3,
1459 );
1460 let poly_fitted = poly_nystroem
1461 .fit(&x, &())
1462 .expect("operation should succeed");
1463 let poly_transformed = poly_fitted.transform(&x).expect("operation should succeed");
1464 assert_eq!(poly_transformed.shape()[0], 3);
1465 }
1466
1467 #[test]
1468 fn test_min_update_size() {
1469 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1470 let x_small = array![[3.0, 4.0]];
1471 let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1472
1473 let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1474
1475 let fitted = nystroem
1476 .fit(&x_initial, &())
1477 .expect("operation should succeed");
1478
1479 let after_small = fitted.update(&x_small).expect("operation should succeed");
1481 assert_eq!(after_small.update_count(), 0);
1482 assert_eq!(after_small.n_landmarks(), 2);
1483
1484 let after_large = after_small
1486 .update(&x_large)
1487 .expect("operation should succeed");
1488 assert_eq!(after_large.update_count(), 1);
1489 assert_eq!(after_large.n_landmarks(), 5);
1490 }
1491
1492 #[test]
1493 fn test_reproducibility() {
1494 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1495 let x_new = array![[4.0, 5.0]];
1496
1497 let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1498 .random_state(42)
1499 .min_update_size(1);
1500 let fitted1 = nystroem1.fit(&x, &()).expect("operation should succeed");
1501 let updated1 = fitted1.update(&x_new).expect("operation should succeed");
1502 let result1 = updated1.transform(&x).expect("operation should succeed");
1503
1504 let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1505 .random_state(42)
1506 .min_update_size(1);
1507 let fitted2 = nystroem2.fit(&x, &()).expect("operation should succeed");
1508 let updated2 = fitted2.update(&x_new).expect("operation should succeed");
1509 let result2 = updated2.transform(&x).expect("operation should succeed");
1510
1511 assert_eq!(result1.shape(), result2.shape());
1514
1515 let mut direct_match = true;
1517 let mut sign_flip_match = true;
1518
1519 for i in 0..result1.len() {
1520 let val1 = result1.as_slice().expect("operation should succeed")[i];
1521 let val2 = result2.as_slice().expect("operation should succeed")[i];
1522
1523 if (val1 - val2).abs() > 1e-6 {
1524 direct_match = false;
1525 }
1526 if (val1 + val2).abs() > 1e-6 {
1527 sign_flip_match = false;
1528 }
1529 }
1530
1531 assert!(
1532 direct_match || sign_flip_match,
1533 "Results differ too much and are not related by sign flip"
1534 );
1535 }
1536}