1use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::RngExt;
6use scirs2_core::random::{thread_rng, SeedableRng};
7use sklears_core::{
8 error::{Result, SklearsError},
9 prelude::{Fit, Transform},
10 traits::{Estimator, Trained, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15fn cholesky_inverse(a: &Array2<f64>, n: usize) -> std::result::Result<Array2<f64>, String> {
20 let mut l = Array2::<f64>::zeros((n, n));
22 for j in 0..n {
23 let mut sum = 0.0;
24 for k in 0..j {
25 sum += l[[j, k]] * l[[j, k]];
26 }
27 let diag = a[[j, j]] - sum;
28 if diag <= 0.0 {
29 return Err(format!(
30 "Matrix is not positive definite (diagonal element {} = {})",
31 j, diag
32 ));
33 }
34 l[[j, j]] = diag.sqrt();
35
36 for i in (j + 1)..n {
37 let mut s = 0.0;
38 for k in 0..j {
39 s += l[[i, k]] * l[[j, k]];
40 }
41 l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
42 }
43 }
44
45 let mut l_inv = Array2::<f64>::zeros((n, n));
47 for col in 0..n {
48 for i in 0..n {
50 let mut sum = if i == col { 1.0 } else { 0.0 };
51 for k in 0..i {
52 sum -= l[[i, k]] * l_inv[[k, col]];
53 }
54 l_inv[[i, col]] = sum / l[[i, i]];
55 }
56 }
57
58 let result = l_inv.t().dot(&l_inv);
60
61 Ok(result)
62}
63
64#[derive(Debug, Clone)]
66pub enum SamplingStrategy {
68 Random,
70 KMeans,
72 LeverageScore,
74 ColumnNorm,
76}
77
78#[derive(Debug, Clone)]
80pub enum Kernel {
82 Linear,
84 Rbf { gamma: Float },
86 Polynomial {
88 gamma: Float,
89
90 coef0: Float,
91
92 degree: u32,
93 },
94}
95
96impl Kernel {
97 pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
99 let (n_x, _) = x.dim();
100 let (n_y, _) = y.dim();
101 let mut kernel_matrix = Array2::zeros((n_x, n_y));
102
103 match self {
104 Kernel::Linear => {
105 kernel_matrix = x.dot(&y.t());
106 }
107 Kernel::Rbf { gamma } => {
108 for i in 0..n_x {
109 for j in 0..n_y {
110 let diff = &x.row(i) - &y.row(j);
111 let dist_sq = diff.dot(&diff);
112 kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
113 }
114 }
115 }
116 Kernel::Polynomial {
117 gamma,
118 coef0,
119 degree,
120 } => {
121 for i in 0..n_x {
122 for j in 0..n_y {
123 let dot_prod = x.row(i).dot(&y.row(j));
124 kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
125 }
126 }
127 }
128 }
129
130 kernel_matrix
131 }
132}
133
134#[derive(Debug, Clone)]
163pub struct Nystroem<State = Untrained> {
165 pub kernel: Kernel,
167 pub n_components: usize,
169 pub sampling_strategy: SamplingStrategy,
171 pub random_state: Option<u64>,
173
174 components_: Option<Array2<Float>>,
176 normalization_: Option<Array2<Float>>,
177 component_indices_: Option<Vec<usize>>,
178
179 _state: PhantomData<State>,
180}
181
182impl Nystroem<Untrained> {
183 pub fn new(kernel: Kernel, n_components: usize) -> Self {
185 Self {
186 kernel,
187 n_components,
188 sampling_strategy: SamplingStrategy::Random,
189 random_state: None,
190 components_: None,
191 normalization_: None,
192 component_indices_: None,
193 _state: PhantomData,
194 }
195 }
196
197 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
199 self.sampling_strategy = strategy;
200 self
201 }
202
203 pub fn random_state(mut self, seed: u64) -> Self {
205 self.random_state = Some(seed);
206 self
207 }
208}
209
210impl Estimator for Nystroem<Untrained> {
211 type Config = ();
212 type Error = SklearsError;
213 type Float = Float;
214
215 fn config(&self) -> &Self::Config {
216 &()
217 }
218}
219
220impl Nystroem<Untrained> {
221 fn select_components(
223 &self,
224 x: &Array2<Float>,
225 n_components: usize,
226 rng: &mut RealStdRng,
227 ) -> Result<Vec<usize>> {
228 let (n_samples, _) = x.dim();
229
230 match &self.sampling_strategy {
231 SamplingStrategy::Random => {
232 let mut indices: Vec<usize> = (0..n_samples).collect();
233 indices.shuffle(rng);
234 Ok(indices[..n_components].to_vec())
235 }
236 SamplingStrategy::KMeans => {
237 self.kmeans_sampling(x, n_components, rng)
239 }
240 SamplingStrategy::LeverageScore => {
241 self.leverage_score_sampling(x, n_components, rng)
243 }
244 SamplingStrategy::ColumnNorm => {
245 self.column_norm_sampling(x, n_components, rng)
247 }
248 }
249 }
250
251 fn kmeans_sampling(
253 &self,
254 x: &Array2<Float>,
255 n_components: usize,
256 rng: &mut RealStdRng,
257 ) -> Result<Vec<usize>> {
258 let (n_samples, n_features) = x.dim();
259 let mut centers = Array2::zeros((n_components, n_features));
260
261 let mut indices: Vec<usize> = (0..n_samples).collect();
263 indices.shuffle(rng);
264 for (i, &idx) in indices[..n_components].iter().enumerate() {
265 centers.row_mut(i).assign(&x.row(idx));
266 }
267
268 for _iter in 0..5 {
270 let mut assignments = vec![0; n_samples];
271
272 for i in 0..n_samples {
274 let mut min_dist = Float::INFINITY;
275 let mut best_center = 0;
276
277 for j in 0..n_components {
278 let diff = &x.row(i) - ¢ers.row(j);
279 let dist = diff.dot(&diff);
280 if dist < min_dist {
281 min_dist = dist;
282 best_center = j;
283 }
284 }
285 assignments[i] = best_center;
286 }
287
288 for j in 0..n_components {
290 let cluster_points: Vec<usize> = assignments
291 .iter()
292 .enumerate()
293 .filter(|(_, &assignment)| assignment == j)
294 .map(|(i, _)| i)
295 .collect();
296
297 if !cluster_points.is_empty() {
298 let mut new_center = Array1::zeros(n_features);
299 for &point_idx in &cluster_points {
300 new_center = new_center + x.row(point_idx);
301 }
302 new_center /= cluster_points.len() as Float;
303 centers.row_mut(j).assign(&new_center);
304 }
305 }
306 }
307
308 let mut selected_indices = Vec::new();
310 for j in 0..n_components {
311 let mut min_dist = Float::INFINITY;
312 let mut best_point = 0;
313
314 for i in 0..n_samples {
315 let diff = &x.row(i) - ¢ers.row(j);
316 let dist = diff.dot(&diff);
317 if dist < min_dist {
318 min_dist = dist;
319 best_point = i;
320 }
321 }
322 selected_indices.push(best_point);
323 }
324
325 selected_indices.sort_unstable();
326 selected_indices.dedup();
327
328 while selected_indices.len() < n_components {
330 let random_idx = rng.random_range(0..n_samples);
331 if !selected_indices.contains(&random_idx) {
332 selected_indices.push(random_idx);
333 }
334 }
335
336 Ok(selected_indices[..n_components].to_vec())
337 }
338
339 fn leverage_score_sampling(
341 &self,
342 x: &Array2<Float>,
343 n_components: usize,
344 _rng: &mut RealStdRng,
345 ) -> Result<Vec<usize>> {
346 let (n_samples, _) = x.dim();
347
348 let mut scores = Vec::new();
351 for i in 0..n_samples {
352 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
353 scores.push(row_norm + 1e-10); }
355
356 let total_score: Float = scores.iter().sum();
358 if total_score <= 0.0 {
359 return Err(SklearsError::InvalidInput(
360 "All scores are zero or negative".to_string(),
361 ));
362 }
363
364 let mut cumulative = Vec::with_capacity(scores.len());
366 let mut sum = 0.0;
367 for &score in &scores {
368 sum += score / total_score;
369 cumulative.push(sum);
370 }
371
372 let mut selected_indices = Vec::new();
373 for _ in 0..n_components {
374 let r = thread_rng().random::<Float>();
375 let mut idx = cumulative
377 .iter()
378 .position(|&cum| cum >= r)
379 .unwrap_or(scores.len() - 1);
380
381 while selected_indices.contains(&idx) {
383 let r = thread_rng().random::<Float>();
384 idx = cumulative
385 .iter()
386 .position(|&cum| cum >= r)
387 .unwrap_or(scores.len() - 1);
388 }
389 selected_indices.push(idx);
390 }
391
392 Ok(selected_indices)
393 }
394
395 fn column_norm_sampling(
397 &self,
398 x: &Array2<Float>,
399 n_components: usize,
400 rng: &mut RealStdRng,
401 ) -> Result<Vec<usize>> {
402 let (n_samples, _) = x.dim();
403
404 let mut norms = Vec::new();
406 for i in 0..n_samples {
407 let norm = x.row(i).dot(&x.row(i)).sqrt();
408 norms.push(norm + 1e-10);
409 }
410
411 let mut indices_with_norms: Vec<(usize, Float)> = norms
413 .iter()
414 .enumerate()
415 .map(|(i, &norm)| (i, norm))
416 .collect();
417 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
418
419 let mut selected_indices = Vec::new();
420 let step = n_samples.max(1) / n_components.max(1);
421
422 for i in 0..n_components {
423 let idx = (i * step).min(n_samples - 1);
424 selected_indices.push(indices_with_norms[idx].0);
425 }
426
427 while selected_indices.len() < n_components {
429 let random_idx = rng.random_range(0..n_samples);
430 if !selected_indices.contains(&random_idx) {
431 selected_indices.push(random_idx);
432 }
433 }
434
435 Ok(selected_indices)
436 }
437
438 fn compute_eigendecomposition(
441 &self,
442 matrix: &Array2<Float>,
443 _rng: &mut RealStdRng,
444 ) -> Result<(Array1<Float>, Array2<Float>)> {
445 use scirs2_linalg::compat::{ArrayLinalgExt, UPLO};
446
447 let n = matrix.nrows();
448
449 if n != matrix.ncols() {
450 return Err(SklearsError::InvalidInput(
451 "Matrix must be square for eigendecomposition".to_string(),
452 ));
453 }
454
455 let (eigenvals, eigenvecs) = matrix
457 .eigh(UPLO::Lower)
458 .map_err(|e| SklearsError::InvalidInput(format!("Eigendecomposition failed: {}", e)))?;
459
460 let mut sorted_eigenvals = Array1::zeros(n);
462 let mut sorted_eigenvecs = Array2::zeros((n, n));
463
464 for i in 0..n {
465 sorted_eigenvals[i] = eigenvals[n - 1 - i];
466 sorted_eigenvecs
467 .column_mut(i)
468 .assign(&eigenvecs.column(n - 1 - i));
469 }
470
471 Ok((sorted_eigenvals, sorted_eigenvecs))
472 }
473
474 fn power_iteration(
476 &self,
477 matrix: &Array2<Float>,
478 max_iter: usize,
479 tol: Float,
480 rng: &mut RealStdRng,
481 ) -> Result<(Float, Array1<Float>)> {
482 let n = matrix.nrows();
483
484 let mut v = Array1::from_shape_fn(n, |_| rng.random::<Float>() - 0.5);
486
487 let norm = v.dot(&v).sqrt();
489 if norm < 1e-10 {
490 return Err(SklearsError::InvalidInput(
491 "Initial vector has zero norm".to_string(),
492 ));
493 }
494 v /= norm;
495
496 let mut eigenval = 0.0;
497
498 for _iter in 0..max_iter {
499 let w = matrix.dot(&v);
501
502 let new_eigenval = v.dot(&w);
504
505 let w_norm = w.dot(&w).sqrt();
507 if w_norm < 1e-10 {
508 break;
509 }
510 let new_v = w / w_norm;
511
512 let eigenval_change = (new_eigenval - eigenval).abs();
514 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
515
516 if eigenval_change < tol && vector_change < tol {
517 return Ok((new_eigenval, new_v));
518 }
519
520 eigenval = new_eigenval;
521 v = new_v;
522 }
523
524 Ok((eigenval, v))
525 }
526}
527
528impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
529 type Fitted = Nystroem<Trained>;
530
531 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
532 let (n_samples, _) = x.dim();
533
534 if self.n_components > n_samples {
535 eprintln!(
536 "Warning: n_components ({}) > n_samples ({})",
537 self.n_components, n_samples
538 );
539 }
540
541 let n_components_actual = self.n_components.min(n_samples);
542
543 let mut rng = if let Some(seed) = self.random_state {
544 RealStdRng::seed_from_u64(seed)
545 } else {
546 RealStdRng::from_seed(thread_rng().random())
547 };
548
549 let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
551
552 let mut components = Array2::zeros((n_components_actual, x.ncols()));
554 for (i, &idx) in component_indices.iter().enumerate() {
555 components.row_mut(i).assign(&x.row(idx));
556 }
557
558 let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
560
561 let eps = 1e-8;
564
565 let mut k11_reg = k11;
567 for i in 0..n_components_actual {
568 k11_reg[[i, i]] += eps;
569 }
570
571 let normalization = cholesky_inverse(&k11_reg, n_components_actual).map_err(|e| {
574 SklearsError::InvalidInput(format!("Failed to invert regularized kernel matrix: {}", e))
575 })?;
576
577 Ok(Nystroem {
578 kernel: self.kernel,
579 n_components: self.n_components,
580 sampling_strategy: self.sampling_strategy,
581 random_state: self.random_state,
582 components_: Some(components),
583 normalization_: Some(normalization),
584 component_indices_: Some(component_indices),
585 _state: PhantomData,
586 })
587 }
588}
589
590impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
591 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
592 let components = self.components_.as_ref().expect("operation should succeed");
593 let normalization = self
594 .normalization_
595 .as_ref()
596 .expect("operation should succeed");
597
598 if x.ncols() != components.ncols() {
599 return Err(SklearsError::InvalidInput(format!(
600 "X has {} features, but Nystroem was fitted with {} features",
601 x.ncols(),
602 components.ncols()
603 )));
604 }
605
606 let k_x_components = self.kernel.compute_kernel(x, components);
608
609 let result = k_x_components.dot(normalization);
611
612 Ok(result)
613 }
614}
615
616impl Nystroem<Trained> {
617 pub fn components(&self) -> &Array2<Float> {
619 self.components_.as_ref().expect("operation should succeed")
620 }
621
622 pub fn component_indices(&self) -> &[usize] {
624 self.component_indices_
625 .as_ref()
626 .expect("operation should succeed")
627 }
628
629 pub fn normalization(&self) -> &Array2<Float> {
631 self.normalization_
632 .as_ref()
633 .expect("operation should succeed")
634 }
635}
636
637#[allow(non_snake_case)]
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use scirs2_core::ndarray::array;
642
643 #[test]
644 fn test_nystroem_linear_kernel() {
645 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
646
647 let nystroem = Nystroem::new(Kernel::Linear, 3);
648 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
649 let x_transformed = fitted.transform(&x).expect("operation should succeed");
650
651 assert_eq!(x_transformed.nrows(), 4);
652 assert!(x_transformed.ncols() <= 3); }
654
655 #[test]
656 fn test_nystroem_rbf_kernel() {
657 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
658
659 let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
660 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
661 let x_transformed = fitted.transform(&x).expect("operation should succeed");
662
663 assert_eq!(x_transformed.nrows(), 3);
664 assert!(x_transformed.ncols() <= 2);
665 }
666
667 #[test]
668 fn test_nystroem_polynomial_kernel() {
669 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
670
671 let kernel = Kernel::Polynomial {
672 gamma: 1.0,
673 coef0: 1.0,
674 degree: 2,
675 };
676 let nystroem = Nystroem::new(kernel, 2);
677 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
678 let x_transformed = fitted.transform(&x).expect("operation should succeed");
679
680 assert_eq!(x_transformed.nrows(), 3);
681 assert!(x_transformed.ncols() <= 2);
682 }
683
684 #[test]
685 fn test_nystroem_reproducibility() {
686 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
687
688 let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
689 let fitted1 = nystroem1.fit(&x, &()).expect("operation should succeed");
690 let result1 = fitted1.transform(&x).expect("operation should succeed");
691
692 let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
693 let fitted2 = nystroem2.fit(&x, &()).expect("operation should succeed");
694 let result2 = fitted2.transform(&x).expect("operation should succeed");
695
696 assert_eq!(result1.shape(), result2.shape());
698 for (a, b) in result1.iter().zip(result2.iter()) {
699 assert!(
700 (a - b).abs() < 1e-6,
701 "Values differ too much: {} vs {}",
702 a,
703 b
704 );
705 }
706 }
707
708 #[test]
709 fn test_nystroem_feature_mismatch() {
710 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
711
712 let x_test = array![
713 [1.0, 2.0, 3.0], ];
715
716 let nystroem = Nystroem::new(Kernel::Linear, 2);
717 let fitted = nystroem
718 .fit(&x_train, &())
719 .expect("operation should succeed");
720 let result = fitted.transform(&x_test);
721
722 assert!(result.is_err());
723 }
724
725 #[test]
726 fn test_nystroem_sampling_strategies() {
727 let x = array![
728 [1.0, 2.0],
729 [3.0, 4.0],
730 [5.0, 6.0],
731 [7.0, 8.0],
732 [2.0, 1.0],
733 [4.0, 3.0],
734 [6.0, 5.0],
735 [8.0, 7.0]
736 ];
737
738 let nystroem_random = Nystroem::new(Kernel::Linear, 4)
740 .sampling_strategy(SamplingStrategy::Random)
741 .random_state(42);
742 let fitted_random = nystroem_random
743 .fit(&x, &())
744 .expect("operation should succeed");
745 let result_random = fitted_random
746 .transform(&x)
747 .expect("operation should succeed");
748 assert_eq!(result_random.nrows(), 8);
749
750 let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
752 .sampling_strategy(SamplingStrategy::KMeans)
753 .random_state(42);
754 let fitted_kmeans = nystroem_kmeans
755 .fit(&x, &())
756 .expect("operation should succeed");
757 let result_kmeans = fitted_kmeans
758 .transform(&x)
759 .expect("operation should succeed");
760 assert_eq!(result_kmeans.nrows(), 8);
761
762 let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
764 .sampling_strategy(SamplingStrategy::LeverageScore)
765 .random_state(42);
766 let fitted_leverage = nystroem_leverage
767 .fit(&x, &())
768 .expect("operation should succeed");
769 let result_leverage = fitted_leverage
770 .transform(&x)
771 .expect("operation should succeed");
772 assert_eq!(result_leverage.nrows(), 8);
773
774 let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
776 .sampling_strategy(SamplingStrategy::ColumnNorm)
777 .random_state(42);
778 let fitted_norm = nystroem_norm
779 .fit(&x, &())
780 .expect("operation should succeed");
781 let result_norm = fitted_norm.transform(&x).expect("operation should succeed");
782 assert_eq!(result_norm.nrows(), 8);
783 }
784
785 #[test]
786 fn test_nystroem_rbf_with_different_sampling() {
787 let x = array![
788 [1.0, 2.0],
789 [3.0, 4.0],
790 [5.0, 6.0],
791 [7.0, 8.0],
792 [2.0, 1.0],
793 [4.0, 3.0],
794 [6.0, 5.0],
795 [8.0, 7.0]
796 ];
797
798 let kernel = Kernel::Rbf { gamma: 0.1 };
799
800 let nystroem = Nystroem::new(kernel, 4)
802 .sampling_strategy(SamplingStrategy::LeverageScore)
803 .random_state(42);
804 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
805 let result = fitted.transform(&x).expect("operation should succeed");
806
807 assert_eq!(result.shape(), &[8, 4]);
808
809 for val in result.iter() {
811 assert!(val.is_finite());
812 }
813 }
814
815 #[test]
816 fn test_nystroem_improved_eigendecomposition() {
817 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
818
819 let nystroem = Nystroem::new(Kernel::Linear, 3)
820 .sampling_strategy(SamplingStrategy::Random)
821 .random_state(42);
822 let fitted = nystroem.fit(&x, &()).expect("operation should succeed");
823 let result = fitted.transform(&x).expect("operation should succeed");
824
825 assert_eq!(result.nrows(), 4);
826 assert!(result.ncols() <= 3);
827
828 for val in result.iter() {
830 assert!(val.is_finite());
831 }
832 }
833}