1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12 error::{Result, SklearsError},
13 traits::{Estimator, Fit, Trained, Transform, Untrained},
14 types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20#[derive(Debug, Clone)]
22pub enum UpdateStrategy {
24 Append,
26 SlidingWindow,
28 Merge,
30 Selective { threshold: Float },
32}
33
34#[derive(Debug, Clone)]
47pub struct IncrementalNystroem<State = Untrained> {
48 pub kernel: Kernel,
49 pub n_components: usize,
50 pub update_strategy: UpdateStrategy,
51 pub min_update_size: usize,
52 pub sampling_strategy: SamplingStrategy,
53 pub random_state: Option<u64>,
54
55 components_: Option<Array2<Float>>,
57 normalization_: Option<Array2<Float>>,
58 component_indices_: Option<Vec<usize>>,
59 landmark_data_: Option<Array2<Float>>,
60 update_count_: usize,
61 accumulated_data_: Option<Array2<Float>>,
62
63 _state: PhantomData<State>,
64}
65
66impl IncrementalNystroem<Untrained> {
67 pub fn new(kernel: Kernel, n_components: usize) -> Self {
69 Self {
70 kernel,
71 n_components,
72 update_strategy: UpdateStrategy::Append,
73 min_update_size: 10,
74 sampling_strategy: SamplingStrategy::Random,
75 random_state: None,
76 components_: None,
77 normalization_: None,
78 component_indices_: None,
79 landmark_data_: None,
80 update_count_: 0,
81 accumulated_data_: None,
82 _state: PhantomData,
83 }
84 }
85
86 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
88 self.update_strategy = strategy;
89 self
90 }
91
92 pub fn min_update_size(mut self, size: usize) -> Self {
94 self.min_update_size = size;
95 self
96 }
97
98 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
100 self.sampling_strategy = strategy;
101 self
102 }
103
104 pub fn random_state(mut self, seed: u64) -> Self {
106 self.random_state = Some(seed);
107 self
108 }
109}
110
111impl Estimator for IncrementalNystroem<Untrained> {
112 type Config = ();
113 type Error = SklearsError;
114 type Float = Float;
115
116 fn config(&self) -> &Self::Config {
117 &()
118 }
119}
120
121impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
122 type Fitted = IncrementalNystroem<Trained>;
123
124 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
125 let (n_samples, _n_features) = x.dim();
126 let n_components = self.n_components.min(n_samples);
127
128 let mut rng = match self.random_state {
129 Some(seed) => RealStdRng::seed_from_u64(seed),
130 None => RealStdRng::from_seed(thread_rng().gen()),
131 };
132
133 let component_indices = self.select_components(x, n_components, &mut rng)?;
135 let landmark_data = self.extract_landmarks(x, &component_indices);
136
137 let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
139
140 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
142
143 Ok(IncrementalNystroem {
144 kernel: self.kernel,
145 n_components: self.n_components,
146 update_strategy: self.update_strategy,
147 min_update_size: self.min_update_size,
148 sampling_strategy: self.sampling_strategy,
149 random_state: self.random_state,
150 components_: Some(components),
151 normalization_: Some(normalization),
152 component_indices_: Some(component_indices),
153 landmark_data_: Some(landmark_data),
154 update_count_: 0,
155 accumulated_data_: None,
156 _state: PhantomData,
157 })
158 }
159}
160
161impl IncrementalNystroem<Untrained> {
162 fn select_components(
164 &self,
165 x: &Array2<Float>,
166 n_components: usize,
167 rng: &mut RealStdRng,
168 ) -> Result<Vec<usize>> {
169 let (n_samples, _) = x.dim();
170
171 match &self.sampling_strategy {
172 SamplingStrategy::Random => {
173 let mut indices: Vec<usize> = (0..n_samples).collect();
174 indices.shuffle(rng);
175 Ok(indices[..n_components].to_vec())
176 }
177 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
178 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
179 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
180 }
181 }
182
183 fn kmeans_sampling(
185 &self,
186 x: &Array2<Float>,
187 n_components: usize,
188 rng: &mut RealStdRng,
189 ) -> Result<Vec<usize>> {
190 let (n_samples, n_features) = x.dim();
191 let mut centers = Array2::zeros((n_components, n_features));
192
193 let mut indices: Vec<usize> = (0..n_samples).collect();
195 indices.shuffle(rng);
196 for (i, &idx) in indices[..n_components].iter().enumerate() {
197 centers.row_mut(i).assign(&x.row(idx));
198 }
199
200 for _iter in 0..5 {
202 let mut assignments = vec![0; n_samples];
203
204 for i in 0..n_samples {
206 let mut min_dist = Float::INFINITY;
207 let mut best_center = 0;
208
209 for j in 0..n_components {
210 let diff = &x.row(i) - ¢ers.row(j);
211 let dist = diff.dot(&diff);
212 if dist < min_dist {
213 min_dist = dist;
214 best_center = j;
215 }
216 }
217 assignments[i] = best_center;
218 }
219
220 for j in 0..n_components {
222 let cluster_points: Vec<usize> = assignments
223 .iter()
224 .enumerate()
225 .filter(|(_, &assignment)| assignment == j)
226 .map(|(i, _)| i)
227 .collect();
228
229 if !cluster_points.is_empty() {
230 let mut new_center = Array1::zeros(n_features);
231 for &point_idx in &cluster_points {
232 new_center = new_center + x.row(point_idx);
233 }
234 new_center /= cluster_points.len() as Float;
235 centers.row_mut(j).assign(&new_center);
236 }
237 }
238 }
239
240 let mut selected_indices = Vec::new();
242 for j in 0..n_components {
243 let mut min_dist = Float::INFINITY;
244 let mut best_point = 0;
245
246 for i in 0..n_samples {
247 let diff = &x.row(i) - ¢ers.row(j);
248 let dist = diff.dot(&diff);
249 if dist < min_dist {
250 min_dist = dist;
251 best_point = i;
252 }
253 }
254 selected_indices.push(best_point);
255 }
256
257 selected_indices.sort_unstable();
258 selected_indices.dedup();
259
260 while selected_indices.len() < n_components {
262 let random_idx = rng.gen_range(0..n_samples);
263 if !selected_indices.contains(&random_idx) {
264 selected_indices.push(random_idx);
265 }
266 }
267
268 Ok(selected_indices[..n_components].to_vec())
269 }
270
271 fn leverage_score_sampling(
273 &self,
274 x: &Array2<Float>,
275 n_components: usize,
276 _rng: &mut RealStdRng,
277 ) -> Result<Vec<usize>> {
278 let (n_samples, _) = x.dim();
279
280 let mut scores = Vec::new();
283 for i in 0..n_samples {
284 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
285 scores.push(row_norm + 1e-10); }
287
288 let total_score: Float = scores.iter().sum();
290 if total_score <= 0.0 {
291 return Err(SklearsError::InvalidInput(
292 "All scores are zero or negative".to_string(),
293 ));
294 }
295
296 let mut cumulative = Vec::with_capacity(scores.len());
298 let mut sum = 0.0;
299 for &score in &scores {
300 sum += score / total_score;
301 cumulative.push(sum);
302 }
303
304 let mut selected_indices = Vec::new();
305 for _ in 0..n_components {
306 let r = thread_rng().gen::<Float>();
307 let mut idx = cumulative
309 .iter()
310 .position(|&cum| cum >= r)
311 .unwrap_or(scores.len() - 1);
312
313 while selected_indices.contains(&idx) {
315 let r = thread_rng().gen::<Float>();
316 idx = cumulative
317 .iter()
318 .position(|&cum| cum >= r)
319 .unwrap_or(scores.len() - 1);
320 }
321 selected_indices.push(idx);
322 }
323
324 Ok(selected_indices)
325 }
326
327 fn column_norm_sampling(
329 &self,
330 x: &Array2<Float>,
331 n_components: usize,
332 rng: &mut RealStdRng,
333 ) -> Result<Vec<usize>> {
334 let (n_samples, _) = x.dim();
335
336 let mut norms = Vec::new();
338 for i in 0..n_samples {
339 let norm = x.row(i).dot(&x.row(i)).sqrt();
340 norms.push(norm + 1e-10);
341 }
342
343 let mut indices_with_norms: Vec<(usize, Float)> = norms
345 .iter()
346 .enumerate()
347 .map(|(i, &norm)| (i, norm))
348 .collect();
349 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
350
351 let mut selected_indices = Vec::new();
352 let step = n_samples.max(1) / n_components.max(1);
353
354 for i in 0..n_components {
355 let idx = (i * step).min(n_samples - 1);
356 selected_indices.push(indices_with_norms[idx].0);
357 }
358
359 while selected_indices.len() < n_components {
361 let random_idx = rng.gen_range(0..n_samples);
362 if !selected_indices.contains(&random_idx) {
363 selected_indices.push(random_idx);
364 }
365 }
366
367 Ok(selected_indices)
368 }
369
370 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
372 let (_, n_features) = x.dim();
373 let mut landmarks = Array2::zeros((indices.len(), n_features));
374
375 for (i, &idx) in indices.iter().enumerate() {
376 landmarks.row_mut(i).assign(&x.row(idx));
377 }
378
379 landmarks
380 }
381
382 fn compute_eigendecomposition(
385 &self,
386 matrix: Array2<Float>,
387 ) -> Result<(Array1<Float>, Array2<Float>)> {
388 let n = matrix.nrows();
389
390 if n != matrix.ncols() {
391 return Err(SklearsError::InvalidInput(
392 "Matrix must be square for eigendecomposition".to_string(),
393 ));
394 }
395
396 let mut eigenvals = Array1::zeros(n);
397 let mut eigenvecs = Array2::zeros((n, n));
398
399 let mut deflated_matrix = matrix.clone();
401
402 for k in 0..n {
403 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
405
406 eigenvals[k] = eigenval;
407 eigenvecs.column_mut(k).assign(&eigenvec);
408
409 for i in 0..n {
411 for j in 0..n {
412 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
413 }
414 }
415 }
416
417 let mut indices: Vec<usize> = (0..n).collect();
419 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
420
421 let mut sorted_eigenvals = Array1::zeros(n);
422 let mut sorted_eigenvecs = Array2::zeros((n, n));
423
424 for (new_idx, &old_idx) in indices.iter().enumerate() {
425 sorted_eigenvals[new_idx] = eigenvals[old_idx];
426 sorted_eigenvecs
427 .column_mut(new_idx)
428 .assign(&eigenvecs.column(old_idx));
429 }
430
431 Ok((sorted_eigenvals, sorted_eigenvecs))
432 }
433
434 fn power_iteration(
436 &self,
437 matrix: &Array2<Float>,
438 max_iter: usize,
439 tol: Float,
440 ) -> Result<(Float, Array1<Float>)> {
441 let n = matrix.nrows();
442
443 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
445
446 let norm = v.dot(&v).sqrt();
448 if norm < 1e-10 {
449 return Err(SklearsError::InvalidInput(
450 "Initial vector has zero norm".to_string(),
451 ));
452 }
453 v /= norm;
454
455 let mut eigenval = 0.0;
456
457 for _iter in 0..max_iter {
458 let w = matrix.dot(&v);
460
461 let new_eigenval = v.dot(&w);
463
464 let w_norm = w.dot(&w).sqrt();
466 if w_norm < 1e-10 {
467 break;
468 }
469 let new_v = w / w_norm;
470
471 let eigenval_change = (new_eigenval - eigenval).abs();
473 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
474
475 if eigenval_change < tol && vector_change < tol {
476 return Ok((new_eigenval, new_v));
477 }
478
479 eigenval = new_eigenval;
480 v = new_v;
481 }
482
483 Ok((eigenval, v))
484 }
485
486 fn compute_decomposition(
488 &self,
489 mut kernel_matrix: Array2<Float>,
490 ) -> Result<(Array2<Float>, Array2<Float>)> {
491 let reg = 1e-8;
493 for i in 0..kernel_matrix.nrows() {
494 kernel_matrix[[i, i]] += reg;
495 }
496
497 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
499
500 let threshold = 1e-8;
502 let valid_indices: Vec<usize> = eigenvals
503 .iter()
504 .enumerate()
505 .filter(|(_, &val)| val > threshold)
506 .map(|(i, _)| i)
507 .collect();
508
509 if valid_indices.is_empty() {
510 return Err(SklearsError::InvalidInput(
511 "No valid eigenvalues found in kernel matrix".to_string(),
512 ));
513 }
514
515 let n_valid = valid_indices.len();
517 let mut components = Array2::zeros((eigenvals.len(), n_valid));
518 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
519
520 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
521 let sqrt_eigenval = eigenvals[old_idx].sqrt();
522 components
523 .column_mut(new_idx)
524 .assign(&eigenvecs.column(old_idx));
525
526 for i in 0..eigenvals.len() {
528 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
529 }
530 }
531
532 Ok((components, normalization))
533 }
534}
535
536impl IncrementalNystroem<Trained> {
537 fn select_components(
539 &self,
540 x: &Array2<Float>,
541 n_components: usize,
542 rng: &mut RealStdRng,
543 ) -> Result<Vec<usize>> {
544 let (n_samples, _) = x.dim();
545
546 match &self.sampling_strategy {
547 SamplingStrategy::Random => {
548 let mut indices: Vec<usize> = (0..n_samples).collect();
549 indices.shuffle(rng);
550 Ok(indices[..n_components].to_vec())
551 }
552 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
553 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
554 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
555 }
556 }
557
558 fn kmeans_sampling(
560 &self,
561 x: &Array2<Float>,
562 n_components: usize,
563 rng: &mut RealStdRng,
564 ) -> Result<Vec<usize>> {
565 let (n_samples, n_features) = x.dim();
566 let mut centers = Array2::zeros((n_components, n_features));
567
568 let mut indices: Vec<usize> = (0..n_samples).collect();
570 indices.shuffle(rng);
571 for (i, &idx) in indices[..n_components].iter().enumerate() {
572 centers.row_mut(i).assign(&x.row(idx));
573 }
574
575 for _iter in 0..5 {
577 let mut assignments = vec![0; n_samples];
578
579 for i in 0..n_samples {
581 let mut min_dist = Float::INFINITY;
582 let mut best_center = 0;
583
584 for j in 0..n_components {
585 let diff = &x.row(i) - ¢ers.row(j);
586 let dist = diff.dot(&diff);
587 if dist < min_dist {
588 min_dist = dist;
589 best_center = j;
590 }
591 }
592 assignments[i] = best_center;
593 }
594
595 for j in 0..n_components {
597 let cluster_points: Vec<usize> = assignments
598 .iter()
599 .enumerate()
600 .filter(|(_, &assignment)| assignment == j)
601 .map(|(i, _)| i)
602 .collect();
603
604 if !cluster_points.is_empty() {
605 let mut new_center = Array1::zeros(n_features);
606 for &point_idx in &cluster_points {
607 new_center = new_center + x.row(point_idx);
608 }
609 new_center /= cluster_points.len() as Float;
610 centers.row_mut(j).assign(&new_center);
611 }
612 }
613 }
614
615 let mut selected_indices = Vec::new();
617 for j in 0..n_components {
618 let mut min_dist = Float::INFINITY;
619 let mut best_point = 0;
620
621 for i in 0..n_samples {
622 let diff = &x.row(i) - ¢ers.row(j);
623 let dist = diff.dot(&diff);
624 if dist < min_dist {
625 min_dist = dist;
626 best_point = i;
627 }
628 }
629 selected_indices.push(best_point);
630 }
631
632 selected_indices.sort_unstable();
633 selected_indices.dedup();
634
635 while selected_indices.len() < n_components {
637 let random_idx = rng.gen_range(0..n_samples);
638 if !selected_indices.contains(&random_idx) {
639 selected_indices.push(random_idx);
640 }
641 }
642
643 Ok(selected_indices[..n_components].to_vec())
644 }
645
646 fn leverage_score_sampling(
648 &self,
649 x: &Array2<Float>,
650 n_components: usize,
651 _rng: &mut RealStdRng,
652 ) -> Result<Vec<usize>> {
653 let (n_samples, _) = x.dim();
654
655 let mut scores = Vec::new();
658 for i in 0..n_samples {
659 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
660 scores.push(row_norm + 1e-10); }
662
663 let total_score: Float = scores.iter().sum();
665 if total_score <= 0.0 {
666 return Err(SklearsError::InvalidInput(
667 "All scores are zero or negative".to_string(),
668 ));
669 }
670
671 let mut cumulative = Vec::with_capacity(scores.len());
673 let mut sum = 0.0;
674 for &score in &scores {
675 sum += score / total_score;
676 cumulative.push(sum);
677 }
678
679 let mut selected_indices = Vec::new();
680 for _ in 0..n_components {
681 let r = thread_rng().gen::<Float>();
682 let mut idx = cumulative
684 .iter()
685 .position(|&cum| cum >= r)
686 .unwrap_or(scores.len() - 1);
687
688 while selected_indices.contains(&idx) {
690 let r = thread_rng().gen::<Float>();
691 idx = cumulative
692 .iter()
693 .position(|&cum| cum >= r)
694 .unwrap_or(scores.len() - 1);
695 }
696 selected_indices.push(idx);
697 }
698
699 Ok(selected_indices)
700 }
701
702 fn column_norm_sampling(
704 &self,
705 x: &Array2<Float>,
706 n_components: usize,
707 rng: &mut RealStdRng,
708 ) -> Result<Vec<usize>> {
709 let (n_samples, _) = x.dim();
710
711 let mut norms = Vec::new();
713 for i in 0..n_samples {
714 let norm = x.row(i).dot(&x.row(i)).sqrt();
715 norms.push(norm + 1e-10);
716 }
717
718 let mut indices_with_norms: Vec<(usize, Float)> = norms
720 .iter()
721 .enumerate()
722 .map(|(i, &norm)| (i, norm))
723 .collect();
724 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
725
726 let mut selected_indices = Vec::new();
727 let step = n_samples.max(1) / n_components.max(1);
728
729 for i in 0..n_components {
730 let idx = (i * step).min(n_samples - 1);
731 selected_indices.push(indices_with_norms[idx].0);
732 }
733
734 while selected_indices.len() < n_components {
736 let random_idx = rng.gen_range(0..n_samples);
737 if !selected_indices.contains(&random_idx) {
738 selected_indices.push(random_idx);
739 }
740 }
741
742 Ok(selected_indices)
743 }
744
745 pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
747 match &self.accumulated_data_ {
749 Some(existing) => {
750 let combined =
751 scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
752 self.accumulated_data_ = Some(combined);
753 }
754 None => {
755 self.accumulated_data_ = Some(x_new.clone());
756 }
757 }
758
759 let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
761 accumulated.nrows() >= self.min_update_size
762 } else {
763 false
764 };
765
766 if should_update {
767 if let Some(accumulated) = self.accumulated_data_.take() {
768 self = self.perform_update(&accumulated)?;
769 self.update_count_ += 1;
770 }
771 }
772
773 Ok(self)
774 }
775
776 fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
778 match self.update_strategy.clone() {
779 UpdateStrategy::Append => self.append_update(new_data),
780 UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
781 UpdateStrategy::Merge => self.merge_update(new_data),
782 UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
783 }
784 }
785
786 fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
788 let current_landmarks = self.landmark_data_.as_ref().unwrap();
789 let current_components = current_landmarks.nrows();
790
791 if current_components >= self.n_components {
792 return Ok(self);
794 }
795
796 let available_space = self.n_components - current_components;
797 let n_new = available_space.min(new_data.nrows());
798
799 if n_new == 0 {
800 return Ok(self);
801 }
802
803 let mut rng = match self.random_state {
805 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
806 None => RealStdRng::from_seed(thread_rng().gen()),
807 };
808
809 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
810 indices.shuffle(&mut rng);
811 let selected_indices = &indices[..n_new];
812
813 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
815
816 let combined_landmarks =
818 scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
819
820 let kernel_matrix = self
822 .kernel
823 .compute_kernel(&combined_landmarks, &combined_landmarks);
824 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
825
826 let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
828 let base_index = current_landmarks.nrows();
829 for &idx in selected_indices {
830 new_component_indices.push(base_index + idx);
831 }
832
833 self.components_ = Some(components);
834 self.normalization_ = Some(normalization);
835 self.component_indices_ = Some(new_component_indices);
836 self.landmark_data_ = Some(combined_landmarks);
837
838 Ok(self)
839 }
840
841 fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
843 let current_landmarks = self.landmark_data_.as_ref().unwrap();
844 let n_new = new_data.nrows().min(self.n_components);
845
846 if n_new == 0 {
847 return Ok(self);
848 }
849
850 let mut rng = match self.random_state {
852 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
853 None => RealStdRng::from_seed(thread_rng().gen()),
854 };
855
856 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
857 indices.shuffle(&mut rng);
858 let selected_indices = &indices[..n_new];
859
860 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
861
862 let n_keep = self.n_components - n_new;
864 let combined_landmarks = if n_keep > 0 {
865 let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
866 scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
867 } else {
868 new_landmarks
869 };
870
871 let kernel_matrix = self
873 .kernel
874 .compute_kernel(&combined_landmarks, &combined_landmarks);
875 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
876
877 let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
879
880 self.components_ = Some(components);
881 self.normalization_ = Some(normalization);
882 self.component_indices_ = Some(new_component_indices);
883 self.landmark_data_ = Some(combined_landmarks);
884
885 Ok(self)
886 }
887
888 fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
890 let current_landmarks = self.landmark_data_.as_ref().unwrap();
894 let _current_components = self.components_.as_ref().unwrap();
895 let _current_normalization = self.normalization_.as_ref().unwrap();
896
897 let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
899
900 let mut rng = match self.random_state {
901 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
902 None => RealStdRng::from_seed(thread_rng().gen()),
903 };
904
905 let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
907 let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
908
909 let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
911 let (_new_components, _new_normalization) =
912 self.compute_decomposition(new_kernel_matrix)?;
913
914 let merged_landmarks =
917 self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
918
919 let merged_kernel_matrix = self
921 .kernel
922 .compute_kernel(&merged_landmarks, &merged_landmarks);
923 let (final_components, final_normalization) =
924 self.compute_decomposition(merged_kernel_matrix)?;
925
926 let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
928
929 let mut updated_self = self;
930 updated_self.components_ = Some(final_components);
931 updated_self.normalization_ = Some(final_normalization);
932 updated_self.component_indices_ = Some(final_component_indices);
933 updated_self.landmark_data_ = Some(merged_landmarks);
934
935 Ok(updated_self)
936 }
937
938 fn merge_landmarks_intelligently(
940 &self,
941 current_landmarks: &Array2<Float>,
942 new_landmarks: &Array2<Float>,
943 rng: &mut RealStdRng,
944 ) -> Result<Array2<Float>> {
945 let n_current = current_landmarks.nrows();
946 let n_new = new_landmarks.nrows();
947 let n_features = current_landmarks.ncols();
948
949 let all_landmarks = scirs2_core::ndarray::concatenate![
951 Axis(0),
952 current_landmarks.clone(),
953 new_landmarks.clone()
954 ];
955
956 let n_target = self.n_components.min(n_current + n_new);
958 let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
959
960 let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
962 for (i, &idx) in selected_indices.iter().enumerate() {
963 merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
964 }
965
966 Ok(merged_landmarks)
967 }
968
969 fn select_diverse_landmarks(
971 &self,
972 landmarks: &Array2<Float>,
973 n_select: usize,
974 rng: &mut RealStdRng,
975 ) -> Result<Vec<usize>> {
976 let n_landmarks = landmarks.nrows();
977
978 if n_select >= n_landmarks {
979 return Ok((0..n_landmarks).collect());
980 }
981
982 let mut selected = Vec::new();
983 let mut available: Vec<usize> = (0..n_landmarks).collect();
984
985 let first_idx = rng.gen_range(0..available.len());
987 selected.push(available.remove(first_idx));
988
989 while selected.len() < n_select && !available.is_empty() {
991 let mut best_idx = 0;
992 let mut max_min_distance = 0.0;
993
994 for (i, &candidate_idx) in available.iter().enumerate() {
995 let mut min_distance = Float::INFINITY;
997
998 for &selected_idx in &selected {
999 let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1000 let distance = diff.dot(&diff).sqrt();
1001 if distance < min_distance {
1002 min_distance = distance;
1003 }
1004 }
1005
1006 if min_distance > max_min_distance {
1007 max_min_distance = min_distance;
1008 best_idx = i;
1009 }
1010 }
1011
1012 selected.push(available.remove(best_idx));
1013 }
1014
1015 Ok(selected)
1016 }
1017
1018 fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1020 let current_landmarks = self.landmark_data_.as_ref().unwrap();
1023
1024 let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1026
1027 let mut best_update = self.clone();
1029 let mut best_quality = current_quality;
1030
1031 let append_candidate = self.clone().append_update(new_data)?;
1033 let append_quality = append_candidate.evaluate_approximation_quality(
1034 append_candidate.landmark_data_.as_ref().unwrap(),
1035 new_data,
1036 )?;
1037
1038 if append_quality > best_quality + threshold {
1039 best_update = append_candidate;
1040 best_quality = append_quality;
1041 }
1042
1043 if new_data.nrows() >= 3 {
1045 let merge_candidate = self.clone().merge_update(new_data)?;
1046 let merge_quality = merge_candidate.evaluate_approximation_quality(
1047 merge_candidate.landmark_data_.as_ref().unwrap(),
1048 new_data,
1049 )?;
1050
1051 if merge_quality > best_quality + threshold {
1052 best_update = merge_candidate;
1053 best_quality = merge_quality;
1054 }
1055 }
1056
1057 let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1059 let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1060 sliding_candidate.landmark_data_.as_ref().unwrap(),
1061 new_data,
1062 )?;
1063
1064 if sliding_quality > best_quality + threshold {
1065 best_update = sliding_candidate;
1066 best_quality = sliding_quality;
1067 }
1068
1069 if best_quality > current_quality + threshold {
1071 Ok(best_update)
1072 } else {
1073 Ok(self)
1075 }
1076 }
1077
1078 fn evaluate_approximation_quality(
1080 &self,
1081 landmarks: &Array2<Float>,
1082 test_data: &Array2<Float>,
1083 ) -> Result<Float> {
1084 let n_test = test_data.nrows().min(50); let test_subset = if test_data.nrows() > n_test {
1088 let mut rng = thread_rng();
1090 let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1091 indices.shuffle(&mut rng);
1092 test_data.select(Axis(0), &indices[..n_test])
1093 } else {
1094 test_data.to_owned()
1095 };
1096
1097 let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1099
1100 let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1102 let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1103
1104 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1106
1107 let threshold = 1e-8;
1109 let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1110
1111 for i in 0..landmarks.nrows() {
1112 for j in 0..landmarks.nrows() {
1113 let mut sum = 0.0;
1114 for k in 0..eigenvals.len() {
1115 if eigenvals[k] > threshold {
1116 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1117 }
1118 }
1119 pseudo_inverse[[i, j]] = sum;
1120 }
1121 }
1122
1123 let k_approx = k_test_landmarks
1125 .dot(&pseudo_inverse)
1126 .dot(&k_test_landmarks.t());
1127
1128 let error_matrix = &k_exact - &k_approx;
1130 let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1131
1132 let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1134
1135 Ok(quality)
1136 }
1137
1138 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1140 let (_, n_features) = x.dim();
1141 let mut landmarks = Array2::zeros((indices.len(), n_features));
1142
1143 for (i, &idx) in indices.iter().enumerate() {
1144 landmarks.row_mut(i).assign(&x.row(idx));
1145 }
1146
1147 landmarks
1148 }
1149
1150 fn compute_eigendecomposition(
1153 &self,
1154 matrix: Array2<Float>,
1155 ) -> Result<(Array1<Float>, Array2<Float>)> {
1156 let n = matrix.nrows();
1157
1158 if n != matrix.ncols() {
1159 return Err(SklearsError::InvalidInput(
1160 "Matrix must be square for eigendecomposition".to_string(),
1161 ));
1162 }
1163
1164 let mut eigenvals = Array1::zeros(n);
1165 let mut eigenvecs = Array2::zeros((n, n));
1166
1167 let mut deflated_matrix = matrix.clone();
1169
1170 for k in 0..n {
1171 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1173
1174 eigenvals[k] = eigenval;
1175 eigenvecs.column_mut(k).assign(&eigenvec);
1176
1177 for i in 0..n {
1179 for j in 0..n {
1180 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1181 }
1182 }
1183 }
1184
1185 let mut indices: Vec<usize> = (0..n).collect();
1187 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1188
1189 let mut sorted_eigenvals = Array1::zeros(n);
1190 let mut sorted_eigenvecs = Array2::zeros((n, n));
1191
1192 for (new_idx, &old_idx) in indices.iter().enumerate() {
1193 sorted_eigenvals[new_idx] = eigenvals[old_idx];
1194 sorted_eigenvecs
1195 .column_mut(new_idx)
1196 .assign(&eigenvecs.column(old_idx));
1197 }
1198
1199 Ok((sorted_eigenvals, sorted_eigenvecs))
1200 }
1201
1202 fn power_iteration(
1204 &self,
1205 matrix: &Array2<Float>,
1206 max_iter: usize,
1207 tol: Float,
1208 ) -> Result<(Float, Array1<Float>)> {
1209 let n = matrix.nrows();
1210
1211 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1213
1214 let norm = v.dot(&v).sqrt();
1216 if norm < 1e-10 {
1217 return Err(SklearsError::InvalidInput(
1218 "Initial vector has zero norm".to_string(),
1219 ));
1220 }
1221 v /= norm;
1222
1223 let mut eigenval = 0.0;
1224
1225 for _iter in 0..max_iter {
1226 let w = matrix.dot(&v);
1228
1229 let new_eigenval = v.dot(&w);
1231
1232 let w_norm = w.dot(&w).sqrt();
1234 if w_norm < 1e-10 {
1235 break;
1236 }
1237 let new_v = w / w_norm;
1238
1239 let eigenval_change = (new_eigenval - eigenval).abs();
1241 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1242
1243 if eigenval_change < tol && vector_change < tol {
1244 return Ok((new_eigenval, new_v));
1245 }
1246
1247 eigenval = new_eigenval;
1248 v = new_v;
1249 }
1250
1251 Ok((eigenval, v))
1252 }
1253
1254 fn compute_decomposition(
1256 &self,
1257 mut kernel_matrix: Array2<Float>,
1258 ) -> Result<(Array2<Float>, Array2<Float>)> {
1259 let reg = 1e-8;
1261 for i in 0..kernel_matrix.nrows() {
1262 kernel_matrix[[i, i]] += reg;
1263 }
1264
1265 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1267
1268 let threshold = 1e-8;
1270 let valid_indices: Vec<usize> = eigenvals
1271 .iter()
1272 .enumerate()
1273 .filter(|(_, &val)| val > threshold)
1274 .map(|(i, _)| i)
1275 .collect();
1276
1277 if valid_indices.is_empty() {
1278 return Err(SklearsError::InvalidInput(
1279 "No valid eigenvalues found in kernel matrix".to_string(),
1280 ));
1281 }
1282
1283 let n_valid = valid_indices.len();
1285 let mut components = Array2::zeros((eigenvals.len(), n_valid));
1286 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1287
1288 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1289 let sqrt_eigenval = eigenvals[old_idx].sqrt();
1290 components
1291 .column_mut(new_idx)
1292 .assign(&eigenvecs.column(old_idx));
1293
1294 for i in 0..eigenvals.len() {
1296 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1297 }
1298 }
1299
1300 Ok((components, normalization))
1301 }
1302
1303 pub fn update_count(&self) -> usize {
1305 self.update_count_
1306 }
1307
1308 pub fn n_landmarks(&self) -> usize {
1310 self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1311 }
1312}
1313
1314impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1315 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1316 let _components = self
1317 .components_
1318 .as_ref()
1319 .ok_or_else(|| SklearsError::NotFitted {
1320 operation: "transform".to_string(),
1321 })?;
1322
1323 let normalization =
1324 self.normalization_
1325 .as_ref()
1326 .ok_or_else(|| SklearsError::NotFitted {
1327 operation: "transform".to_string(),
1328 })?;
1329
1330 let landmark_data =
1331 self.landmark_data_
1332 .as_ref()
1333 .ok_or_else(|| SklearsError::NotFitted {
1334 operation: "transform".to_string(),
1335 })?;
1336
1337 let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1339
1340 let transformed = kernel_x_landmarks.dot(&normalization.t());
1342
1343 Ok(transformed)
1344 }
1345}
1346
1347#[allow(non_snake_case)]
1348#[cfg(test)]
1349mod tests {
1350 use super::*;
1351 use scirs2_core::ndarray::array;
1352
1353 #[test]
1354 fn test_incremental_nystroem_basic() {
1355 let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1356 let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1357
1358 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1359 .update_strategy(UpdateStrategy::Append)
1360 .min_update_size(1);
1361
1362 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1363 assert_eq!(fitted.n_landmarks(), 3);
1364
1365 let updated = fitted.update(&x_new).unwrap();
1366 assert_eq!(updated.n_landmarks(), 5);
1367 assert_eq!(updated.update_count(), 1);
1368 }
1369
1370 #[test]
1371 fn test_incremental_transform() {
1372 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1373 let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1374
1375 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1376 let fitted = nystroem.fit(&x_train, &()).unwrap();
1377
1378 let transformed = fitted.transform(&x_test).unwrap();
1379 assert_eq!(transformed.shape()[0], 2);
1380 assert!(transformed.shape()[1] <= 3);
1381 }
1382
1383 #[test]
1384 fn test_sliding_window_update() {
1385 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1386 let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1387
1388 let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1389 .update_strategy(UpdateStrategy::SlidingWindow)
1390 .min_update_size(1);
1391
1392 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1393 let updated = fitted.update(&x_new).unwrap();
1394
1395 assert_eq!(updated.n_landmarks(), 3);
1396 assert_eq!(updated.update_count(), 1);
1397 }
1398
1399 #[test]
1400 fn test_different_kernels() {
1401 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1402
1403 let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1405 let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1406 let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1407 assert_eq!(rbf_transformed.shape()[0], 3);
1408
1409 let poly_nystroem = IncrementalNystroem::new(
1411 Kernel::Polynomial {
1412 gamma: 1.0,
1413 coef0: 1.0,
1414 degree: 2,
1415 },
1416 3,
1417 );
1418 let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1419 let poly_transformed = poly_fitted.transform(&x).unwrap();
1420 assert_eq!(poly_transformed.shape()[0], 3);
1421 }
1422
1423 #[test]
1424 fn test_min_update_size() {
1425 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1426 let x_small = array![[3.0, 4.0]];
1427 let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1428
1429 let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1430
1431 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1432
1433 let after_small = fitted.update(&x_small).unwrap();
1435 assert_eq!(after_small.update_count(), 0);
1436 assert_eq!(after_small.n_landmarks(), 2);
1437
1438 let after_large = after_small.update(&x_large).unwrap();
1440 assert_eq!(after_large.update_count(), 1);
1441 assert_eq!(after_large.n_landmarks(), 5);
1442 }
1443
1444 #[test]
1445 fn test_reproducibility() {
1446 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1447 let x_new = array![[4.0, 5.0]];
1448
1449 let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1450 .random_state(42)
1451 .min_update_size(1);
1452 let fitted1 = nystroem1.fit(&x, &()).unwrap();
1453 let updated1 = fitted1.update(&x_new).unwrap();
1454 let result1 = updated1.transform(&x).unwrap();
1455
1456 let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1457 .random_state(42)
1458 .min_update_size(1);
1459 let fitted2 = nystroem2.fit(&x, &()).unwrap();
1460 let updated2 = fitted2.update(&x_new).unwrap();
1461 let result2 = updated2.transform(&x).unwrap();
1462
1463 assert_eq!(result1.shape(), result2.shape());
1466
1467 let mut direct_match = true;
1469 let mut sign_flip_match = true;
1470
1471 for i in 0..result1.len() {
1472 let val1 = result1.as_slice().unwrap()[i];
1473 let val2 = result2.as_slice().unwrap()[i];
1474
1475 if (val1 - val2).abs() > 1e-6 {
1476 direct_match = false;
1477 }
1478 if (val1 + val2).abs() > 1e-6 {
1479 sign_flip_match = false;
1480 }
1481 }
1482
1483 assert!(
1484 direct_match || sign_flip_match,
1485 "Results differ too much and are not related by sign flip"
1486 );
1487 }
1488}