1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Estimator, Fit, Trained, Transform, Untrained},
13 types::Float,
14};
15use std::marker::PhantomData;
16
17use crate::nystroem::{Kernel, SamplingStrategy};
18
19#[derive(Debug, Clone)]
21pub enum UpdateStrategy {
23 Append,
25 SlidingWindow,
27 Merge,
29 Selective { threshold: Float },
31}
32
33#[derive(Debug, Clone)]
51pub struct IncrementalNystroem<State = Untrained> {
52 pub kernel: Kernel,
53 pub n_components: usize,
54 pub update_strategy: UpdateStrategy,
55 pub min_update_size: usize,
56 pub sampling_strategy: SamplingStrategy,
57 pub random_state: Option<u64>,
58
59 components_: Option<Array2<Float>>,
61 normalization_: Option<Array2<Float>>,
62 component_indices_: Option<Vec<usize>>,
63 landmark_data_: Option<Array2<Float>>,
64 update_count_: usize,
65 accumulated_data_: Option<Array2<Float>>,
66
67 _state: PhantomData<State>,
68}
69
70impl IncrementalNystroem<Untrained> {
71 pub fn new(kernel: Kernel, n_components: usize) -> Self {
73 Self {
74 kernel,
75 n_components,
76 update_strategy: UpdateStrategy::Append,
77 min_update_size: 10,
78 sampling_strategy: SamplingStrategy::Random,
79 random_state: None,
80 components_: None,
81 normalization_: None,
82 component_indices_: None,
83 landmark_data_: None,
84 update_count_: 0,
85 accumulated_data_: None,
86 _state: PhantomData,
87 }
88 }
89
90 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
92 self.update_strategy = strategy;
93 self
94 }
95
96 pub fn min_update_size(mut self, size: usize) -> Self {
98 self.min_update_size = size;
99 self
100 }
101
102 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
104 self.sampling_strategy = strategy;
105 self
106 }
107
108 pub fn random_state(mut self, seed: u64) -> Self {
110 self.random_state = Some(seed);
111 self
112 }
113}
114
115impl Estimator for IncrementalNystroem<Untrained> {
116 type Config = ();
117 type Error = SklearsError;
118 type Float = Float;
119
120 fn config(&self) -> &Self::Config {
121 &()
122 }
123}
124
125impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
126 type Fitted = IncrementalNystroem<Trained>;
127
128 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
129 let (n_samples, n_features) = x.dim();
130 let n_components = self.n_components.min(n_samples);
131
132 let mut rng = match self.random_state {
133 Some(seed) => RealStdRng::seed_from_u64(seed),
134 None => RealStdRng::from_seed(thread_rng().gen()),
135 };
136
137 let component_indices = self.select_components(x, n_components, &mut rng)?;
139 let landmark_data = self.extract_landmarks(x, &component_indices);
140
141 let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
143
144 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
146
147 Ok(IncrementalNystroem {
148 kernel: self.kernel,
149 n_components: self.n_components,
150 update_strategy: self.update_strategy,
151 min_update_size: self.min_update_size,
152 sampling_strategy: self.sampling_strategy,
153 random_state: self.random_state,
154 components_: Some(components),
155 normalization_: Some(normalization),
156 component_indices_: Some(component_indices),
157 landmark_data_: Some(landmark_data),
158 update_count_: 0,
159 accumulated_data_: None,
160 _state: PhantomData,
161 })
162 }
163}
164
165impl IncrementalNystroem<Untrained> {
166 fn select_components(
168 &self,
169 x: &Array2<Float>,
170 n_components: usize,
171 rng: &mut RealStdRng,
172 ) -> Result<Vec<usize>> {
173 let (n_samples, _) = x.dim();
174
175 match &self.sampling_strategy {
176 SamplingStrategy::Random => {
177 let mut indices: Vec<usize> = (0..n_samples).collect();
178 indices.shuffle(rng);
179 Ok(indices[..n_components].to_vec())
180 }
181 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
182 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
183 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
184 }
185 }
186
187 fn kmeans_sampling(
189 &self,
190 x: &Array2<Float>,
191 n_components: usize,
192 rng: &mut RealStdRng,
193 ) -> Result<Vec<usize>> {
194 let (n_samples, n_features) = x.dim();
195 let mut centers = Array2::zeros((n_components, n_features));
196
197 let mut indices: Vec<usize> = (0..n_samples).collect();
199 indices.shuffle(rng);
200 for (i, &idx) in indices[..n_components].iter().enumerate() {
201 centers.row_mut(i).assign(&x.row(idx));
202 }
203
204 for _iter in 0..5 {
206 let mut assignments = vec![0; n_samples];
207
208 for i in 0..n_samples {
210 let mut min_dist = Float::INFINITY;
211 let mut best_center = 0;
212
213 for j in 0..n_components {
214 let diff = &x.row(i) - ¢ers.row(j);
215 let dist = diff.dot(&diff);
216 if dist < min_dist {
217 min_dist = dist;
218 best_center = j;
219 }
220 }
221 assignments[i] = best_center;
222 }
223
224 for j in 0..n_components {
226 let cluster_points: Vec<usize> = assignments
227 .iter()
228 .enumerate()
229 .filter(|(_, &assignment)| assignment == j)
230 .map(|(i, _)| i)
231 .collect();
232
233 if !cluster_points.is_empty() {
234 let mut new_center = Array1::zeros(n_features);
235 for &point_idx in &cluster_points {
236 new_center = new_center + &x.row(point_idx);
237 }
238 new_center /= cluster_points.len() as Float;
239 centers.row_mut(j).assign(&new_center);
240 }
241 }
242 }
243
244 let mut selected_indices = Vec::new();
246 for j in 0..n_components {
247 let mut min_dist = Float::INFINITY;
248 let mut best_point = 0;
249
250 for i in 0..n_samples {
251 let diff = &x.row(i) - ¢ers.row(j);
252 let dist = diff.dot(&diff);
253 if dist < min_dist {
254 min_dist = dist;
255 best_point = i;
256 }
257 }
258 selected_indices.push(best_point);
259 }
260
261 selected_indices.sort_unstable();
262 selected_indices.dedup();
263
264 while selected_indices.len() < n_components {
266 let random_idx = rng.gen_range(0..n_samples);
267 if !selected_indices.contains(&random_idx) {
268 selected_indices.push(random_idx);
269 }
270 }
271
272 Ok(selected_indices[..n_components].to_vec())
273 }
274
275 fn leverage_score_sampling(
277 &self,
278 x: &Array2<Float>,
279 n_components: usize,
280 rng: &mut RealStdRng,
281 ) -> Result<Vec<usize>> {
282 let (n_samples, _) = x.dim();
283
284 let mut scores = Vec::new();
287 for i in 0..n_samples {
288 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
289 scores.push(row_norm + 1e-10); }
291
292 let total_score: Float = scores.iter().sum();
294 if total_score <= 0.0 {
295 return Err(SklearsError::InvalidInput(
296 "All scores are zero or negative".to_string(),
297 ));
298 }
299
300 let mut cumulative = Vec::with_capacity(scores.len());
302 let mut sum = 0.0;
303 for &score in &scores {
304 sum += score / total_score;
305 cumulative.push(sum);
306 }
307
308 let mut selected_indices = Vec::new();
309 for _ in 0..n_components {
310 let r = thread_rng().gen::<Float>();
311 let mut idx = cumulative
313 .iter()
314 .position(|&cum| cum >= r)
315 .unwrap_or(scores.len() - 1);
316
317 while selected_indices.contains(&idx) {
319 let r = thread_rng().gen::<Float>();
320 idx = cumulative
321 .iter()
322 .position(|&cum| cum >= r)
323 .unwrap_or(scores.len() - 1);
324 }
325 selected_indices.push(idx);
326 }
327
328 Ok(selected_indices)
329 }
330
331 fn column_norm_sampling(
333 &self,
334 x: &Array2<Float>,
335 n_components: usize,
336 rng: &mut RealStdRng,
337 ) -> Result<Vec<usize>> {
338 let (n_samples, _) = x.dim();
339
340 let mut norms = Vec::new();
342 for i in 0..n_samples {
343 let norm = x.row(i).dot(&x.row(i)).sqrt();
344 norms.push(norm + 1e-10);
345 }
346
347 let mut indices_with_norms: Vec<(usize, Float)> = norms
349 .iter()
350 .enumerate()
351 .map(|(i, &norm)| (i, norm))
352 .collect();
353 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
354
355 let mut selected_indices = Vec::new();
356 let step = n_samples.max(1) / n_components.max(1);
357
358 for i in 0..n_components {
359 let idx = (i * step).min(n_samples - 1);
360 selected_indices.push(indices_with_norms[idx].0);
361 }
362
363 while selected_indices.len() < n_components {
365 let random_idx = rng.gen_range(0..n_samples);
366 if !selected_indices.contains(&random_idx) {
367 selected_indices.push(random_idx);
368 }
369 }
370
371 Ok(selected_indices)
372 }
373
374 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
376 let (_, n_features) = x.dim();
377 let mut landmarks = Array2::zeros((indices.len(), n_features));
378
379 for (i, &idx) in indices.iter().enumerate() {
380 landmarks.row_mut(i).assign(&x.row(idx));
381 }
382
383 landmarks
384 }
385
386 fn compute_eigendecomposition(
389 &self,
390 matrix: Array2<Float>,
391 ) -> Result<(Array1<Float>, Array2<Float>)> {
392 let n = matrix.nrows();
393
394 if n != matrix.ncols() {
395 return Err(SklearsError::InvalidInput(
396 "Matrix must be square for eigendecomposition".to_string(),
397 ));
398 }
399
400 let mut eigenvals = Array1::zeros(n);
401 let mut eigenvecs = Array2::zeros((n, n));
402
403 let mut deflated_matrix = matrix.clone();
405
406 for k in 0..n {
407 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
409
410 eigenvals[k] = eigenval;
411 eigenvecs.column_mut(k).assign(&eigenvec);
412
413 for i in 0..n {
415 for j in 0..n {
416 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
417 }
418 }
419 }
420
421 let mut indices: Vec<usize> = (0..n).collect();
423 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
424
425 let mut sorted_eigenvals = Array1::zeros(n);
426 let mut sorted_eigenvecs = Array2::zeros((n, n));
427
428 for (new_idx, &old_idx) in indices.iter().enumerate() {
429 sorted_eigenvals[new_idx] = eigenvals[old_idx];
430 sorted_eigenvecs
431 .column_mut(new_idx)
432 .assign(&eigenvecs.column(old_idx));
433 }
434
435 Ok((sorted_eigenvals, sorted_eigenvecs))
436 }
437
438 fn power_iteration(
440 &self,
441 matrix: &Array2<Float>,
442 max_iter: usize,
443 tol: Float,
444 ) -> Result<(Float, Array1<Float>)> {
445 let n = matrix.nrows();
446
447 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
449
450 let norm = v.dot(&v).sqrt();
452 if norm < 1e-10 {
453 return Err(SklearsError::InvalidInput(
454 "Initial vector has zero norm".to_string(),
455 ));
456 }
457 v /= norm;
458
459 let mut eigenval = 0.0;
460
461 for _iter in 0..max_iter {
462 let w = matrix.dot(&v);
464
465 let new_eigenval = v.dot(&w);
467
468 let w_norm = w.dot(&w).sqrt();
470 if w_norm < 1e-10 {
471 break;
472 }
473 let new_v = w / w_norm;
474
475 let eigenval_change = (new_eigenval - eigenval).abs();
477 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
478
479 if eigenval_change < tol && vector_change < tol {
480 return Ok((new_eigenval, new_v));
481 }
482
483 eigenval = new_eigenval;
484 v = new_v;
485 }
486
487 Ok((eigenval, v))
488 }
489
490 fn compute_decomposition(
492 &self,
493 mut kernel_matrix: Array2<Float>,
494 ) -> Result<(Array2<Float>, Array2<Float>)> {
495 let reg = 1e-8;
497 for i in 0..kernel_matrix.nrows() {
498 kernel_matrix[[i, i]] += reg;
499 }
500
501 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
503
504 let threshold = 1e-8;
506 let valid_indices: Vec<usize> = eigenvals
507 .iter()
508 .enumerate()
509 .filter(|(_, &val)| val > threshold)
510 .map(|(i, _)| i)
511 .collect();
512
513 if valid_indices.is_empty() {
514 return Err(SklearsError::InvalidInput(
515 "No valid eigenvalues found in kernel matrix".to_string(),
516 ));
517 }
518
519 let n_valid = valid_indices.len();
521 let mut components = Array2::zeros((eigenvals.len(), n_valid));
522 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
523
524 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
525 let sqrt_eigenval = eigenvals[old_idx].sqrt();
526 components
527 .column_mut(new_idx)
528 .assign(&eigenvecs.column(old_idx));
529
530 for i in 0..eigenvals.len() {
532 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
533 }
534 }
535
536 Ok((components, normalization))
537 }
538}
539
540impl IncrementalNystroem<Trained> {
541 fn select_components(
543 &self,
544 x: &Array2<Float>,
545 n_components: usize,
546 rng: &mut RealStdRng,
547 ) -> Result<Vec<usize>> {
548 let (n_samples, _) = x.dim();
549
550 match &self.sampling_strategy {
551 SamplingStrategy::Random => {
552 let mut indices: Vec<usize> = (0..n_samples).collect();
553 indices.shuffle(rng);
554 Ok(indices[..n_components].to_vec())
555 }
556 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
557 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
558 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
559 }
560 }
561
562 fn kmeans_sampling(
564 &self,
565 x: &Array2<Float>,
566 n_components: usize,
567 rng: &mut RealStdRng,
568 ) -> Result<Vec<usize>> {
569 let (n_samples, n_features) = x.dim();
570 let mut centers = Array2::zeros((n_components, n_features));
571
572 let mut indices: Vec<usize> = (0..n_samples).collect();
574 indices.shuffle(rng);
575 for (i, &idx) in indices[..n_components].iter().enumerate() {
576 centers.row_mut(i).assign(&x.row(idx));
577 }
578
579 for _iter in 0..5 {
581 let mut assignments = vec![0; n_samples];
582
583 for i in 0..n_samples {
585 let mut min_dist = Float::INFINITY;
586 let mut best_center = 0;
587
588 for j in 0..n_components {
589 let diff = &x.row(i) - ¢ers.row(j);
590 let dist = diff.dot(&diff);
591 if dist < min_dist {
592 min_dist = dist;
593 best_center = j;
594 }
595 }
596 assignments[i] = best_center;
597 }
598
599 for j in 0..n_components {
601 let cluster_points: Vec<usize> = assignments
602 .iter()
603 .enumerate()
604 .filter(|(_, &assignment)| assignment == j)
605 .map(|(i, _)| i)
606 .collect();
607
608 if !cluster_points.is_empty() {
609 let mut new_center = Array1::zeros(n_features);
610 for &point_idx in &cluster_points {
611 new_center = new_center + &x.row(point_idx);
612 }
613 new_center /= cluster_points.len() as Float;
614 centers.row_mut(j).assign(&new_center);
615 }
616 }
617 }
618
619 let mut selected_indices = Vec::new();
621 for j in 0..n_components {
622 let mut min_dist = Float::INFINITY;
623 let mut best_point = 0;
624
625 for i in 0..n_samples {
626 let diff = &x.row(i) - ¢ers.row(j);
627 let dist = diff.dot(&diff);
628 if dist < min_dist {
629 min_dist = dist;
630 best_point = i;
631 }
632 }
633 selected_indices.push(best_point);
634 }
635
636 selected_indices.sort_unstable();
637 selected_indices.dedup();
638
639 while selected_indices.len() < n_components {
641 let random_idx = rng.gen_range(0..n_samples);
642 if !selected_indices.contains(&random_idx) {
643 selected_indices.push(random_idx);
644 }
645 }
646
647 Ok(selected_indices[..n_components].to_vec())
648 }
649
650 fn leverage_score_sampling(
652 &self,
653 x: &Array2<Float>,
654 n_components: usize,
655 rng: &mut RealStdRng,
656 ) -> Result<Vec<usize>> {
657 let (n_samples, _) = x.dim();
658
659 let mut scores = Vec::new();
662 for i in 0..n_samples {
663 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
664 scores.push(row_norm + 1e-10); }
666
667 let total_score: Float = scores.iter().sum();
669 if total_score <= 0.0 {
670 return Err(SklearsError::InvalidInput(
671 "All scores are zero or negative".to_string(),
672 ));
673 }
674
675 let mut cumulative = Vec::with_capacity(scores.len());
677 let mut sum = 0.0;
678 for &score in &scores {
679 sum += score / total_score;
680 cumulative.push(sum);
681 }
682
683 let mut selected_indices = Vec::new();
684 for _ in 0..n_components {
685 let r = thread_rng().gen::<Float>();
686 let mut idx = cumulative
688 .iter()
689 .position(|&cum| cum >= r)
690 .unwrap_or(scores.len() - 1);
691
692 while selected_indices.contains(&idx) {
694 let r = thread_rng().gen::<Float>();
695 idx = cumulative
696 .iter()
697 .position(|&cum| cum >= r)
698 .unwrap_or(scores.len() - 1);
699 }
700 selected_indices.push(idx);
701 }
702
703 Ok(selected_indices)
704 }
705
706 fn column_norm_sampling(
708 &self,
709 x: &Array2<Float>,
710 n_components: usize,
711 rng: &mut RealStdRng,
712 ) -> Result<Vec<usize>> {
713 let (n_samples, _) = x.dim();
714
715 let mut norms = Vec::new();
717 for i in 0..n_samples {
718 let norm = x.row(i).dot(&x.row(i)).sqrt();
719 norms.push(norm + 1e-10);
720 }
721
722 let mut indices_with_norms: Vec<(usize, Float)> = norms
724 .iter()
725 .enumerate()
726 .map(|(i, &norm)| (i, norm))
727 .collect();
728 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
729
730 let mut selected_indices = Vec::new();
731 let step = n_samples.max(1) / n_components.max(1);
732
733 for i in 0..n_components {
734 let idx = (i * step).min(n_samples - 1);
735 selected_indices.push(indices_with_norms[idx].0);
736 }
737
738 while selected_indices.len() < n_components {
740 let random_idx = rng.gen_range(0..n_samples);
741 if !selected_indices.contains(&random_idx) {
742 selected_indices.push(random_idx);
743 }
744 }
745
746 Ok(selected_indices)
747 }
748
749 pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
751 match &self.accumulated_data_ {
753 Some(existing) => {
754 let combined =
755 scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
756 self.accumulated_data_ = Some(combined);
757 }
758 None => {
759 self.accumulated_data_ = Some(x_new.clone());
760 }
761 }
762
763 let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
765 accumulated.nrows() >= self.min_update_size
766 } else {
767 false
768 };
769
770 if should_update {
771 if let Some(accumulated) = self.accumulated_data_.take() {
772 self = self.perform_update(&accumulated)?;
773 self.update_count_ += 1;
774 }
775 }
776
777 Ok(self)
778 }
779
780 fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
782 match self.update_strategy.clone() {
783 UpdateStrategy::Append => self.append_update(new_data),
784 UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
785 UpdateStrategy::Merge => self.merge_update(new_data),
786 UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
787 }
788 }
789
790 fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
792 let current_landmarks = self.landmark_data_.as_ref().unwrap();
793 let current_components = current_landmarks.nrows();
794
795 if current_components >= self.n_components {
796 return Ok(self);
798 }
799
800 let available_space = self.n_components - current_components;
801 let n_new = available_space.min(new_data.nrows());
802
803 if n_new == 0 {
804 return Ok(self);
805 }
806
807 let mut rng = match self.random_state {
809 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
810 None => RealStdRng::from_seed(thread_rng().gen()),
811 };
812
813 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
814 indices.shuffle(&mut rng);
815 let selected_indices = &indices[..n_new];
816
817 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
819
820 let combined_landmarks =
822 scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
823
824 let kernel_matrix = self
826 .kernel
827 .compute_kernel(&combined_landmarks, &combined_landmarks);
828 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
829
830 let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
832 let base_index = current_landmarks.nrows();
833 for &idx in selected_indices {
834 new_component_indices.push(base_index + idx);
835 }
836
837 self.components_ = Some(components);
838 self.normalization_ = Some(normalization);
839 self.component_indices_ = Some(new_component_indices);
840 self.landmark_data_ = Some(combined_landmarks);
841
842 Ok(self)
843 }
844
845 fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
847 let current_landmarks = self.landmark_data_.as_ref().unwrap();
848 let n_new = new_data.nrows().min(self.n_components);
849
850 if n_new == 0 {
851 return Ok(self);
852 }
853
854 let mut rng = match self.random_state {
856 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
857 None => RealStdRng::from_seed(thread_rng().gen()),
858 };
859
860 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
861 indices.shuffle(&mut rng);
862 let selected_indices = &indices[..n_new];
863
864 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
865
866 let n_keep = self.n_components - n_new;
868 let combined_landmarks = if n_keep > 0 {
869 let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
870 scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
871 } else {
872 new_landmarks
873 };
874
875 let kernel_matrix = self
877 .kernel
878 .compute_kernel(&combined_landmarks, &combined_landmarks);
879 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
880
881 let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
883
884 self.components_ = Some(components);
885 self.normalization_ = Some(normalization);
886 self.component_indices_ = Some(new_component_indices);
887 self.landmark_data_ = Some(combined_landmarks);
888
889 Ok(self)
890 }
891
892 fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
894 let current_landmarks = self.landmark_data_.as_ref().unwrap();
898 let current_components = self.components_.as_ref().unwrap();
899 let current_normalization = self.normalization_.as_ref().unwrap();
900
901 let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
903
904 let mut rng = match self.random_state {
905 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
906 None => RealStdRng::from_seed(thread_rng().gen()),
907 };
908
909 let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
911 let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
912
913 let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
915 let (new_components, new_normalization) = self.compute_decomposition(new_kernel_matrix)?;
916
917 let merged_landmarks =
920 self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
921
922 let merged_kernel_matrix = self
924 .kernel
925 .compute_kernel(&merged_landmarks, &merged_landmarks);
926 let (final_components, final_normalization) =
927 self.compute_decomposition(merged_kernel_matrix)?;
928
929 let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
931
932 let mut updated_self = self;
933 updated_self.components_ = Some(final_components);
934 updated_self.normalization_ = Some(final_normalization);
935 updated_self.component_indices_ = Some(final_component_indices);
936 updated_self.landmark_data_ = Some(merged_landmarks);
937
938 Ok(updated_self)
939 }
940
941 fn merge_landmarks_intelligently(
943 &self,
944 current_landmarks: &Array2<Float>,
945 new_landmarks: &Array2<Float>,
946 rng: &mut RealStdRng,
947 ) -> Result<Array2<Float>> {
948 let n_current = current_landmarks.nrows();
949 let n_new = new_landmarks.nrows();
950 let n_features = current_landmarks.ncols();
951
952 let all_landmarks = scirs2_core::ndarray::concatenate![
954 Axis(0),
955 current_landmarks.clone(),
956 new_landmarks.clone()
957 ];
958
959 let n_target = self.n_components.min(n_current + n_new);
961 let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
962
963 let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
965 for (i, &idx) in selected_indices.iter().enumerate() {
966 merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
967 }
968
969 Ok(merged_landmarks)
970 }
971
972 fn select_diverse_landmarks(
974 &self,
975 landmarks: &Array2<Float>,
976 n_select: usize,
977 rng: &mut RealStdRng,
978 ) -> Result<Vec<usize>> {
979 let n_landmarks = landmarks.nrows();
980
981 if n_select >= n_landmarks {
982 return Ok((0..n_landmarks).collect());
983 }
984
985 let mut selected = Vec::new();
986 let mut available: Vec<usize> = (0..n_landmarks).collect();
987
988 let first_idx = rng.gen_range(0..available.len());
990 selected.push(available.remove(first_idx));
991
992 while selected.len() < n_select && !available.is_empty() {
994 let mut best_idx = 0;
995 let mut max_min_distance = 0.0;
996
997 for (i, &candidate_idx) in available.iter().enumerate() {
998 let mut min_distance = Float::INFINITY;
1000
1001 for &selected_idx in &selected {
1002 let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1003 let distance = diff.dot(&diff).sqrt();
1004 if distance < min_distance {
1005 min_distance = distance;
1006 }
1007 }
1008
1009 if min_distance > max_min_distance {
1010 max_min_distance = min_distance;
1011 best_idx = i;
1012 }
1013 }
1014
1015 selected.push(available.remove(best_idx));
1016 }
1017
1018 Ok(selected)
1019 }
1020
1021 fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1023 let current_landmarks = self.landmark_data_.as_ref().unwrap();
1026
1027 let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1029
1030 let mut best_update = self.clone();
1032 let mut best_quality = current_quality;
1033
1034 let append_candidate = self.clone().append_update(new_data)?;
1036 let append_quality = append_candidate.evaluate_approximation_quality(
1037 append_candidate.landmark_data_.as_ref().unwrap(),
1038 new_data,
1039 )?;
1040
1041 if append_quality > best_quality + threshold {
1042 best_update = append_candidate;
1043 best_quality = append_quality;
1044 }
1045
1046 if new_data.nrows() >= 3 {
1048 let merge_candidate = self.clone().merge_update(new_data)?;
1049 let merge_quality = merge_candidate.evaluate_approximation_quality(
1050 merge_candidate.landmark_data_.as_ref().unwrap(),
1051 new_data,
1052 )?;
1053
1054 if merge_quality > best_quality + threshold {
1055 best_update = merge_candidate;
1056 best_quality = merge_quality;
1057 }
1058 }
1059
1060 let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1062 let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1063 sliding_candidate.landmark_data_.as_ref().unwrap(),
1064 new_data,
1065 )?;
1066
1067 if sliding_quality > best_quality + threshold {
1068 best_update = sliding_candidate;
1069 best_quality = sliding_quality;
1070 }
1071
1072 if best_quality > current_quality + threshold {
1074 Ok(best_update)
1075 } else {
1076 Ok(self)
1078 }
1079 }
1080
1081 fn evaluate_approximation_quality(
1083 &self,
1084 landmarks: &Array2<Float>,
1085 test_data: &Array2<Float>,
1086 ) -> Result<Float> {
1087 let n_test = test_data.nrows().min(50); let test_subset = if test_data.nrows() > n_test {
1091 let mut rng = thread_rng();
1093 let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1094 indices.shuffle(&mut rng);
1095 test_data.select(Axis(0), &indices[..n_test])
1096 } else {
1097 test_data.to_owned()
1098 };
1099
1100 let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1102
1103 let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1105 let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1106
1107 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1109
1110 let threshold = 1e-8;
1112 let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1113
1114 for i in 0..landmarks.nrows() {
1115 for j in 0..landmarks.nrows() {
1116 let mut sum = 0.0;
1117 for k in 0..eigenvals.len() {
1118 if eigenvals[k] > threshold {
1119 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1120 }
1121 }
1122 pseudo_inverse[[i, j]] = sum;
1123 }
1124 }
1125
1126 let k_approx = k_test_landmarks
1128 .dot(&pseudo_inverse)
1129 .dot(&k_test_landmarks.t());
1130
1131 let error_matrix = &k_exact - &k_approx;
1133 let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1134
1135 let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1137
1138 Ok(quality)
1139 }
1140
1141 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1143 let (_, n_features) = x.dim();
1144 let mut landmarks = Array2::zeros((indices.len(), n_features));
1145
1146 for (i, &idx) in indices.iter().enumerate() {
1147 landmarks.row_mut(i).assign(&x.row(idx));
1148 }
1149
1150 landmarks
1151 }
1152
1153 fn compute_eigendecomposition(
1156 &self,
1157 matrix: Array2<Float>,
1158 ) -> Result<(Array1<Float>, Array2<Float>)> {
1159 let n = matrix.nrows();
1160
1161 if n != matrix.ncols() {
1162 return Err(SklearsError::InvalidInput(
1163 "Matrix must be square for eigendecomposition".to_string(),
1164 ));
1165 }
1166
1167 let mut eigenvals = Array1::zeros(n);
1168 let mut eigenvecs = Array2::zeros((n, n));
1169
1170 let mut deflated_matrix = matrix.clone();
1172
1173 for k in 0..n {
1174 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1176
1177 eigenvals[k] = eigenval;
1178 eigenvecs.column_mut(k).assign(&eigenvec);
1179
1180 for i in 0..n {
1182 for j in 0..n {
1183 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1184 }
1185 }
1186 }
1187
1188 let mut indices: Vec<usize> = (0..n).collect();
1190 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1191
1192 let mut sorted_eigenvals = Array1::zeros(n);
1193 let mut sorted_eigenvecs = Array2::zeros((n, n));
1194
1195 for (new_idx, &old_idx) in indices.iter().enumerate() {
1196 sorted_eigenvals[new_idx] = eigenvals[old_idx];
1197 sorted_eigenvecs
1198 .column_mut(new_idx)
1199 .assign(&eigenvecs.column(old_idx));
1200 }
1201
1202 Ok((sorted_eigenvals, sorted_eigenvecs))
1203 }
1204
1205 fn power_iteration(
1207 &self,
1208 matrix: &Array2<Float>,
1209 max_iter: usize,
1210 tol: Float,
1211 ) -> Result<(Float, Array1<Float>)> {
1212 let n = matrix.nrows();
1213
1214 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1216
1217 let norm = v.dot(&v).sqrt();
1219 if norm < 1e-10 {
1220 return Err(SklearsError::InvalidInput(
1221 "Initial vector has zero norm".to_string(),
1222 ));
1223 }
1224 v /= norm;
1225
1226 let mut eigenval = 0.0;
1227
1228 for _iter in 0..max_iter {
1229 let w = matrix.dot(&v);
1231
1232 let new_eigenval = v.dot(&w);
1234
1235 let w_norm = w.dot(&w).sqrt();
1237 if w_norm < 1e-10 {
1238 break;
1239 }
1240 let new_v = w / w_norm;
1241
1242 let eigenval_change = (new_eigenval - eigenval).abs();
1244 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1245
1246 if eigenval_change < tol && vector_change < tol {
1247 return Ok((new_eigenval, new_v));
1248 }
1249
1250 eigenval = new_eigenval;
1251 v = new_v;
1252 }
1253
1254 Ok((eigenval, v))
1255 }
1256
1257 fn compute_decomposition(
1259 &self,
1260 mut kernel_matrix: Array2<Float>,
1261 ) -> Result<(Array2<Float>, Array2<Float>)> {
1262 let reg = 1e-8;
1264 for i in 0..kernel_matrix.nrows() {
1265 kernel_matrix[[i, i]] += reg;
1266 }
1267
1268 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1270
1271 let threshold = 1e-8;
1273 let valid_indices: Vec<usize> = eigenvals
1274 .iter()
1275 .enumerate()
1276 .filter(|(_, &val)| val > threshold)
1277 .map(|(i, _)| i)
1278 .collect();
1279
1280 if valid_indices.is_empty() {
1281 return Err(SklearsError::InvalidInput(
1282 "No valid eigenvalues found in kernel matrix".to_string(),
1283 ));
1284 }
1285
1286 let n_valid = valid_indices.len();
1288 let mut components = Array2::zeros((eigenvals.len(), n_valid));
1289 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1290
1291 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1292 let sqrt_eigenval = eigenvals[old_idx].sqrt();
1293 components
1294 .column_mut(new_idx)
1295 .assign(&eigenvecs.column(old_idx));
1296
1297 for i in 0..eigenvals.len() {
1299 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1300 }
1301 }
1302
1303 Ok((components, normalization))
1304 }
1305
1306 pub fn update_count(&self) -> usize {
1308 self.update_count_
1309 }
1310
1311 pub fn n_landmarks(&self) -> usize {
1313 self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1314 }
1315}
1316
1317impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1318 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1319 let components = self
1320 .components_
1321 .as_ref()
1322 .ok_or_else(|| SklearsError::NotFitted {
1323 operation: "transform".to_string(),
1324 })?;
1325
1326 let normalization =
1327 self.normalization_
1328 .as_ref()
1329 .ok_or_else(|| SklearsError::NotFitted {
1330 operation: "transform".to_string(),
1331 })?;
1332
1333 let landmark_data =
1334 self.landmark_data_
1335 .as_ref()
1336 .ok_or_else(|| SklearsError::NotFitted {
1337 operation: "transform".to_string(),
1338 })?;
1339
1340 let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1342
1343 let transformed = kernel_x_landmarks.dot(&normalization.t());
1345
1346 Ok(transformed)
1347 }
1348}
1349
1350#[allow(non_snake_case)]
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354 use scirs2_core::ndarray::array;
1355
1356 #[test]
1357 fn test_incremental_nystroem_basic() {
1358 let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1359 let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1360
1361 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1362 .update_strategy(UpdateStrategy::Append)
1363 .min_update_size(1);
1364
1365 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1366 assert_eq!(fitted.n_landmarks(), 3);
1367
1368 let updated = fitted.update(&x_new).unwrap();
1369 assert_eq!(updated.n_landmarks(), 5);
1370 assert_eq!(updated.update_count(), 1);
1371 }
1372
1373 #[test]
1374 fn test_incremental_transform() {
1375 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1376 let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1377
1378 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1379 let fitted = nystroem.fit(&x_train, &()).unwrap();
1380
1381 let transformed = fitted.transform(&x_test).unwrap();
1382 assert_eq!(transformed.shape()[0], 2);
1383 assert!(transformed.shape()[1] <= 3);
1384 }
1385
1386 #[test]
1387 fn test_sliding_window_update() {
1388 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1389 let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1390
1391 let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1392 .update_strategy(UpdateStrategy::SlidingWindow)
1393 .min_update_size(1);
1394
1395 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1396 let updated = fitted.update(&x_new).unwrap();
1397
1398 assert_eq!(updated.n_landmarks(), 3);
1399 assert_eq!(updated.update_count(), 1);
1400 }
1401
1402 #[test]
1403 fn test_different_kernels() {
1404 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1405
1406 let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1408 let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1409 let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1410 assert_eq!(rbf_transformed.shape()[0], 3);
1411
1412 let poly_nystroem = IncrementalNystroem::new(
1414 Kernel::Polynomial {
1415 gamma: 1.0,
1416 coef0: 1.0,
1417 degree: 2,
1418 },
1419 3,
1420 );
1421 let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1422 let poly_transformed = poly_fitted.transform(&x).unwrap();
1423 assert_eq!(poly_transformed.shape()[0], 3);
1424 }
1425
1426 #[test]
1427 fn test_min_update_size() {
1428 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1429 let x_small = array![[3.0, 4.0]];
1430 let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1431
1432 let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1433
1434 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1435
1436 let after_small = fitted.update(&x_small).unwrap();
1438 assert_eq!(after_small.update_count(), 0);
1439 assert_eq!(after_small.n_landmarks(), 2);
1440
1441 let after_large = after_small.update(&x_large).unwrap();
1443 assert_eq!(after_large.update_count(), 1);
1444 assert_eq!(after_large.n_landmarks(), 5);
1445 }
1446
1447 #[test]
1448 fn test_reproducibility() {
1449 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1450 let x_new = array![[4.0, 5.0]];
1451
1452 let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1453 .random_state(42)
1454 .min_update_size(1);
1455 let fitted1 = nystroem1.fit(&x, &()).unwrap();
1456 let updated1 = fitted1.update(&x_new).unwrap();
1457 let result1 = updated1.transform(&x).unwrap();
1458
1459 let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1460 .random_state(42)
1461 .min_update_size(1);
1462 let fitted2 = nystroem2.fit(&x, &()).unwrap();
1463 let updated2 = fitted2.update(&x_new).unwrap();
1464 let result2 = updated2.transform(&x).unwrap();
1465
1466 assert_eq!(result1.shape(), result2.shape());
1469
1470 let mut direct_match = true;
1472 let mut sign_flip_match = true;
1473
1474 for i in 0..result1.len() {
1475 let val1 = result1.as_slice().unwrap()[i];
1476 let val2 = result2.as_slice().unwrap()[i];
1477
1478 if (val1 - val2).abs() > 1e-6 {
1479 direct_match = false;
1480 }
1481 if (val1 + val2).abs() > 1e-6 {
1482 sign_flip_match = false;
1483 }
1484 }
1485
1486 assert!(
1487 direct_match || sign_flip_match,
1488 "Results differ too much and are not related by sign flip"
1489 );
1490 }
1491}