1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
7use scirs2_core::random::rngs::StdRng as RealStdRng;
8use scirs2_core::random::seq::SliceRandom;
9use scirs2_core::random::Rng;
10use scirs2_core::random::{thread_rng, SeedableRng};
11use sklears_core::{
12 error::{Result, SklearsError},
13 traits::{Estimator, Fit, Trained, Transform, Untrained},
14 types::Float,
15};
16use std::marker::PhantomData;
17
18use crate::nystroem::{Kernel, SamplingStrategy};
19
20#[derive(Debug, Clone)]
22pub enum UpdateStrategy {
24 Append,
26 SlidingWindow,
28 Merge,
30 Selective { threshold: Float },
32}
33
34#[derive(Debug, Clone)]
52pub struct IncrementalNystroem<State = Untrained> {
53 pub kernel: Kernel,
54 pub n_components: usize,
55 pub update_strategy: UpdateStrategy,
56 pub min_update_size: usize,
57 pub sampling_strategy: SamplingStrategy,
58 pub random_state: Option<u64>,
59
60 components_: Option<Array2<Float>>,
62 normalization_: Option<Array2<Float>>,
63 component_indices_: Option<Vec<usize>>,
64 landmark_data_: Option<Array2<Float>>,
65 update_count_: usize,
66 accumulated_data_: Option<Array2<Float>>,
67
68 _state: PhantomData<State>,
69}
70
71impl IncrementalNystroem<Untrained> {
72 pub fn new(kernel: Kernel, n_components: usize) -> Self {
74 Self {
75 kernel,
76 n_components,
77 update_strategy: UpdateStrategy::Append,
78 min_update_size: 10,
79 sampling_strategy: SamplingStrategy::Random,
80 random_state: None,
81 components_: None,
82 normalization_: None,
83 component_indices_: None,
84 landmark_data_: None,
85 update_count_: 0,
86 accumulated_data_: None,
87 _state: PhantomData,
88 }
89 }
90
91 pub fn update_strategy(mut self, strategy: UpdateStrategy) -> Self {
93 self.update_strategy = strategy;
94 self
95 }
96
97 pub fn min_update_size(mut self, size: usize) -> Self {
99 self.min_update_size = size;
100 self
101 }
102
103 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
105 self.sampling_strategy = strategy;
106 self
107 }
108
109 pub fn random_state(mut self, seed: u64) -> Self {
111 self.random_state = Some(seed);
112 self
113 }
114}
115
116impl Estimator for IncrementalNystroem<Untrained> {
117 type Config = ();
118 type Error = SklearsError;
119 type Float = Float;
120
121 fn config(&self) -> &Self::Config {
122 &()
123 }
124}
125
126impl Fit<Array2<Float>, ()> for IncrementalNystroem<Untrained> {
127 type Fitted = IncrementalNystroem<Trained>;
128
129 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
130 let (n_samples, _n_features) = x.dim();
131 let n_components = self.n_components.min(n_samples);
132
133 let mut rng = match self.random_state {
134 Some(seed) => RealStdRng::seed_from_u64(seed),
135 None => RealStdRng::from_seed(thread_rng().gen()),
136 };
137
138 let component_indices = self.select_components(x, n_components, &mut rng)?;
140 let landmark_data = self.extract_landmarks(x, &component_indices);
141
142 let kernel_matrix = self.kernel.compute_kernel(&landmark_data, &landmark_data);
144
145 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
147
148 Ok(IncrementalNystroem {
149 kernel: self.kernel,
150 n_components: self.n_components,
151 update_strategy: self.update_strategy,
152 min_update_size: self.min_update_size,
153 sampling_strategy: self.sampling_strategy,
154 random_state: self.random_state,
155 components_: Some(components),
156 normalization_: Some(normalization),
157 component_indices_: Some(component_indices),
158 landmark_data_: Some(landmark_data),
159 update_count_: 0,
160 accumulated_data_: None,
161 _state: PhantomData,
162 })
163 }
164}
165
166impl IncrementalNystroem<Untrained> {
167 fn select_components(
169 &self,
170 x: &Array2<Float>,
171 n_components: usize,
172 rng: &mut RealStdRng,
173 ) -> Result<Vec<usize>> {
174 let (n_samples, _) = x.dim();
175
176 match &self.sampling_strategy {
177 SamplingStrategy::Random => {
178 let mut indices: Vec<usize> = (0..n_samples).collect();
179 indices.shuffle(rng);
180 Ok(indices[..n_components].to_vec())
181 }
182 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
183 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
184 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
185 }
186 }
187
188 fn kmeans_sampling(
190 &self,
191 x: &Array2<Float>,
192 n_components: usize,
193 rng: &mut RealStdRng,
194 ) -> Result<Vec<usize>> {
195 let (n_samples, n_features) = x.dim();
196 let mut centers = Array2::zeros((n_components, n_features));
197
198 let mut indices: Vec<usize> = (0..n_samples).collect();
200 indices.shuffle(rng);
201 for (i, &idx) in indices[..n_components].iter().enumerate() {
202 centers.row_mut(i).assign(&x.row(idx));
203 }
204
205 for _iter in 0..5 {
207 let mut assignments = vec![0; n_samples];
208
209 for i in 0..n_samples {
211 let mut min_dist = Float::INFINITY;
212 let mut best_center = 0;
213
214 for j in 0..n_components {
215 let diff = &x.row(i) - ¢ers.row(j);
216 let dist = diff.dot(&diff);
217 if dist < min_dist {
218 min_dist = dist;
219 best_center = j;
220 }
221 }
222 assignments[i] = best_center;
223 }
224
225 for j in 0..n_components {
227 let cluster_points: Vec<usize> = assignments
228 .iter()
229 .enumerate()
230 .filter(|(_, &assignment)| assignment == j)
231 .map(|(i, _)| i)
232 .collect();
233
234 if !cluster_points.is_empty() {
235 let mut new_center = Array1::zeros(n_features);
236 for &point_idx in &cluster_points {
237 new_center = new_center + x.row(point_idx);
238 }
239 new_center /= cluster_points.len() as Float;
240 centers.row_mut(j).assign(&new_center);
241 }
242 }
243 }
244
245 let mut selected_indices = Vec::new();
247 for j in 0..n_components {
248 let mut min_dist = Float::INFINITY;
249 let mut best_point = 0;
250
251 for i in 0..n_samples {
252 let diff = &x.row(i) - ¢ers.row(j);
253 let dist = diff.dot(&diff);
254 if dist < min_dist {
255 min_dist = dist;
256 best_point = i;
257 }
258 }
259 selected_indices.push(best_point);
260 }
261
262 selected_indices.sort_unstable();
263 selected_indices.dedup();
264
265 while selected_indices.len() < n_components {
267 let random_idx = rng.gen_range(0..n_samples);
268 if !selected_indices.contains(&random_idx) {
269 selected_indices.push(random_idx);
270 }
271 }
272
273 Ok(selected_indices[..n_components].to_vec())
274 }
275
276 fn leverage_score_sampling(
278 &self,
279 x: &Array2<Float>,
280 n_components: usize,
281 _rng: &mut RealStdRng,
282 ) -> Result<Vec<usize>> {
283 let (n_samples, _) = x.dim();
284
285 let mut scores = Vec::new();
288 for i in 0..n_samples {
289 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
290 scores.push(row_norm + 1e-10); }
292
293 let total_score: Float = scores.iter().sum();
295 if total_score <= 0.0 {
296 return Err(SklearsError::InvalidInput(
297 "All scores are zero or negative".to_string(),
298 ));
299 }
300
301 let mut cumulative = Vec::with_capacity(scores.len());
303 let mut sum = 0.0;
304 for &score in &scores {
305 sum += score / total_score;
306 cumulative.push(sum);
307 }
308
309 let mut selected_indices = Vec::new();
310 for _ in 0..n_components {
311 let r = thread_rng().gen::<Float>();
312 let mut idx = cumulative
314 .iter()
315 .position(|&cum| cum >= r)
316 .unwrap_or(scores.len() - 1);
317
318 while selected_indices.contains(&idx) {
320 let r = thread_rng().gen::<Float>();
321 idx = cumulative
322 .iter()
323 .position(|&cum| cum >= r)
324 .unwrap_or(scores.len() - 1);
325 }
326 selected_indices.push(idx);
327 }
328
329 Ok(selected_indices)
330 }
331
332 fn column_norm_sampling(
334 &self,
335 x: &Array2<Float>,
336 n_components: usize,
337 rng: &mut RealStdRng,
338 ) -> Result<Vec<usize>> {
339 let (n_samples, _) = x.dim();
340
341 let mut norms = Vec::new();
343 for i in 0..n_samples {
344 let norm = x.row(i).dot(&x.row(i)).sqrt();
345 norms.push(norm + 1e-10);
346 }
347
348 let mut indices_with_norms: Vec<(usize, Float)> = norms
350 .iter()
351 .enumerate()
352 .map(|(i, &norm)| (i, norm))
353 .collect();
354 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355
356 let mut selected_indices = Vec::new();
357 let step = n_samples.max(1) / n_components.max(1);
358
359 for i in 0..n_components {
360 let idx = (i * step).min(n_samples - 1);
361 selected_indices.push(indices_with_norms[idx].0);
362 }
363
364 while selected_indices.len() < n_components {
366 let random_idx = rng.gen_range(0..n_samples);
367 if !selected_indices.contains(&random_idx) {
368 selected_indices.push(random_idx);
369 }
370 }
371
372 Ok(selected_indices)
373 }
374
375 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
377 let (_, n_features) = x.dim();
378 let mut landmarks = Array2::zeros((indices.len(), n_features));
379
380 for (i, &idx) in indices.iter().enumerate() {
381 landmarks.row_mut(i).assign(&x.row(idx));
382 }
383
384 landmarks
385 }
386
387 fn compute_eigendecomposition(
390 &self,
391 matrix: Array2<Float>,
392 ) -> Result<(Array1<Float>, Array2<Float>)> {
393 let n = matrix.nrows();
394
395 if n != matrix.ncols() {
396 return Err(SklearsError::InvalidInput(
397 "Matrix must be square for eigendecomposition".to_string(),
398 ));
399 }
400
401 let mut eigenvals = Array1::zeros(n);
402 let mut eigenvecs = Array2::zeros((n, n));
403
404 let mut deflated_matrix = matrix.clone();
406
407 for k in 0..n {
408 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
410
411 eigenvals[k] = eigenval;
412 eigenvecs.column_mut(k).assign(&eigenvec);
413
414 for i in 0..n {
416 for j in 0..n {
417 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
418 }
419 }
420 }
421
422 let mut indices: Vec<usize> = (0..n).collect();
424 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
425
426 let mut sorted_eigenvals = Array1::zeros(n);
427 let mut sorted_eigenvecs = Array2::zeros((n, n));
428
429 for (new_idx, &old_idx) in indices.iter().enumerate() {
430 sorted_eigenvals[new_idx] = eigenvals[old_idx];
431 sorted_eigenvecs
432 .column_mut(new_idx)
433 .assign(&eigenvecs.column(old_idx));
434 }
435
436 Ok((sorted_eigenvals, sorted_eigenvecs))
437 }
438
439 fn power_iteration(
441 &self,
442 matrix: &Array2<Float>,
443 max_iter: usize,
444 tol: Float,
445 ) -> Result<(Float, Array1<Float>)> {
446 let n = matrix.nrows();
447
448 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
450
451 let norm = v.dot(&v).sqrt();
453 if norm < 1e-10 {
454 return Err(SklearsError::InvalidInput(
455 "Initial vector has zero norm".to_string(),
456 ));
457 }
458 v /= norm;
459
460 let mut eigenval = 0.0;
461
462 for _iter in 0..max_iter {
463 let w = matrix.dot(&v);
465
466 let new_eigenval = v.dot(&w);
468
469 let w_norm = w.dot(&w).sqrt();
471 if w_norm < 1e-10 {
472 break;
473 }
474 let new_v = w / w_norm;
475
476 let eigenval_change = (new_eigenval - eigenval).abs();
478 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
479
480 if eigenval_change < tol && vector_change < tol {
481 return Ok((new_eigenval, new_v));
482 }
483
484 eigenval = new_eigenval;
485 v = new_v;
486 }
487
488 Ok((eigenval, v))
489 }
490
491 fn compute_decomposition(
493 &self,
494 mut kernel_matrix: Array2<Float>,
495 ) -> Result<(Array2<Float>, Array2<Float>)> {
496 let reg = 1e-8;
498 for i in 0..kernel_matrix.nrows() {
499 kernel_matrix[[i, i]] += reg;
500 }
501
502 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
504
505 let threshold = 1e-8;
507 let valid_indices: Vec<usize> = eigenvals
508 .iter()
509 .enumerate()
510 .filter(|(_, &val)| val > threshold)
511 .map(|(i, _)| i)
512 .collect();
513
514 if valid_indices.is_empty() {
515 return Err(SklearsError::InvalidInput(
516 "No valid eigenvalues found in kernel matrix".to_string(),
517 ));
518 }
519
520 let n_valid = valid_indices.len();
522 let mut components = Array2::zeros((eigenvals.len(), n_valid));
523 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
524
525 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
526 let sqrt_eigenval = eigenvals[old_idx].sqrt();
527 components
528 .column_mut(new_idx)
529 .assign(&eigenvecs.column(old_idx));
530
531 for i in 0..eigenvals.len() {
533 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
534 }
535 }
536
537 Ok((components, normalization))
538 }
539}
540
541impl IncrementalNystroem<Trained> {
542 fn select_components(
544 &self,
545 x: &Array2<Float>,
546 n_components: usize,
547 rng: &mut RealStdRng,
548 ) -> Result<Vec<usize>> {
549 let (n_samples, _) = x.dim();
550
551 match &self.sampling_strategy {
552 SamplingStrategy::Random => {
553 let mut indices: Vec<usize> = (0..n_samples).collect();
554 indices.shuffle(rng);
555 Ok(indices[..n_components].to_vec())
556 }
557 SamplingStrategy::KMeans => self.kmeans_sampling(x, n_components, rng),
558 SamplingStrategy::LeverageScore => self.leverage_score_sampling(x, n_components, rng),
559 SamplingStrategy::ColumnNorm => self.column_norm_sampling(x, n_components, rng),
560 }
561 }
562
563 fn kmeans_sampling(
565 &self,
566 x: &Array2<Float>,
567 n_components: usize,
568 rng: &mut RealStdRng,
569 ) -> Result<Vec<usize>> {
570 let (n_samples, n_features) = x.dim();
571 let mut centers = Array2::zeros((n_components, n_features));
572
573 let mut indices: Vec<usize> = (0..n_samples).collect();
575 indices.shuffle(rng);
576 for (i, &idx) in indices[..n_components].iter().enumerate() {
577 centers.row_mut(i).assign(&x.row(idx));
578 }
579
580 for _iter in 0..5 {
582 let mut assignments = vec![0; n_samples];
583
584 for i in 0..n_samples {
586 let mut min_dist = Float::INFINITY;
587 let mut best_center = 0;
588
589 for j in 0..n_components {
590 let diff = &x.row(i) - ¢ers.row(j);
591 let dist = diff.dot(&diff);
592 if dist < min_dist {
593 min_dist = dist;
594 best_center = j;
595 }
596 }
597 assignments[i] = best_center;
598 }
599
600 for j in 0..n_components {
602 let cluster_points: Vec<usize> = assignments
603 .iter()
604 .enumerate()
605 .filter(|(_, &assignment)| assignment == j)
606 .map(|(i, _)| i)
607 .collect();
608
609 if !cluster_points.is_empty() {
610 let mut new_center = Array1::zeros(n_features);
611 for &point_idx in &cluster_points {
612 new_center = new_center + x.row(point_idx);
613 }
614 new_center /= cluster_points.len() as Float;
615 centers.row_mut(j).assign(&new_center);
616 }
617 }
618 }
619
620 let mut selected_indices = Vec::new();
622 for j in 0..n_components {
623 let mut min_dist = Float::INFINITY;
624 let mut best_point = 0;
625
626 for i in 0..n_samples {
627 let diff = &x.row(i) - ¢ers.row(j);
628 let dist = diff.dot(&diff);
629 if dist < min_dist {
630 min_dist = dist;
631 best_point = i;
632 }
633 }
634 selected_indices.push(best_point);
635 }
636
637 selected_indices.sort_unstable();
638 selected_indices.dedup();
639
640 while selected_indices.len() < n_components {
642 let random_idx = rng.gen_range(0..n_samples);
643 if !selected_indices.contains(&random_idx) {
644 selected_indices.push(random_idx);
645 }
646 }
647
648 Ok(selected_indices[..n_components].to_vec())
649 }
650
651 fn leverage_score_sampling(
653 &self,
654 x: &Array2<Float>,
655 n_components: usize,
656 _rng: &mut RealStdRng,
657 ) -> Result<Vec<usize>> {
658 let (n_samples, _) = x.dim();
659
660 let mut scores = Vec::new();
663 for i in 0..n_samples {
664 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
665 scores.push(row_norm + 1e-10); }
667
668 let total_score: Float = scores.iter().sum();
670 if total_score <= 0.0 {
671 return Err(SklearsError::InvalidInput(
672 "All scores are zero or negative".to_string(),
673 ));
674 }
675
676 let mut cumulative = Vec::with_capacity(scores.len());
678 let mut sum = 0.0;
679 for &score in &scores {
680 sum += score / total_score;
681 cumulative.push(sum);
682 }
683
684 let mut selected_indices = Vec::new();
685 for _ in 0..n_components {
686 let r = thread_rng().gen::<Float>();
687 let mut idx = cumulative
689 .iter()
690 .position(|&cum| cum >= r)
691 .unwrap_or(scores.len() - 1);
692
693 while selected_indices.contains(&idx) {
695 let r = thread_rng().gen::<Float>();
696 idx = cumulative
697 .iter()
698 .position(|&cum| cum >= r)
699 .unwrap_or(scores.len() - 1);
700 }
701 selected_indices.push(idx);
702 }
703
704 Ok(selected_indices)
705 }
706
707 fn column_norm_sampling(
709 &self,
710 x: &Array2<Float>,
711 n_components: usize,
712 rng: &mut RealStdRng,
713 ) -> Result<Vec<usize>> {
714 let (n_samples, _) = x.dim();
715
716 let mut norms = Vec::new();
718 for i in 0..n_samples {
719 let norm = x.row(i).dot(&x.row(i)).sqrt();
720 norms.push(norm + 1e-10);
721 }
722
723 let mut indices_with_norms: Vec<(usize, Float)> = norms
725 .iter()
726 .enumerate()
727 .map(|(i, &norm)| (i, norm))
728 .collect();
729 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
730
731 let mut selected_indices = Vec::new();
732 let step = n_samples.max(1) / n_components.max(1);
733
734 for i in 0..n_components {
735 let idx = (i * step).min(n_samples - 1);
736 selected_indices.push(indices_with_norms[idx].0);
737 }
738
739 while selected_indices.len() < n_components {
741 let random_idx = rng.gen_range(0..n_samples);
742 if !selected_indices.contains(&random_idx) {
743 selected_indices.push(random_idx);
744 }
745 }
746
747 Ok(selected_indices)
748 }
749
750 pub fn update(mut self, x_new: &Array2<Float>) -> Result<Self> {
752 match &self.accumulated_data_ {
754 Some(existing) => {
755 let combined =
756 scirs2_core::ndarray::concatenate![Axis(0), existing.clone(), x_new.clone()];
757 self.accumulated_data_ = Some(combined);
758 }
759 None => {
760 self.accumulated_data_ = Some(x_new.clone());
761 }
762 }
763
764 let should_update = if let Some(ref accumulated) = self.accumulated_data_ {
766 accumulated.nrows() >= self.min_update_size
767 } else {
768 false
769 };
770
771 if should_update {
772 if let Some(accumulated) = self.accumulated_data_.take() {
773 self = self.perform_update(&accumulated)?;
774 self.update_count_ += 1;
775 }
776 }
777
778 Ok(self)
779 }
780
781 fn perform_update(self, new_data: &Array2<Float>) -> Result<Self> {
783 match self.update_strategy.clone() {
784 UpdateStrategy::Append => self.append_update(new_data),
785 UpdateStrategy::SlidingWindow => self.sliding_window_update(new_data),
786 UpdateStrategy::Merge => self.merge_update(new_data),
787 UpdateStrategy::Selective { threshold } => self.selective_update(new_data, threshold),
788 }
789 }
790
791 fn append_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
793 let current_landmarks = self.landmark_data_.as_ref().unwrap();
794 let current_components = current_landmarks.nrows();
795
796 if current_components >= self.n_components {
797 return Ok(self);
799 }
800
801 let available_space = self.n_components - current_components;
802 let n_new = available_space.min(new_data.nrows());
803
804 if n_new == 0 {
805 return Ok(self);
806 }
807
808 let mut rng = match self.random_state {
810 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(1000)),
811 None => RealStdRng::from_seed(thread_rng().gen()),
812 };
813
814 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
815 indices.shuffle(&mut rng);
816 let selected_indices = &indices[..n_new];
817
818 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
820
821 let combined_landmarks =
823 scirs2_core::ndarray::concatenate![Axis(0), current_landmarks.clone(), new_landmarks];
824
825 let kernel_matrix = self
827 .kernel
828 .compute_kernel(&combined_landmarks, &combined_landmarks);
829 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
830
831 let mut new_component_indices = self.component_indices_.as_ref().unwrap().clone();
833 let base_index = current_landmarks.nrows();
834 for &idx in selected_indices {
835 new_component_indices.push(base_index + idx);
836 }
837
838 self.components_ = Some(components);
839 self.normalization_ = Some(normalization);
840 self.component_indices_ = Some(new_component_indices);
841 self.landmark_data_ = Some(combined_landmarks);
842
843 Ok(self)
844 }
845
846 fn sliding_window_update(mut self, new_data: &Array2<Float>) -> Result<Self> {
848 let current_landmarks = self.landmark_data_.as_ref().unwrap();
849 let n_new = new_data.nrows().min(self.n_components);
850
851 if n_new == 0 {
852 return Ok(self);
853 }
854
855 let mut rng = match self.random_state {
857 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(2000)),
858 None => RealStdRng::from_seed(thread_rng().gen()),
859 };
860
861 let mut indices: Vec<usize> = (0..new_data.nrows()).collect();
862 indices.shuffle(&mut rng);
863 let selected_indices = &indices[..n_new];
864
865 let new_landmarks = self.extract_landmarks(new_data, selected_indices);
866
867 let n_keep = self.n_components - n_new;
869 let combined_landmarks = if n_keep > 0 {
870 let kept_landmarks = current_landmarks.slice(s![n_new.., ..]).to_owned();
871 scirs2_core::ndarray::concatenate![Axis(0), kept_landmarks, new_landmarks]
872 } else {
873 new_landmarks
874 };
875
876 let kernel_matrix = self
878 .kernel
879 .compute_kernel(&combined_landmarks, &combined_landmarks);
880 let (components, normalization) = self.compute_decomposition(kernel_matrix)?;
881
882 let new_component_indices: Vec<usize> = (0..combined_landmarks.nrows()).collect();
884
885 self.components_ = Some(components);
886 self.normalization_ = Some(normalization);
887 self.component_indices_ = Some(new_component_indices);
888 self.landmark_data_ = Some(combined_landmarks);
889
890 Ok(self)
891 }
892
893 fn merge_update(self, new_data: &Array2<Float>) -> Result<Self> {
895 let current_landmarks = self.landmark_data_.as_ref().unwrap();
899 let _current_components = self.components_.as_ref().unwrap();
900 let _current_normalization = self.normalization_.as_ref().unwrap();
901
902 let n_new_components = (new_data.nrows().min(self.n_components) / 2).max(1);
904
905 let mut rng = match self.random_state {
906 Some(seed) => RealStdRng::seed_from_u64(seed.wrapping_add(3000)),
907 None => RealStdRng::from_seed(thread_rng().gen()),
908 };
909
910 let new_component_indices = self.select_components(new_data, n_new_components, &mut rng)?;
912 let new_landmarks = self.extract_landmarks(new_data, &new_component_indices);
913
914 let new_kernel_matrix = self.kernel.compute_kernel(&new_landmarks, &new_landmarks);
916 let (_new_components, _new_normalization) =
917 self.compute_decomposition(new_kernel_matrix)?;
918
919 let merged_landmarks =
922 self.merge_landmarks_intelligently(current_landmarks, &new_landmarks, &mut rng)?;
923
924 let merged_kernel_matrix = self
926 .kernel
927 .compute_kernel(&merged_landmarks, &merged_landmarks);
928 let (final_components, final_normalization) =
929 self.compute_decomposition(merged_kernel_matrix)?;
930
931 let final_component_indices: Vec<usize> = (0..merged_landmarks.nrows()).collect();
933
934 let mut updated_self = self;
935 updated_self.components_ = Some(final_components);
936 updated_self.normalization_ = Some(final_normalization);
937 updated_self.component_indices_ = Some(final_component_indices);
938 updated_self.landmark_data_ = Some(merged_landmarks);
939
940 Ok(updated_self)
941 }
942
943 fn merge_landmarks_intelligently(
945 &self,
946 current_landmarks: &Array2<Float>,
947 new_landmarks: &Array2<Float>,
948 rng: &mut RealStdRng,
949 ) -> Result<Array2<Float>> {
950 let n_current = current_landmarks.nrows();
951 let n_new = new_landmarks.nrows();
952 let n_features = current_landmarks.ncols();
953
954 let all_landmarks = scirs2_core::ndarray::concatenate![
956 Axis(0),
957 current_landmarks.clone(),
958 new_landmarks.clone()
959 ];
960
961 let n_target = self.n_components.min(n_current + n_new);
963 let selected_indices = self.select_diverse_landmarks(&all_landmarks, n_target, rng)?;
964
965 let mut merged_landmarks = Array2::zeros((selected_indices.len(), n_features));
967 for (i, &idx) in selected_indices.iter().enumerate() {
968 merged_landmarks.row_mut(i).assign(&all_landmarks.row(idx));
969 }
970
971 Ok(merged_landmarks)
972 }
973
974 fn select_diverse_landmarks(
976 &self,
977 landmarks: &Array2<Float>,
978 n_select: usize,
979 rng: &mut RealStdRng,
980 ) -> Result<Vec<usize>> {
981 let n_landmarks = landmarks.nrows();
982
983 if n_select >= n_landmarks {
984 return Ok((0..n_landmarks).collect());
985 }
986
987 let mut selected = Vec::new();
988 let mut available: Vec<usize> = (0..n_landmarks).collect();
989
990 let first_idx = rng.gen_range(0..available.len());
992 selected.push(available.remove(first_idx));
993
994 while selected.len() < n_select && !available.is_empty() {
996 let mut best_idx = 0;
997 let mut max_min_distance = 0.0;
998
999 for (i, &candidate_idx) in available.iter().enumerate() {
1000 let mut min_distance = Float::INFINITY;
1002
1003 for &selected_idx in &selected {
1004 let diff = &landmarks.row(candidate_idx) - &landmarks.row(selected_idx);
1005 let distance = diff.dot(&diff).sqrt();
1006 if distance < min_distance {
1007 min_distance = distance;
1008 }
1009 }
1010
1011 if min_distance > max_min_distance {
1012 max_min_distance = min_distance;
1013 best_idx = i;
1014 }
1015 }
1016
1017 selected.push(available.remove(best_idx));
1018 }
1019
1020 Ok(selected)
1021 }
1022
1023 fn selective_update(self, new_data: &Array2<Float>, threshold: Float) -> Result<Self> {
1025 let current_landmarks = self.landmark_data_.as_ref().unwrap();
1028
1029 let current_quality = self.evaluate_approximation_quality(current_landmarks, new_data)?;
1031
1032 let mut best_update = self.clone();
1034 let mut best_quality = current_quality;
1035
1036 let append_candidate = self.clone().append_update(new_data)?;
1038 let append_quality = append_candidate.evaluate_approximation_quality(
1039 append_candidate.landmark_data_.as_ref().unwrap(),
1040 new_data,
1041 )?;
1042
1043 if append_quality > best_quality + threshold {
1044 best_update = append_candidate;
1045 best_quality = append_quality;
1046 }
1047
1048 if new_data.nrows() >= 3 {
1050 let merge_candidate = self.clone().merge_update(new_data)?;
1051 let merge_quality = merge_candidate.evaluate_approximation_quality(
1052 merge_candidate.landmark_data_.as_ref().unwrap(),
1053 new_data,
1054 )?;
1055
1056 if merge_quality > best_quality + threshold {
1057 best_update = merge_candidate;
1058 best_quality = merge_quality;
1059 }
1060 }
1061
1062 let sliding_candidate = self.clone().sliding_window_update(new_data)?;
1064 let sliding_quality = sliding_candidate.evaluate_approximation_quality(
1065 sliding_candidate.landmark_data_.as_ref().unwrap(),
1066 new_data,
1067 )?;
1068
1069 if sliding_quality > best_quality + threshold {
1070 best_update = sliding_candidate;
1071 best_quality = sliding_quality;
1072 }
1073
1074 if best_quality > current_quality + threshold {
1076 Ok(best_update)
1077 } else {
1078 Ok(self)
1080 }
1081 }
1082
1083 fn evaluate_approximation_quality(
1085 &self,
1086 landmarks: &Array2<Float>,
1087 test_data: &Array2<Float>,
1088 ) -> Result<Float> {
1089 let n_test = test_data.nrows().min(50); let test_subset = if test_data.nrows() > n_test {
1093 let mut rng = thread_rng();
1095 let mut indices: Vec<usize> = (0..test_data.nrows()).collect();
1096 indices.shuffle(&mut rng);
1097 test_data.select(Axis(0), &indices[..n_test])
1098 } else {
1099 test_data.to_owned()
1100 };
1101
1102 let k_exact = self.kernel.compute_kernel(&test_subset, &test_subset);
1104
1105 let k_test_landmarks = self.kernel.compute_kernel(&test_subset, landmarks);
1107 let k_landmarks = self.kernel.compute_kernel(landmarks, landmarks);
1108
1109 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(k_landmarks)?;
1111
1112 let threshold = 1e-8;
1114 let mut pseudo_inverse = Array2::zeros((landmarks.nrows(), landmarks.nrows()));
1115
1116 for i in 0..landmarks.nrows() {
1117 for j in 0..landmarks.nrows() {
1118 let mut sum = 0.0;
1119 for k in 0..eigenvals.len() {
1120 if eigenvals[k] > threshold {
1121 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
1122 }
1123 }
1124 pseudo_inverse[[i, j]] = sum;
1125 }
1126 }
1127
1128 let k_approx = k_test_landmarks
1130 .dot(&pseudo_inverse)
1131 .dot(&k_test_landmarks.t());
1132
1133 let error_matrix = &k_exact - &k_approx;
1135 let approximation_error = error_matrix.mapv(|x| x * x).sum().sqrt();
1136
1137 let quality = -approximation_error / (k_exact.mapv(|x| x * x).sum().sqrt() + 1e-10);
1139
1140 Ok(quality)
1141 }
1142
1143 fn extract_landmarks(&self, x: &Array2<Float>, indices: &[usize]) -> Array2<Float> {
1145 let (_, n_features) = x.dim();
1146 let mut landmarks = Array2::zeros((indices.len(), n_features));
1147
1148 for (i, &idx) in indices.iter().enumerate() {
1149 landmarks.row_mut(i).assign(&x.row(idx));
1150 }
1151
1152 landmarks
1153 }
1154
1155 fn compute_eigendecomposition(
1158 &self,
1159 matrix: Array2<Float>,
1160 ) -> Result<(Array1<Float>, Array2<Float>)> {
1161 let n = matrix.nrows();
1162
1163 if n != matrix.ncols() {
1164 return Err(SklearsError::InvalidInput(
1165 "Matrix must be square for eigendecomposition".to_string(),
1166 ));
1167 }
1168
1169 let mut eigenvals = Array1::zeros(n);
1170 let mut eigenvecs = Array2::zeros((n, n));
1171
1172 let mut deflated_matrix = matrix.clone();
1174
1175 for k in 0..n {
1176 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8)?;
1178
1179 eigenvals[k] = eigenval;
1180 eigenvecs.column_mut(k).assign(&eigenvec);
1181
1182 for i in 0..n {
1184 for j in 0..n {
1185 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
1186 }
1187 }
1188 }
1189
1190 let mut indices: Vec<usize> = (0..n).collect();
1192 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
1193
1194 let mut sorted_eigenvals = Array1::zeros(n);
1195 let mut sorted_eigenvecs = Array2::zeros((n, n));
1196
1197 for (new_idx, &old_idx) in indices.iter().enumerate() {
1198 sorted_eigenvals[new_idx] = eigenvals[old_idx];
1199 sorted_eigenvecs
1200 .column_mut(new_idx)
1201 .assign(&eigenvecs.column(old_idx));
1202 }
1203
1204 Ok((sorted_eigenvals, sorted_eigenvecs))
1205 }
1206
1207 fn power_iteration(
1209 &self,
1210 matrix: &Array2<Float>,
1211 max_iter: usize,
1212 tol: Float,
1213 ) -> Result<(Float, Array1<Float>)> {
1214 let n = matrix.nrows();
1215
1216 let mut v = Array1::from_shape_fn(n, |i| ((i as Float + 1.0) * 0.1).sin());
1218
1219 let norm = v.dot(&v).sqrt();
1221 if norm < 1e-10 {
1222 return Err(SklearsError::InvalidInput(
1223 "Initial vector has zero norm".to_string(),
1224 ));
1225 }
1226 v /= norm;
1227
1228 let mut eigenval = 0.0;
1229
1230 for _iter in 0..max_iter {
1231 let w = matrix.dot(&v);
1233
1234 let new_eigenval = v.dot(&w);
1236
1237 let w_norm = w.dot(&w).sqrt();
1239 if w_norm < 1e-10 {
1240 break;
1241 }
1242 let new_v = w / w_norm;
1243
1244 let eigenval_change = (new_eigenval - eigenval).abs();
1246 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
1247
1248 if eigenval_change < tol && vector_change < tol {
1249 return Ok((new_eigenval, new_v));
1250 }
1251
1252 eigenval = new_eigenval;
1253 v = new_v;
1254 }
1255
1256 Ok((eigenval, v))
1257 }
1258
1259 fn compute_decomposition(
1261 &self,
1262 mut kernel_matrix: Array2<Float>,
1263 ) -> Result<(Array2<Float>, Array2<Float>)> {
1264 let reg = 1e-8;
1266 for i in 0..kernel_matrix.nrows() {
1267 kernel_matrix[[i, i]] += reg;
1268 }
1269
1270 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(kernel_matrix)?;
1272
1273 let threshold = 1e-8;
1275 let valid_indices: Vec<usize> = eigenvals
1276 .iter()
1277 .enumerate()
1278 .filter(|(_, &val)| val > threshold)
1279 .map(|(i, _)| i)
1280 .collect();
1281
1282 if valid_indices.is_empty() {
1283 return Err(SklearsError::InvalidInput(
1284 "No valid eigenvalues found in kernel matrix".to_string(),
1285 ));
1286 }
1287
1288 let n_valid = valid_indices.len();
1290 let mut components = Array2::zeros((eigenvals.len(), n_valid));
1291 let mut normalization = Array2::zeros((n_valid, eigenvals.len()));
1292
1293 for (new_idx, &old_idx) in valid_indices.iter().enumerate() {
1294 let sqrt_eigenval = eigenvals[old_idx].sqrt();
1295 components
1296 .column_mut(new_idx)
1297 .assign(&eigenvecs.column(old_idx));
1298
1299 for i in 0..eigenvals.len() {
1301 normalization[[new_idx, i]] = eigenvecs[[i, old_idx]] / sqrt_eigenval;
1302 }
1303 }
1304
1305 Ok((components, normalization))
1306 }
1307
1308 pub fn update_count(&self) -> usize {
1310 self.update_count_
1311 }
1312
1313 pub fn n_landmarks(&self) -> usize {
1315 self.landmark_data_.as_ref().map_or(0, |data| data.nrows())
1316 }
1317}
1318
1319impl Transform<Array2<Float>> for IncrementalNystroem<Trained> {
1320 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
1321 let _components = self
1322 .components_
1323 .as_ref()
1324 .ok_or_else(|| SklearsError::NotFitted {
1325 operation: "transform".to_string(),
1326 })?;
1327
1328 let normalization =
1329 self.normalization_
1330 .as_ref()
1331 .ok_or_else(|| SklearsError::NotFitted {
1332 operation: "transform".to_string(),
1333 })?;
1334
1335 let landmark_data =
1336 self.landmark_data_
1337 .as_ref()
1338 .ok_or_else(|| SklearsError::NotFitted {
1339 operation: "transform".to_string(),
1340 })?;
1341
1342 let kernel_x_landmarks = self.kernel.compute_kernel(x, landmark_data);
1344
1345 let transformed = kernel_x_landmarks.dot(&normalization.t());
1347
1348 Ok(transformed)
1349 }
1350}
1351
1352#[allow(non_snake_case)]
1353#[cfg(test)]
1354mod tests {
1355 use super::*;
1356 use scirs2_core::ndarray::array;
1357
1358 #[test]
1359 fn test_incremental_nystroem_basic() {
1360 let x_initial = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1361 let x_new = array![[4.0, 5.0], [5.0, 6.0]];
1362
1363 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 5)
1364 .update_strategy(UpdateStrategy::Append)
1365 .min_update_size(1);
1366
1367 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1368 assert_eq!(fitted.n_landmarks(), 3);
1369
1370 let updated = fitted.update(&x_new).unwrap();
1371 assert_eq!(updated.n_landmarks(), 5);
1372 assert_eq!(updated.update_count(), 1);
1373 }
1374
1375 #[test]
1376 fn test_incremental_transform() {
1377 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1378 let x_test = array![[1.5, 2.5], [2.5, 3.5]];
1379
1380 let nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3);
1381 let fitted = nystroem.fit(&x_train, &()).unwrap();
1382
1383 let transformed = fitted.transform(&x_test).unwrap();
1384 assert_eq!(transformed.shape()[0], 2);
1385 assert!(transformed.shape()[1] <= 3);
1386 }
1387
1388 #[test]
1389 fn test_sliding_window_update() {
1390 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1391 let x_new = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1392
1393 let nystroem = IncrementalNystroem::new(Kernel::Linear, 3)
1394 .update_strategy(UpdateStrategy::SlidingWindow)
1395 .min_update_size(1);
1396
1397 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1398 let updated = fitted.update(&x_new).unwrap();
1399
1400 assert_eq!(updated.n_landmarks(), 3);
1401 assert_eq!(updated.update_count(), 1);
1402 }
1403
1404 #[test]
1405 fn test_different_kernels() {
1406 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1407
1408 let rbf_nystroem = IncrementalNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3);
1410 let rbf_fitted = rbf_nystroem.fit(&x, &()).unwrap();
1411 let rbf_transformed = rbf_fitted.transform(&x).unwrap();
1412 assert_eq!(rbf_transformed.shape()[0], 3);
1413
1414 let poly_nystroem = IncrementalNystroem::new(
1416 Kernel::Polynomial {
1417 gamma: 1.0,
1418 coef0: 1.0,
1419 degree: 2,
1420 },
1421 3,
1422 );
1423 let poly_fitted = poly_nystroem.fit(&x, &()).unwrap();
1424 let poly_transformed = poly_fitted.transform(&x).unwrap();
1425 assert_eq!(poly_transformed.shape()[0], 3);
1426 }
1427
1428 #[test]
1429 fn test_min_update_size() {
1430 let x_initial = array![[1.0, 2.0], [2.0, 3.0]];
1431 let x_small = array![[3.0, 4.0]];
1432 let x_large = array![[4.0, 5.0], [5.0, 6.0], [6.0, 7.0]];
1433
1434 let nystroem = IncrementalNystroem::new(Kernel::Linear, 5).min_update_size(2);
1435
1436 let fitted = nystroem.fit(&x_initial, &()).unwrap();
1437
1438 let after_small = fitted.update(&x_small).unwrap();
1440 assert_eq!(after_small.update_count(), 0);
1441 assert_eq!(after_small.n_landmarks(), 2);
1442
1443 let after_large = after_small.update(&x_large).unwrap();
1445 assert_eq!(after_large.update_count(), 1);
1446 assert_eq!(after_large.n_landmarks(), 5);
1447 }
1448
1449 #[test]
1450 fn test_reproducibility() {
1451 let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
1452 let x_new = array![[4.0, 5.0]];
1453
1454 let nystroem1 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1455 .random_state(42)
1456 .min_update_size(1);
1457 let fitted1 = nystroem1.fit(&x, &()).unwrap();
1458 let updated1 = fitted1.update(&x_new).unwrap();
1459 let result1 = updated1.transform(&x).unwrap();
1460
1461 let nystroem2 = IncrementalNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3)
1462 .random_state(42)
1463 .min_update_size(1);
1464 let fitted2 = nystroem2.fit(&x, &()).unwrap();
1465 let updated2 = fitted2.update(&x_new).unwrap();
1466 let result2 = updated2.transform(&x).unwrap();
1467
1468 assert_eq!(result1.shape(), result2.shape());
1471
1472 let mut direct_match = true;
1474 let mut sign_flip_match = true;
1475
1476 for i in 0..result1.len() {
1477 let val1 = result1.as_slice().unwrap()[i];
1478 let val2 = result2.as_slice().unwrap()[i];
1479
1480 if (val1 - val2).abs() > 1e-6 {
1481 direct_match = false;
1482 }
1483 if (val1 + val2).abs() > 1e-6 {
1484 sign_flip_match = false;
1485 }
1486 }
1487
1488 assert!(
1489 direct_match || sign_flip_match,
1490 "Results differ too much and are not related by sign flip"
1491 );
1492 }
1493}