1use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::{thread_rng, Rng, SeedableRng};
6use sklears_core::{
7 error::{Result, SklearsError},
8 prelude::{Fit, Transform},
9 traits::{Estimator, Trained, Untrained},
10 types::Float,
11};
12use std::marker::PhantomData;
13
14#[derive(Debug, Clone)]
16pub enum SamplingStrategy {
18 Random,
20 KMeans,
22 LeverageScore,
24 ColumnNorm,
26}
27
28#[derive(Debug, Clone)]
30pub enum Kernel {
32 Linear,
34 Rbf { gamma: Float },
36 Polynomial {
38 gamma: Float,
39
40 coef0: Float,
41
42 degree: u32,
43 },
44}
45
46impl Kernel {
47 pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
49 let (n_x, _) = x.dim();
50 let (n_y, _) = y.dim();
51 let mut kernel_matrix = Array2::zeros((n_x, n_y));
52
53 match self {
54 Kernel::Linear => {
55 kernel_matrix = x.dot(&y.t());
56 }
57 Kernel::Rbf { gamma } => {
58 for i in 0..n_x {
59 for j in 0..n_y {
60 let diff = &x.row(i) - &y.row(j);
61 let dist_sq = diff.dot(&diff);
62 kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
63 }
64 }
65 }
66 Kernel::Polynomial {
67 gamma,
68 coef0,
69 degree,
70 } => {
71 for i in 0..n_x {
72 for j in 0..n_y {
73 let dot_prod = x.row(i).dot(&y.row(j));
74 kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
75 }
76 }
77 }
78 }
79
80 kernel_matrix
81 }
82}
83
84#[derive(Debug, Clone)]
113pub struct Nystroem<State = Untrained> {
115 pub kernel: Kernel,
117 pub n_components: usize,
119 pub sampling_strategy: SamplingStrategy,
121 pub random_state: Option<u64>,
123
124 components_: Option<Array2<Float>>,
126 normalization_: Option<Array2<Float>>,
127 component_indices_: Option<Vec<usize>>,
128
129 _state: PhantomData<State>,
130}
131
132impl Nystroem<Untrained> {
133 pub fn new(kernel: Kernel, n_components: usize) -> Self {
135 Self {
136 kernel,
137 n_components,
138 sampling_strategy: SamplingStrategy::Random,
139 random_state: None,
140 components_: None,
141 normalization_: None,
142 component_indices_: None,
143 _state: PhantomData,
144 }
145 }
146
147 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
149 self.sampling_strategy = strategy;
150 self
151 }
152
153 pub fn random_state(mut self, seed: u64) -> Self {
155 self.random_state = Some(seed);
156 self
157 }
158}
159
160impl Estimator for Nystroem<Untrained> {
161 type Config = ();
162 type Error = SklearsError;
163 type Float = Float;
164
165 fn config(&self) -> &Self::Config {
166 &()
167 }
168}
169
170impl Nystroem<Untrained> {
171 fn select_components(
173 &self,
174 x: &Array2<Float>,
175 n_components: usize,
176 rng: &mut RealStdRng,
177 ) -> Result<Vec<usize>> {
178 let (n_samples, _) = x.dim();
179
180 match &self.sampling_strategy {
181 SamplingStrategy::Random => {
182 let mut indices: Vec<usize> = (0..n_samples).collect();
183 indices.shuffle(rng);
184 Ok(indices[..n_components].to_vec())
185 }
186 SamplingStrategy::KMeans => {
187 self.kmeans_sampling(x, n_components, rng)
189 }
190 SamplingStrategy::LeverageScore => {
191 self.leverage_score_sampling(x, n_components, rng)
193 }
194 SamplingStrategy::ColumnNorm => {
195 self.column_norm_sampling(x, n_components, rng)
197 }
198 }
199 }
200
201 fn kmeans_sampling(
203 &self,
204 x: &Array2<Float>,
205 n_components: usize,
206 rng: &mut RealStdRng,
207 ) -> Result<Vec<usize>> {
208 let (n_samples, n_features) = x.dim();
209 let mut centers = Array2::zeros((n_components, n_features));
210
211 let mut indices: Vec<usize> = (0..n_samples).collect();
213 indices.shuffle(rng);
214 for (i, &idx) in indices[..n_components].iter().enumerate() {
215 centers.row_mut(i).assign(&x.row(idx));
216 }
217
218 for _iter in 0..5 {
220 let mut assignments = vec![0; n_samples];
221
222 for i in 0..n_samples {
224 let mut min_dist = Float::INFINITY;
225 let mut best_center = 0;
226
227 for j in 0..n_components {
228 let diff = &x.row(i) - ¢ers.row(j);
229 let dist = diff.dot(&diff);
230 if dist < min_dist {
231 min_dist = dist;
232 best_center = j;
233 }
234 }
235 assignments[i] = best_center;
236 }
237
238 for j in 0..n_components {
240 let cluster_points: Vec<usize> = assignments
241 .iter()
242 .enumerate()
243 .filter(|(_, &assignment)| assignment == j)
244 .map(|(i, _)| i)
245 .collect();
246
247 if !cluster_points.is_empty() {
248 let mut new_center = Array1::zeros(n_features);
249 for &point_idx in &cluster_points {
250 new_center = new_center + &x.row(point_idx);
251 }
252 new_center /= cluster_points.len() as Float;
253 centers.row_mut(j).assign(&new_center);
254 }
255 }
256 }
257
258 let mut selected_indices = Vec::new();
260 for j in 0..n_components {
261 let mut min_dist = Float::INFINITY;
262 let mut best_point = 0;
263
264 for i in 0..n_samples {
265 let diff = &x.row(i) - ¢ers.row(j);
266 let dist = diff.dot(&diff);
267 if dist < min_dist {
268 min_dist = dist;
269 best_point = i;
270 }
271 }
272 selected_indices.push(best_point);
273 }
274
275 selected_indices.sort_unstable();
276 selected_indices.dedup();
277
278 while selected_indices.len() < n_components {
280 let random_idx = rng.gen_range(0..n_samples);
281 if !selected_indices.contains(&random_idx) {
282 selected_indices.push(random_idx);
283 }
284 }
285
286 Ok(selected_indices[..n_components].to_vec())
287 }
288
289 fn leverage_score_sampling(
291 &self,
292 x: &Array2<Float>,
293 n_components: usize,
294 rng: &mut RealStdRng,
295 ) -> Result<Vec<usize>> {
296 let (n_samples, _) = x.dim();
297
298 let mut scores = Vec::new();
301 for i in 0..n_samples {
302 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
303 scores.push(row_norm + 1e-10); }
305
306 let total_score: Float = scores.iter().sum();
308 if total_score <= 0.0 {
309 return Err(SklearsError::InvalidInput(
310 "All scores are zero or negative".to_string(),
311 ));
312 }
313
314 let mut cumulative = Vec::with_capacity(scores.len());
316 let mut sum = 0.0;
317 for &score in &scores {
318 sum += score / total_score;
319 cumulative.push(sum);
320 }
321
322 let mut selected_indices = Vec::new();
323 for _ in 0..n_components {
324 let r = thread_rng().gen::<Float>();
325 let mut idx = cumulative
327 .iter()
328 .position(|&cum| cum >= r)
329 .unwrap_or(scores.len() - 1);
330
331 while selected_indices.contains(&idx) {
333 let r = thread_rng().gen::<Float>();
334 idx = cumulative
335 .iter()
336 .position(|&cum| cum >= r)
337 .unwrap_or(scores.len() - 1);
338 }
339 selected_indices.push(idx);
340 }
341
342 Ok(selected_indices)
343 }
344
345 fn column_norm_sampling(
347 &self,
348 x: &Array2<Float>,
349 n_components: usize,
350 rng: &mut RealStdRng,
351 ) -> Result<Vec<usize>> {
352 let (n_samples, _) = x.dim();
353
354 let mut norms = Vec::new();
356 for i in 0..n_samples {
357 let norm = x.row(i).dot(&x.row(i)).sqrt();
358 norms.push(norm + 1e-10);
359 }
360
361 let mut indices_with_norms: Vec<(usize, Float)> = norms
363 .iter()
364 .enumerate()
365 .map(|(i, &norm)| (i, norm))
366 .collect();
367 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
368
369 let mut selected_indices = Vec::new();
370 let step = n_samples.max(1) / n_components.max(1);
371
372 for i in 0..n_components {
373 let idx = (i * step).min(n_samples - 1);
374 selected_indices.push(indices_with_norms[idx].0);
375 }
376
377 while selected_indices.len() < n_components {
379 let random_idx = rng.gen_range(0..n_samples);
380 if !selected_indices.contains(&random_idx) {
381 selected_indices.push(random_idx);
382 }
383 }
384
385 Ok(selected_indices)
386 }
387
388 fn compute_eigendecomposition(
391 &self,
392 matrix: &Array2<Float>,
393 rng: &mut RealStdRng,
394 ) -> Result<(Array1<Float>, Array2<Float>)> {
395 let n = matrix.nrows();
396
397 if n != matrix.ncols() {
398 return Err(SklearsError::InvalidInput(
399 "Matrix must be square for eigendecomposition".to_string(),
400 ));
401 }
402
403 let mut eigenvals = Array1::zeros(n);
404 let mut eigenvecs = Array2::zeros((n, n));
405
406 let mut deflated_matrix = matrix.clone();
408
409 for k in 0..n {
410 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8, rng)?;
412
413 eigenvals[k] = eigenval;
414 eigenvecs.column_mut(k).assign(&eigenvec);
415
416 for i in 0..n {
418 for j in 0..n {
419 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
420 }
421 }
422 }
423
424 let mut indices: Vec<usize> = (0..n).collect();
426 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
427
428 let mut sorted_eigenvals = Array1::zeros(n);
429 let mut sorted_eigenvecs = Array2::zeros((n, n));
430
431 for (new_idx, &old_idx) in indices.iter().enumerate() {
432 sorted_eigenvals[new_idx] = eigenvals[old_idx];
433 sorted_eigenvecs
434 .column_mut(new_idx)
435 .assign(&eigenvecs.column(old_idx));
436 }
437
438 Ok((sorted_eigenvals, sorted_eigenvecs))
439 }
440
441 fn power_iteration(
443 &self,
444 matrix: &Array2<Float>,
445 max_iter: usize,
446 tol: Float,
447 rng: &mut RealStdRng,
448 ) -> Result<(Float, Array1<Float>)> {
449 let n = matrix.nrows();
450
451 let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
453
454 let norm = v.dot(&v).sqrt();
456 if norm < 1e-10 {
457 return Err(SklearsError::InvalidInput(
458 "Initial vector has zero norm".to_string(),
459 ));
460 }
461 v /= norm;
462
463 let mut eigenval = 0.0;
464
465 for _iter in 0..max_iter {
466 let w = matrix.dot(&v);
468
469 let new_eigenval = v.dot(&w);
471
472 let w_norm = w.dot(&w).sqrt();
474 if w_norm < 1e-10 {
475 break;
476 }
477 let new_v = w / w_norm;
478
479 let eigenval_change = (new_eigenval - eigenval).abs();
481 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
482
483 if eigenval_change < tol && vector_change < tol {
484 return Ok((new_eigenval, new_v));
485 }
486
487 eigenval = new_eigenval;
488 v = new_v;
489 }
490
491 Ok((eigenval, v))
492 }
493}
494
495impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
496 type Fitted = Nystroem<Trained>;
497
498 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
499 let (n_samples, _) = x.dim();
500
501 if self.n_components > n_samples {
502 eprintln!(
503 "Warning: n_components ({}) > n_samples ({})",
504 self.n_components, n_samples
505 );
506 }
507
508 let n_components_actual = self.n_components.min(n_samples);
509
510 let mut rng = if let Some(seed) = self.random_state {
511 RealStdRng::seed_from_u64(seed)
512 } else {
513 RealStdRng::from_seed(thread_rng().gen())
514 };
515
516 let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
518
519 let mut components = Array2::zeros((n_components_actual, x.ncols()));
521 for (i, &idx) in component_indices.iter().enumerate() {
522 components.row_mut(i).assign(&x.row(idx));
523 }
524
525 let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
527
528 let eps = 1e-12;
531
532 let mut k11_reg = k11.clone();
534 for i in 0..n_components_actual {
535 k11_reg[[i, i]] += eps;
536 }
537
538 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(&k11_reg, &mut rng)?;
541
542 let threshold = 1e-8;
544 let valid_indices: Vec<usize> = eigenvals
545 .iter()
546 .enumerate()
547 .filter(|(_, &val)| val > threshold)
548 .map(|(i, _)| i)
549 .collect();
550
551 if valid_indices.is_empty() {
552 return Err(SklearsError::InvalidInput(
553 "No valid eigenvalues found in kernel matrix".to_string(),
554 ));
555 }
556
557 let n_valid = valid_indices.len();
559 let mut pseudo_inverse = Array2::zeros((n_components_actual, n_components_actual));
560
561 for i in 0..n_components_actual {
562 for j in 0..n_components_actual {
563 let mut sum = 0.0;
564 for &k in &valid_indices {
565 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
566 }
567 pseudo_inverse[[i, j]] = sum;
568 }
569 }
570
571 let normalization = pseudo_inverse;
572
573 Ok(Nystroem {
574 kernel: self.kernel,
575 n_components: self.n_components,
576 sampling_strategy: self.sampling_strategy,
577 random_state: self.random_state,
578 components_: Some(components),
579 normalization_: Some(normalization),
580 component_indices_: Some(component_indices),
581 _state: PhantomData,
582 })
583 }
584}
585
586impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
587 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
588 let components = self.components_.as_ref().unwrap();
589 let normalization = self.normalization_.as_ref().unwrap();
590
591 if x.ncols() != components.ncols() {
592 return Err(SklearsError::InvalidInput(format!(
593 "X has {} features, but Nystroem was fitted with {} features",
594 x.ncols(),
595 components.ncols()
596 )));
597 }
598
599 let k_x_components = self.kernel.compute_kernel(x, components);
601
602 let result = k_x_components.dot(normalization);
604
605 Ok(result)
606 }
607}
608
609impl Nystroem<Trained> {
610 pub fn components(&self) -> &Array2<Float> {
612 self.components_.as_ref().unwrap()
613 }
614
615 pub fn component_indices(&self) -> &[usize] {
617 self.component_indices_.as_ref().unwrap()
618 }
619
620 pub fn normalization(&self) -> &Array2<Float> {
622 self.normalization_.as_ref().unwrap()
623 }
624}
625
626#[allow(non_snake_case)]
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use scirs2_core::ndarray::array;
631
632 #[test]
633 fn test_nystroem_linear_kernel() {
634 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
635
636 let nystroem = Nystroem::new(Kernel::Linear, 3);
637 let fitted = nystroem.fit(&x, &()).unwrap();
638 let x_transformed = fitted.transform(&x).unwrap();
639
640 assert_eq!(x_transformed.nrows(), 4);
641 assert!(x_transformed.ncols() <= 3); }
643
644 #[test]
645 fn test_nystroem_rbf_kernel() {
646 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
647
648 let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
649 let fitted = nystroem.fit(&x, &()).unwrap();
650 let x_transformed = fitted.transform(&x).unwrap();
651
652 assert_eq!(x_transformed.nrows(), 3);
653 assert!(x_transformed.ncols() <= 2);
654 }
655
656 #[test]
657 fn test_nystroem_polynomial_kernel() {
658 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
659
660 let kernel = Kernel::Polynomial {
661 gamma: 1.0,
662 coef0: 1.0,
663 degree: 2,
664 };
665 let nystroem = Nystroem::new(kernel, 2);
666 let fitted = nystroem.fit(&x, &()).unwrap();
667 let x_transformed = fitted.transform(&x).unwrap();
668
669 assert_eq!(x_transformed.nrows(), 3);
670 assert!(x_transformed.ncols() <= 2);
671 }
672
673 #[test]
674 fn test_nystroem_reproducibility() {
675 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
676
677 let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
678 let fitted1 = nystroem1.fit(&x, &()).unwrap();
679 let result1 = fitted1.transform(&x).unwrap();
680
681 let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
682 let fitted2 = nystroem2.fit(&x, &()).unwrap();
683 let result2 = fitted2.transform(&x).unwrap();
684
685 assert_eq!(result1.shape(), result2.shape());
687 for (a, b) in result1.iter().zip(result2.iter()) {
688 assert!(
689 (a - b).abs() < 1e-6,
690 "Values differ too much: {} vs {}",
691 a,
692 b
693 );
694 }
695 }
696
697 #[test]
698 fn test_nystroem_feature_mismatch() {
699 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
700
701 let x_test = array![
702 [1.0, 2.0, 3.0], ];
704
705 let nystroem = Nystroem::new(Kernel::Linear, 2);
706 let fitted = nystroem.fit(&x_train, &()).unwrap();
707 let result = fitted.transform(&x_test);
708
709 assert!(result.is_err());
710 }
711
712 #[test]
713 fn test_nystroem_sampling_strategies() {
714 let x = array![
715 [1.0, 2.0],
716 [3.0, 4.0],
717 [5.0, 6.0],
718 [7.0, 8.0],
719 [2.0, 1.0],
720 [4.0, 3.0],
721 [6.0, 5.0],
722 [8.0, 7.0]
723 ];
724
725 let nystroem_random = Nystroem::new(Kernel::Linear, 4)
727 .sampling_strategy(SamplingStrategy::Random)
728 .random_state(42);
729 let fitted_random = nystroem_random.fit(&x, &()).unwrap();
730 let result_random = fitted_random.transform(&x).unwrap();
731 assert_eq!(result_random.nrows(), 8);
732
733 let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
735 .sampling_strategy(SamplingStrategy::KMeans)
736 .random_state(42);
737 let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
738 let result_kmeans = fitted_kmeans.transform(&x).unwrap();
739 assert_eq!(result_kmeans.nrows(), 8);
740
741 let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
743 .sampling_strategy(SamplingStrategy::LeverageScore)
744 .random_state(42);
745 let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
746 let result_leverage = fitted_leverage.transform(&x).unwrap();
747 assert_eq!(result_leverage.nrows(), 8);
748
749 let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
751 .sampling_strategy(SamplingStrategy::ColumnNorm)
752 .random_state(42);
753 let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
754 let result_norm = fitted_norm.transform(&x).unwrap();
755 assert_eq!(result_norm.nrows(), 8);
756 }
757
758 #[test]
759 fn test_nystroem_rbf_with_different_sampling() {
760 let x = array![
761 [1.0, 2.0],
762 [3.0, 4.0],
763 [5.0, 6.0],
764 [7.0, 8.0],
765 [2.0, 1.0],
766 [4.0, 3.0],
767 [6.0, 5.0],
768 [8.0, 7.0]
769 ];
770
771 let kernel = Kernel::Rbf { gamma: 0.1 };
772
773 let nystroem = Nystroem::new(kernel, 4)
775 .sampling_strategy(SamplingStrategy::LeverageScore)
776 .random_state(42);
777 let fitted = nystroem.fit(&x, &()).unwrap();
778 let result = fitted.transform(&x).unwrap();
779
780 assert_eq!(result.shape(), &[8, 4]);
781
782 for val in result.iter() {
784 assert!(val.is_finite());
785 }
786 }
787
788 #[test]
789 fn test_nystroem_improved_eigendecomposition() {
790 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
791
792 let nystroem = Nystroem::new(Kernel::Linear, 3)
793 .sampling_strategy(SamplingStrategy::Random)
794 .random_state(42);
795 let fitted = nystroem.fit(&x, &()).unwrap();
796 let result = fitted.transform(&x).unwrap();
797
798 assert_eq!(result.nrows(), 4);
799 assert!(result.ncols() <= 3);
800
801 for val in result.iter() {
803 assert!(val.is_finite());
804 }
805 }
806}