1use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::Rng;
6use scirs2_core::random::{thread_rng, SeedableRng};
7use sklears_core::{
8 error::{Result, SklearsError},
9 prelude::{Fit, Transform},
10 traits::{Estimator, Trained, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15fn cholesky_inverse(a: &Array2<f64>, n: usize) -> std::result::Result<Array2<f64>, String> {
20 let mut l = Array2::<f64>::zeros((n, n));
22 for j in 0..n {
23 let mut sum = 0.0;
24 for k in 0..j {
25 sum += l[[j, k]] * l[[j, k]];
26 }
27 let diag = a[[j, j]] - sum;
28 if diag <= 0.0 {
29 return Err(format!(
30 "Matrix is not positive definite (diagonal element {} = {})",
31 j, diag
32 ));
33 }
34 l[[j, j]] = diag.sqrt();
35
36 for i in (j + 1)..n {
37 let mut s = 0.0;
38 for k in 0..j {
39 s += l[[i, k]] * l[[j, k]];
40 }
41 l[[i, j]] = (a[[i, j]] - s) / l[[j, j]];
42 }
43 }
44
45 let mut l_inv = Array2::<f64>::zeros((n, n));
47 for col in 0..n {
48 for i in 0..n {
50 let mut sum = if i == col { 1.0 } else { 0.0 };
51 for k in 0..i {
52 sum -= l[[i, k]] * l_inv[[k, col]];
53 }
54 l_inv[[i, col]] = sum / l[[i, i]];
55 }
56 }
57
58 let result = l_inv.t().dot(&l_inv);
60
61 Ok(result)
62}
63
64#[derive(Debug, Clone)]
66pub enum SamplingStrategy {
68 Random,
70 KMeans,
72 LeverageScore,
74 ColumnNorm,
76}
77
78#[derive(Debug, Clone)]
80pub enum Kernel {
82 Linear,
84 Rbf { gamma: Float },
86 Polynomial {
88 gamma: Float,
89
90 coef0: Float,
91
92 degree: u32,
93 },
94}
95
96impl Kernel {
97 pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
99 let (n_x, _) = x.dim();
100 let (n_y, _) = y.dim();
101 let mut kernel_matrix = Array2::zeros((n_x, n_y));
102
103 match self {
104 Kernel::Linear => {
105 kernel_matrix = x.dot(&y.t());
106 }
107 Kernel::Rbf { gamma } => {
108 for i in 0..n_x {
109 for j in 0..n_y {
110 let diff = &x.row(i) - &y.row(j);
111 let dist_sq = diff.dot(&diff);
112 kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
113 }
114 }
115 }
116 Kernel::Polynomial {
117 gamma,
118 coef0,
119 degree,
120 } => {
121 for i in 0..n_x {
122 for j in 0..n_y {
123 let dot_prod = x.row(i).dot(&y.row(j));
124 kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
125 }
126 }
127 }
128 }
129
130 kernel_matrix
131 }
132}
133
134#[derive(Debug, Clone)]
163pub struct Nystroem<State = Untrained> {
165 pub kernel: Kernel,
167 pub n_components: usize,
169 pub sampling_strategy: SamplingStrategy,
171 pub random_state: Option<u64>,
173
174 components_: Option<Array2<Float>>,
176 normalization_: Option<Array2<Float>>,
177 component_indices_: Option<Vec<usize>>,
178
179 _state: PhantomData<State>,
180}
181
182impl Nystroem<Untrained> {
183 pub fn new(kernel: Kernel, n_components: usize) -> Self {
185 Self {
186 kernel,
187 n_components,
188 sampling_strategy: SamplingStrategy::Random,
189 random_state: None,
190 components_: None,
191 normalization_: None,
192 component_indices_: None,
193 _state: PhantomData,
194 }
195 }
196
197 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
199 self.sampling_strategy = strategy;
200 self
201 }
202
203 pub fn random_state(mut self, seed: u64) -> Self {
205 self.random_state = Some(seed);
206 self
207 }
208}
209
210impl Estimator for Nystroem<Untrained> {
211 type Config = ();
212 type Error = SklearsError;
213 type Float = Float;
214
215 fn config(&self) -> &Self::Config {
216 &()
217 }
218}
219
220impl Nystroem<Untrained> {
221 fn select_components(
223 &self,
224 x: &Array2<Float>,
225 n_components: usize,
226 rng: &mut RealStdRng,
227 ) -> Result<Vec<usize>> {
228 let (n_samples, _) = x.dim();
229
230 match &self.sampling_strategy {
231 SamplingStrategy::Random => {
232 let mut indices: Vec<usize> = (0..n_samples).collect();
233 indices.shuffle(rng);
234 Ok(indices[..n_components].to_vec())
235 }
236 SamplingStrategy::KMeans => {
237 self.kmeans_sampling(x, n_components, rng)
239 }
240 SamplingStrategy::LeverageScore => {
241 self.leverage_score_sampling(x, n_components, rng)
243 }
244 SamplingStrategy::ColumnNorm => {
245 self.column_norm_sampling(x, n_components, rng)
247 }
248 }
249 }
250
251 fn kmeans_sampling(
253 &self,
254 x: &Array2<Float>,
255 n_components: usize,
256 rng: &mut RealStdRng,
257 ) -> Result<Vec<usize>> {
258 let (n_samples, n_features) = x.dim();
259 let mut centers = Array2::zeros((n_components, n_features));
260
261 let mut indices: Vec<usize> = (0..n_samples).collect();
263 indices.shuffle(rng);
264 for (i, &idx) in indices[..n_components].iter().enumerate() {
265 centers.row_mut(i).assign(&x.row(idx));
266 }
267
268 for _iter in 0..5 {
270 let mut assignments = vec![0; n_samples];
271
272 for i in 0..n_samples {
274 let mut min_dist = Float::INFINITY;
275 let mut best_center = 0;
276
277 for j in 0..n_components {
278 let diff = &x.row(i) - ¢ers.row(j);
279 let dist = diff.dot(&diff);
280 if dist < min_dist {
281 min_dist = dist;
282 best_center = j;
283 }
284 }
285 assignments[i] = best_center;
286 }
287
288 for j in 0..n_components {
290 let cluster_points: Vec<usize> = assignments
291 .iter()
292 .enumerate()
293 .filter(|(_, &assignment)| assignment == j)
294 .map(|(i, _)| i)
295 .collect();
296
297 if !cluster_points.is_empty() {
298 let mut new_center = Array1::zeros(n_features);
299 for &point_idx in &cluster_points {
300 new_center = new_center + x.row(point_idx);
301 }
302 new_center /= cluster_points.len() as Float;
303 centers.row_mut(j).assign(&new_center);
304 }
305 }
306 }
307
308 let mut selected_indices = Vec::new();
310 for j in 0..n_components {
311 let mut min_dist = Float::INFINITY;
312 let mut best_point = 0;
313
314 for i in 0..n_samples {
315 let diff = &x.row(i) - ¢ers.row(j);
316 let dist = diff.dot(&diff);
317 if dist < min_dist {
318 min_dist = dist;
319 best_point = i;
320 }
321 }
322 selected_indices.push(best_point);
323 }
324
325 selected_indices.sort_unstable();
326 selected_indices.dedup();
327
328 while selected_indices.len() < n_components {
330 let random_idx = rng.gen_range(0..n_samples);
331 if !selected_indices.contains(&random_idx) {
332 selected_indices.push(random_idx);
333 }
334 }
335
336 Ok(selected_indices[..n_components].to_vec())
337 }
338
339 fn leverage_score_sampling(
341 &self,
342 x: &Array2<Float>,
343 n_components: usize,
344 _rng: &mut RealStdRng,
345 ) -> Result<Vec<usize>> {
346 let (n_samples, _) = x.dim();
347
348 let mut scores = Vec::new();
351 for i in 0..n_samples {
352 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
353 scores.push(row_norm + 1e-10); }
355
356 let total_score: Float = scores.iter().sum();
358 if total_score <= 0.0 {
359 return Err(SklearsError::InvalidInput(
360 "All scores are zero or negative".to_string(),
361 ));
362 }
363
364 let mut cumulative = Vec::with_capacity(scores.len());
366 let mut sum = 0.0;
367 for &score in &scores {
368 sum += score / total_score;
369 cumulative.push(sum);
370 }
371
372 let mut selected_indices = Vec::new();
373 for _ in 0..n_components {
374 let r = thread_rng().gen::<Float>();
375 let mut idx = cumulative
377 .iter()
378 .position(|&cum| cum >= r)
379 .unwrap_or(scores.len() - 1);
380
381 while selected_indices.contains(&idx) {
383 let r = thread_rng().gen::<Float>();
384 idx = cumulative
385 .iter()
386 .position(|&cum| cum >= r)
387 .unwrap_or(scores.len() - 1);
388 }
389 selected_indices.push(idx);
390 }
391
392 Ok(selected_indices)
393 }
394
395 fn column_norm_sampling(
397 &self,
398 x: &Array2<Float>,
399 n_components: usize,
400 rng: &mut RealStdRng,
401 ) -> Result<Vec<usize>> {
402 let (n_samples, _) = x.dim();
403
404 let mut norms = Vec::new();
406 for i in 0..n_samples {
407 let norm = x.row(i).dot(&x.row(i)).sqrt();
408 norms.push(norm + 1e-10);
409 }
410
411 let mut indices_with_norms: Vec<(usize, Float)> = norms
413 .iter()
414 .enumerate()
415 .map(|(i, &norm)| (i, norm))
416 .collect();
417 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
418
419 let mut selected_indices = Vec::new();
420 let step = n_samples.max(1) / n_components.max(1);
421
422 for i in 0..n_components {
423 let idx = (i * step).min(n_samples - 1);
424 selected_indices.push(indices_with_norms[idx].0);
425 }
426
427 while selected_indices.len() < n_components {
429 let random_idx = rng.gen_range(0..n_samples);
430 if !selected_indices.contains(&random_idx) {
431 selected_indices.push(random_idx);
432 }
433 }
434
435 Ok(selected_indices)
436 }
437
438 fn compute_eigendecomposition(
441 &self,
442 matrix: &Array2<Float>,
443 _rng: &mut RealStdRng,
444 ) -> Result<(Array1<Float>, Array2<Float>)> {
445 use scirs2_linalg::compat::{ArrayLinalgExt, UPLO};
446
447 let n = matrix.nrows();
448
449 if n != matrix.ncols() {
450 return Err(SklearsError::InvalidInput(
451 "Matrix must be square for eigendecomposition".to_string(),
452 ));
453 }
454
455 let (eigenvals, eigenvecs) = matrix
457 .eigh(UPLO::Lower)
458 .map_err(|e| SklearsError::InvalidInput(format!("Eigendecomposition failed: {}", e)))?;
459
460 let mut sorted_eigenvals = Array1::zeros(n);
462 let mut sorted_eigenvecs = Array2::zeros((n, n));
463
464 for i in 0..n {
465 sorted_eigenvals[i] = eigenvals[n - 1 - i];
466 sorted_eigenvecs
467 .column_mut(i)
468 .assign(&eigenvecs.column(n - 1 - i));
469 }
470
471 Ok((sorted_eigenvals, sorted_eigenvecs))
472 }
473
474 fn power_iteration(
476 &self,
477 matrix: &Array2<Float>,
478 max_iter: usize,
479 tol: Float,
480 rng: &mut RealStdRng,
481 ) -> Result<(Float, Array1<Float>)> {
482 let n = matrix.nrows();
483
484 let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
486
487 let norm = v.dot(&v).sqrt();
489 if norm < 1e-10 {
490 return Err(SklearsError::InvalidInput(
491 "Initial vector has zero norm".to_string(),
492 ));
493 }
494 v /= norm;
495
496 let mut eigenval = 0.0;
497
498 for _iter in 0..max_iter {
499 let w = matrix.dot(&v);
501
502 let new_eigenval = v.dot(&w);
504
505 let w_norm = w.dot(&w).sqrt();
507 if w_norm < 1e-10 {
508 break;
509 }
510 let new_v = w / w_norm;
511
512 let eigenval_change = (new_eigenval - eigenval).abs();
514 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
515
516 if eigenval_change < tol && vector_change < tol {
517 return Ok((new_eigenval, new_v));
518 }
519
520 eigenval = new_eigenval;
521 v = new_v;
522 }
523
524 Ok((eigenval, v))
525 }
526}
527
528impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
529 type Fitted = Nystroem<Trained>;
530
531 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
532 let (n_samples, _) = x.dim();
533
534 if self.n_components > n_samples {
535 eprintln!(
536 "Warning: n_components ({}) > n_samples ({})",
537 self.n_components, n_samples
538 );
539 }
540
541 let n_components_actual = self.n_components.min(n_samples);
542
543 let mut rng = if let Some(seed) = self.random_state {
544 RealStdRng::seed_from_u64(seed)
545 } else {
546 RealStdRng::from_seed(thread_rng().gen())
547 };
548
549 let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
551
552 let mut components = Array2::zeros((n_components_actual, x.ncols()));
554 for (i, &idx) in component_indices.iter().enumerate() {
555 components.row_mut(i).assign(&x.row(idx));
556 }
557
558 let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
560
561 let eps = 1e-8;
564
565 let mut k11_reg = k11;
567 for i in 0..n_components_actual {
568 k11_reg[[i, i]] += eps;
569 }
570
571 let normalization = cholesky_inverse(&k11_reg, n_components_actual).map_err(|e| {
574 SklearsError::InvalidInput(format!("Failed to invert regularized kernel matrix: {}", e))
575 })?;
576
577 Ok(Nystroem {
578 kernel: self.kernel,
579 n_components: self.n_components,
580 sampling_strategy: self.sampling_strategy,
581 random_state: self.random_state,
582 components_: Some(components),
583 normalization_: Some(normalization),
584 component_indices_: Some(component_indices),
585 _state: PhantomData,
586 })
587 }
588}
589
590impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
591 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
592 let components = self.components_.as_ref().unwrap();
593 let normalization = self.normalization_.as_ref().unwrap();
594
595 if x.ncols() != components.ncols() {
596 return Err(SklearsError::InvalidInput(format!(
597 "X has {} features, but Nystroem was fitted with {} features",
598 x.ncols(),
599 components.ncols()
600 )));
601 }
602
603 let k_x_components = self.kernel.compute_kernel(x, components);
605
606 let result = k_x_components.dot(normalization);
608
609 Ok(result)
610 }
611}
612
613impl Nystroem<Trained> {
614 pub fn components(&self) -> &Array2<Float> {
616 self.components_.as_ref().unwrap()
617 }
618
619 pub fn component_indices(&self) -> &[usize] {
621 self.component_indices_.as_ref().unwrap()
622 }
623
624 pub fn normalization(&self) -> &Array2<Float> {
626 self.normalization_.as_ref().unwrap()
627 }
628}
629
630#[allow(non_snake_case)]
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use scirs2_core::ndarray::array;
635
636 #[test]
637 fn test_nystroem_linear_kernel() {
638 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
639
640 let nystroem = Nystroem::new(Kernel::Linear, 3);
641 let fitted = nystroem.fit(&x, &()).unwrap();
642 let x_transformed = fitted.transform(&x).unwrap();
643
644 assert_eq!(x_transformed.nrows(), 4);
645 assert!(x_transformed.ncols() <= 3); }
647
648 #[test]
649 fn test_nystroem_rbf_kernel() {
650 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
651
652 let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
653 let fitted = nystroem.fit(&x, &()).unwrap();
654 let x_transformed = fitted.transform(&x).unwrap();
655
656 assert_eq!(x_transformed.nrows(), 3);
657 assert!(x_transformed.ncols() <= 2);
658 }
659
660 #[test]
661 fn test_nystroem_polynomial_kernel() {
662 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
663
664 let kernel = Kernel::Polynomial {
665 gamma: 1.0,
666 coef0: 1.0,
667 degree: 2,
668 };
669 let nystroem = Nystroem::new(kernel, 2);
670 let fitted = nystroem.fit(&x, &()).unwrap();
671 let x_transformed = fitted.transform(&x).unwrap();
672
673 assert_eq!(x_transformed.nrows(), 3);
674 assert!(x_transformed.ncols() <= 2);
675 }
676
677 #[test]
678 fn test_nystroem_reproducibility() {
679 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
680
681 let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
682 let fitted1 = nystroem1.fit(&x, &()).unwrap();
683 let result1 = fitted1.transform(&x).unwrap();
684
685 let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
686 let fitted2 = nystroem2.fit(&x, &()).unwrap();
687 let result2 = fitted2.transform(&x).unwrap();
688
689 assert_eq!(result1.shape(), result2.shape());
691 for (a, b) in result1.iter().zip(result2.iter()) {
692 assert!(
693 (a - b).abs() < 1e-6,
694 "Values differ too much: {} vs {}",
695 a,
696 b
697 );
698 }
699 }
700
701 #[test]
702 fn test_nystroem_feature_mismatch() {
703 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
704
705 let x_test = array![
706 [1.0, 2.0, 3.0], ];
708
709 let nystroem = Nystroem::new(Kernel::Linear, 2);
710 let fitted = nystroem.fit(&x_train, &()).unwrap();
711 let result = fitted.transform(&x_test);
712
713 assert!(result.is_err());
714 }
715
716 #[test]
717 fn test_nystroem_sampling_strategies() {
718 let x = array![
719 [1.0, 2.0],
720 [3.0, 4.0],
721 [5.0, 6.0],
722 [7.0, 8.0],
723 [2.0, 1.0],
724 [4.0, 3.0],
725 [6.0, 5.0],
726 [8.0, 7.0]
727 ];
728
729 let nystroem_random = Nystroem::new(Kernel::Linear, 4)
731 .sampling_strategy(SamplingStrategy::Random)
732 .random_state(42);
733 let fitted_random = nystroem_random.fit(&x, &()).unwrap();
734 let result_random = fitted_random.transform(&x).unwrap();
735 assert_eq!(result_random.nrows(), 8);
736
737 let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
739 .sampling_strategy(SamplingStrategy::KMeans)
740 .random_state(42);
741 let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
742 let result_kmeans = fitted_kmeans.transform(&x).unwrap();
743 assert_eq!(result_kmeans.nrows(), 8);
744
745 let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
747 .sampling_strategy(SamplingStrategy::LeverageScore)
748 .random_state(42);
749 let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
750 let result_leverage = fitted_leverage.transform(&x).unwrap();
751 assert_eq!(result_leverage.nrows(), 8);
752
753 let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
755 .sampling_strategy(SamplingStrategy::ColumnNorm)
756 .random_state(42);
757 let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
758 let result_norm = fitted_norm.transform(&x).unwrap();
759 assert_eq!(result_norm.nrows(), 8);
760 }
761
762 #[test]
763 fn test_nystroem_rbf_with_different_sampling() {
764 let x = array![
765 [1.0, 2.0],
766 [3.0, 4.0],
767 [5.0, 6.0],
768 [7.0, 8.0],
769 [2.0, 1.0],
770 [4.0, 3.0],
771 [6.0, 5.0],
772 [8.0, 7.0]
773 ];
774
775 let kernel = Kernel::Rbf { gamma: 0.1 };
776
777 let nystroem = Nystroem::new(kernel, 4)
779 .sampling_strategy(SamplingStrategy::LeverageScore)
780 .random_state(42);
781 let fitted = nystroem.fit(&x, &()).unwrap();
782 let result = fitted.transform(&x).unwrap();
783
784 assert_eq!(result.shape(), &[8, 4]);
785
786 for val in result.iter() {
788 assert!(val.is_finite());
789 }
790 }
791
792 #[test]
793 fn test_nystroem_improved_eigendecomposition() {
794 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
795
796 let nystroem = Nystroem::new(Kernel::Linear, 3)
797 .sampling_strategy(SamplingStrategy::Random)
798 .random_state(42);
799 let fitted = nystroem.fit(&x, &()).unwrap();
800 let result = fitted.transform(&x).unwrap();
801
802 assert_eq!(result.nrows(), 4);
803 assert!(result.ncols() <= 3);
804
805 for val in result.iter() {
807 assert!(val.is_finite());
808 }
809 }
810}