1use scirs2_core::ndarray::{Array1, Array2};
3use scirs2_core::random::rngs::StdRng as RealStdRng;
4use scirs2_core::random::seq::SliceRandom;
5use scirs2_core::random::Rng;
6use scirs2_core::random::{thread_rng, SeedableRng};
7use sklears_core::{
8 error::{Result, SklearsError},
9 prelude::{Fit, Transform},
10 traits::{Estimator, Trained, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
17pub enum SamplingStrategy {
19 Random,
21 KMeans,
23 LeverageScore,
25 ColumnNorm,
27}
28
29#[derive(Debug, Clone)]
31pub enum Kernel {
33 Linear,
35 Rbf { gamma: Float },
37 Polynomial {
39 gamma: Float,
40
41 coef0: Float,
42
43 degree: u32,
44 },
45}
46
47impl Kernel {
48 pub fn compute_kernel(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
50 let (n_x, _) = x.dim();
51 let (n_y, _) = y.dim();
52 let mut kernel_matrix = Array2::zeros((n_x, n_y));
53
54 match self {
55 Kernel::Linear => {
56 kernel_matrix = x.dot(&y.t());
57 }
58 Kernel::Rbf { gamma } => {
59 for i in 0..n_x {
60 for j in 0..n_y {
61 let diff = &x.row(i) - &y.row(j);
62 let dist_sq = diff.dot(&diff);
63 kernel_matrix[[i, j]] = (-gamma * dist_sq).exp();
64 }
65 }
66 }
67 Kernel::Polynomial {
68 gamma,
69 coef0,
70 degree,
71 } => {
72 for i in 0..n_x {
73 for j in 0..n_y {
74 let dot_prod = x.row(i).dot(&y.row(j));
75 kernel_matrix[[i, j]] = (gamma * dot_prod + coef0).powf(*degree as Float);
76 }
77 }
78 }
79 }
80
81 kernel_matrix
82 }
83}
84
85#[derive(Debug, Clone)]
114pub struct Nystroem<State = Untrained> {
116 pub kernel: Kernel,
118 pub n_components: usize,
120 pub sampling_strategy: SamplingStrategy,
122 pub random_state: Option<u64>,
124
125 components_: Option<Array2<Float>>,
127 normalization_: Option<Array2<Float>>,
128 component_indices_: Option<Vec<usize>>,
129
130 _state: PhantomData<State>,
131}
132
133impl Nystroem<Untrained> {
134 pub fn new(kernel: Kernel, n_components: usize) -> Self {
136 Self {
137 kernel,
138 n_components,
139 sampling_strategy: SamplingStrategy::Random,
140 random_state: None,
141 components_: None,
142 normalization_: None,
143 component_indices_: None,
144 _state: PhantomData,
145 }
146 }
147
148 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
150 self.sampling_strategy = strategy;
151 self
152 }
153
154 pub fn random_state(mut self, seed: u64) -> Self {
156 self.random_state = Some(seed);
157 self
158 }
159}
160
161impl Estimator for Nystroem<Untrained> {
162 type Config = ();
163 type Error = SklearsError;
164 type Float = Float;
165
166 fn config(&self) -> &Self::Config {
167 &()
168 }
169}
170
171impl Nystroem<Untrained> {
172 fn select_components(
174 &self,
175 x: &Array2<Float>,
176 n_components: usize,
177 rng: &mut RealStdRng,
178 ) -> Result<Vec<usize>> {
179 let (n_samples, _) = x.dim();
180
181 match &self.sampling_strategy {
182 SamplingStrategy::Random => {
183 let mut indices: Vec<usize> = (0..n_samples).collect();
184 indices.shuffle(rng);
185 Ok(indices[..n_components].to_vec())
186 }
187 SamplingStrategy::KMeans => {
188 self.kmeans_sampling(x, n_components, rng)
190 }
191 SamplingStrategy::LeverageScore => {
192 self.leverage_score_sampling(x, n_components, rng)
194 }
195 SamplingStrategy::ColumnNorm => {
196 self.column_norm_sampling(x, n_components, rng)
198 }
199 }
200 }
201
202 fn kmeans_sampling(
204 &self,
205 x: &Array2<Float>,
206 n_components: usize,
207 rng: &mut RealStdRng,
208 ) -> Result<Vec<usize>> {
209 let (n_samples, n_features) = x.dim();
210 let mut centers = Array2::zeros((n_components, n_features));
211
212 let mut indices: Vec<usize> = (0..n_samples).collect();
214 indices.shuffle(rng);
215 for (i, &idx) in indices[..n_components].iter().enumerate() {
216 centers.row_mut(i).assign(&x.row(idx));
217 }
218
219 for _iter in 0..5 {
221 let mut assignments = vec![0; n_samples];
222
223 for i in 0..n_samples {
225 let mut min_dist = Float::INFINITY;
226 let mut best_center = 0;
227
228 for j in 0..n_components {
229 let diff = &x.row(i) - ¢ers.row(j);
230 let dist = diff.dot(&diff);
231 if dist < min_dist {
232 min_dist = dist;
233 best_center = j;
234 }
235 }
236 assignments[i] = best_center;
237 }
238
239 for j in 0..n_components {
241 let cluster_points: Vec<usize> = assignments
242 .iter()
243 .enumerate()
244 .filter(|(_, &assignment)| assignment == j)
245 .map(|(i, _)| i)
246 .collect();
247
248 if !cluster_points.is_empty() {
249 let mut new_center = Array1::zeros(n_features);
250 for &point_idx in &cluster_points {
251 new_center = new_center + x.row(point_idx);
252 }
253 new_center /= cluster_points.len() as Float;
254 centers.row_mut(j).assign(&new_center);
255 }
256 }
257 }
258
259 let mut selected_indices = Vec::new();
261 for j in 0..n_components {
262 let mut min_dist = Float::INFINITY;
263 let mut best_point = 0;
264
265 for i in 0..n_samples {
266 let diff = &x.row(i) - ¢ers.row(j);
267 let dist = diff.dot(&diff);
268 if dist < min_dist {
269 min_dist = dist;
270 best_point = i;
271 }
272 }
273 selected_indices.push(best_point);
274 }
275
276 selected_indices.sort_unstable();
277 selected_indices.dedup();
278
279 while selected_indices.len() < n_components {
281 let random_idx = rng.gen_range(0..n_samples);
282 if !selected_indices.contains(&random_idx) {
283 selected_indices.push(random_idx);
284 }
285 }
286
287 Ok(selected_indices[..n_components].to_vec())
288 }
289
290 fn leverage_score_sampling(
292 &self,
293 x: &Array2<Float>,
294 n_components: usize,
295 _rng: &mut RealStdRng,
296 ) -> Result<Vec<usize>> {
297 let (n_samples, _) = x.dim();
298
299 let mut scores = Vec::new();
302 for i in 0..n_samples {
303 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
304 scores.push(row_norm + 1e-10); }
306
307 let total_score: Float = scores.iter().sum();
309 if total_score <= 0.0 {
310 return Err(SklearsError::InvalidInput(
311 "All scores are zero or negative".to_string(),
312 ));
313 }
314
315 let mut cumulative = Vec::with_capacity(scores.len());
317 let mut sum = 0.0;
318 for &score in &scores {
319 sum += score / total_score;
320 cumulative.push(sum);
321 }
322
323 let mut selected_indices = Vec::new();
324 for _ in 0..n_components {
325 let r = thread_rng().gen::<Float>();
326 let mut idx = cumulative
328 .iter()
329 .position(|&cum| cum >= r)
330 .unwrap_or(scores.len() - 1);
331
332 while selected_indices.contains(&idx) {
334 let r = thread_rng().gen::<Float>();
335 idx = cumulative
336 .iter()
337 .position(|&cum| cum >= r)
338 .unwrap_or(scores.len() - 1);
339 }
340 selected_indices.push(idx);
341 }
342
343 Ok(selected_indices)
344 }
345
346 fn column_norm_sampling(
348 &self,
349 x: &Array2<Float>,
350 n_components: usize,
351 rng: &mut RealStdRng,
352 ) -> Result<Vec<usize>> {
353 let (n_samples, _) = x.dim();
354
355 let mut norms = Vec::new();
357 for i in 0..n_samples {
358 let norm = x.row(i).dot(&x.row(i)).sqrt();
359 norms.push(norm + 1e-10);
360 }
361
362 let mut indices_with_norms: Vec<(usize, Float)> = norms
364 .iter()
365 .enumerate()
366 .map(|(i, &norm)| (i, norm))
367 .collect();
368 indices_with_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
369
370 let mut selected_indices = Vec::new();
371 let step = n_samples.max(1) / n_components.max(1);
372
373 for i in 0..n_components {
374 let idx = (i * step).min(n_samples - 1);
375 selected_indices.push(indices_with_norms[idx].0);
376 }
377
378 while selected_indices.len() < n_components {
380 let random_idx = rng.gen_range(0..n_samples);
381 if !selected_indices.contains(&random_idx) {
382 selected_indices.push(random_idx);
383 }
384 }
385
386 Ok(selected_indices)
387 }
388
389 fn compute_eigendecomposition(
392 &self,
393 matrix: &Array2<Float>,
394 rng: &mut RealStdRng,
395 ) -> Result<(Array1<Float>, Array2<Float>)> {
396 let n = matrix.nrows();
397
398 if n != matrix.ncols() {
399 return Err(SklearsError::InvalidInput(
400 "Matrix must be square for eigendecomposition".to_string(),
401 ));
402 }
403
404 let mut eigenvals = Array1::zeros(n);
405 let mut eigenvecs = Array2::zeros((n, n));
406
407 let mut deflated_matrix = matrix.clone();
409
410 for k in 0..n {
411 let (eigenval, eigenvec) = self.power_iteration(&deflated_matrix, 100, 1e-8, rng)?;
413
414 eigenvals[k] = eigenval;
415 eigenvecs.column_mut(k).assign(&eigenvec);
416
417 for i in 0..n {
419 for j in 0..n {
420 deflated_matrix[[i, j]] -= eigenval * eigenvec[i] * eigenvec[j];
421 }
422 }
423 }
424
425 let mut indices: Vec<usize> = (0..n).collect();
427 indices.sort_by(|&i, &j| eigenvals[j].partial_cmp(&eigenvals[i]).unwrap());
428
429 let mut sorted_eigenvals = Array1::zeros(n);
430 let mut sorted_eigenvecs = Array2::zeros((n, n));
431
432 for (new_idx, &old_idx) in indices.iter().enumerate() {
433 sorted_eigenvals[new_idx] = eigenvals[old_idx];
434 sorted_eigenvecs
435 .column_mut(new_idx)
436 .assign(&eigenvecs.column(old_idx));
437 }
438
439 Ok((sorted_eigenvals, sorted_eigenvecs))
440 }
441
442 fn power_iteration(
444 &self,
445 matrix: &Array2<Float>,
446 max_iter: usize,
447 tol: Float,
448 rng: &mut RealStdRng,
449 ) -> Result<(Float, Array1<Float>)> {
450 let n = matrix.nrows();
451
452 let mut v = Array1::from_shape_fn(n, |_| rng.gen::<Float>() - 0.5);
454
455 let norm = v.dot(&v).sqrt();
457 if norm < 1e-10 {
458 return Err(SklearsError::InvalidInput(
459 "Initial vector has zero norm".to_string(),
460 ));
461 }
462 v /= norm;
463
464 let mut eigenval = 0.0;
465
466 for _iter in 0..max_iter {
467 let w = matrix.dot(&v);
469
470 let new_eigenval = v.dot(&w);
472
473 let w_norm = w.dot(&w).sqrt();
475 if w_norm < 1e-10 {
476 break;
477 }
478 let new_v = w / w_norm;
479
480 let eigenval_change = (new_eigenval - eigenval).abs();
482 let vector_change = (&new_v - &v).mapv(|x| x.abs()).sum();
483
484 if eigenval_change < tol && vector_change < tol {
485 return Ok((new_eigenval, new_v));
486 }
487
488 eigenval = new_eigenval;
489 v = new_v;
490 }
491
492 Ok((eigenval, v))
493 }
494}
495
496impl Fit<Array2<Float>, ()> for Nystroem<Untrained> {
497 type Fitted = Nystroem<Trained>;
498
499 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
500 let (n_samples, _) = x.dim();
501
502 if self.n_components > n_samples {
503 eprintln!(
504 "Warning: n_components ({}) > n_samples ({})",
505 self.n_components, n_samples
506 );
507 }
508
509 let n_components_actual = self.n_components.min(n_samples);
510
511 let mut rng = if let Some(seed) = self.random_state {
512 RealStdRng::seed_from_u64(seed)
513 } else {
514 RealStdRng::from_seed(thread_rng().gen())
515 };
516
517 let component_indices = self.select_components(x, n_components_actual, &mut rng)?;
519
520 let mut components = Array2::zeros((n_components_actual, x.ncols()));
522 for (i, &idx) in component_indices.iter().enumerate() {
523 components.row_mut(i).assign(&x.row(idx));
524 }
525
526 let k11: Array2<f64> = self.kernel.compute_kernel(&components, &components);
528
529 let eps = 1e-12;
532
533 let mut k11_reg = k11.clone();
535 for i in 0..n_components_actual {
536 k11_reg[[i, i]] += eps;
537 }
538
539 let (eigenvals, eigenvecs) = self.compute_eigendecomposition(&k11_reg, &mut rng)?;
542
543 let threshold = 1e-8;
545 let valid_indices: Vec<usize> = eigenvals
546 .iter()
547 .enumerate()
548 .filter(|(_, &val)| val > threshold)
549 .map(|(i, _)| i)
550 .collect();
551
552 if valid_indices.is_empty() {
553 return Err(SklearsError::InvalidInput(
554 "No valid eigenvalues found in kernel matrix".to_string(),
555 ));
556 }
557
558 let _n_valid = valid_indices.len();
560 let mut pseudo_inverse = Array2::zeros((n_components_actual, n_components_actual));
561
562 for i in 0..n_components_actual {
563 for j in 0..n_components_actual {
564 let mut sum = 0.0;
565 for &k in &valid_indices {
566 sum += eigenvecs[[i, k]] * eigenvecs[[j, k]] / eigenvals[k];
567 }
568 pseudo_inverse[[i, j]] = sum;
569 }
570 }
571
572 let normalization = pseudo_inverse;
573
574 Ok(Nystroem {
575 kernel: self.kernel,
576 n_components: self.n_components,
577 sampling_strategy: self.sampling_strategy,
578 random_state: self.random_state,
579 components_: Some(components),
580 normalization_: Some(normalization),
581 component_indices_: Some(component_indices),
582 _state: PhantomData,
583 })
584 }
585}
586
587impl Transform<Array2<Float>, Array2<Float>> for Nystroem<Trained> {
588 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
589 let components = self.components_.as_ref().unwrap();
590 let normalization = self.normalization_.as_ref().unwrap();
591
592 if x.ncols() != components.ncols() {
593 return Err(SklearsError::InvalidInput(format!(
594 "X has {} features, but Nystroem was fitted with {} features",
595 x.ncols(),
596 components.ncols()
597 )));
598 }
599
600 let k_x_components = self.kernel.compute_kernel(x, components);
602
603 let result = k_x_components.dot(normalization);
605
606 Ok(result)
607 }
608}
609
610impl Nystroem<Trained> {
611 pub fn components(&self) -> &Array2<Float> {
613 self.components_.as_ref().unwrap()
614 }
615
616 pub fn component_indices(&self) -> &[usize] {
618 self.component_indices_.as_ref().unwrap()
619 }
620
621 pub fn normalization(&self) -> &Array2<Float> {
623 self.normalization_.as_ref().unwrap()
624 }
625}
626
627#[allow(non_snake_case)]
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use scirs2_core::ndarray::array;
632
633 #[test]
634 fn test_nystroem_linear_kernel() {
635 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
636
637 let nystroem = Nystroem::new(Kernel::Linear, 3);
638 let fitted = nystroem.fit(&x, &()).unwrap();
639 let x_transformed = fitted.transform(&x).unwrap();
640
641 assert_eq!(x_transformed.nrows(), 4);
642 assert!(x_transformed.ncols() <= 3); }
644
645 #[test]
646 fn test_nystroem_rbf_kernel() {
647 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
648
649 let nystroem = Nystroem::new(Kernel::Rbf { gamma: 0.1 }, 2);
650 let fitted = nystroem.fit(&x, &()).unwrap();
651 let x_transformed = fitted.transform(&x).unwrap();
652
653 assert_eq!(x_transformed.nrows(), 3);
654 assert!(x_transformed.ncols() <= 2);
655 }
656
657 #[test]
658 fn test_nystroem_polynomial_kernel() {
659 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
660
661 let kernel = Kernel::Polynomial {
662 gamma: 1.0,
663 coef0: 1.0,
664 degree: 2,
665 };
666 let nystroem = Nystroem::new(kernel, 2);
667 let fitted = nystroem.fit(&x, &()).unwrap();
668 let x_transformed = fitted.transform(&x).unwrap();
669
670 assert_eq!(x_transformed.nrows(), 3);
671 assert!(x_transformed.ncols() <= 2);
672 }
673
674 #[test]
675 fn test_nystroem_reproducibility() {
676 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
677
678 let nystroem1 = Nystroem::new(Kernel::Linear, 3).random_state(42);
679 let fitted1 = nystroem1.fit(&x, &()).unwrap();
680 let result1 = fitted1.transform(&x).unwrap();
681
682 let nystroem2 = Nystroem::new(Kernel::Linear, 3).random_state(42);
683 let fitted2 = nystroem2.fit(&x, &()).unwrap();
684 let result2 = fitted2.transform(&x).unwrap();
685
686 assert_eq!(result1.shape(), result2.shape());
688 for (a, b) in result1.iter().zip(result2.iter()) {
689 assert!(
690 (a - b).abs() < 1e-6,
691 "Values differ too much: {} vs {}",
692 a,
693 b
694 );
695 }
696 }
697
698 #[test]
699 fn test_nystroem_feature_mismatch() {
700 let x_train = array![[1.0, 2.0], [3.0, 4.0],];
701
702 let x_test = array![
703 [1.0, 2.0, 3.0], ];
705
706 let nystroem = Nystroem::new(Kernel::Linear, 2);
707 let fitted = nystroem.fit(&x_train, &()).unwrap();
708 let result = fitted.transform(&x_test);
709
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_nystroem_sampling_strategies() {
715 let x = array![
716 [1.0, 2.0],
717 [3.0, 4.0],
718 [5.0, 6.0],
719 [7.0, 8.0],
720 [2.0, 1.0],
721 [4.0, 3.0],
722 [6.0, 5.0],
723 [8.0, 7.0]
724 ];
725
726 let nystroem_random = Nystroem::new(Kernel::Linear, 4)
728 .sampling_strategy(SamplingStrategy::Random)
729 .random_state(42);
730 let fitted_random = nystroem_random.fit(&x, &()).unwrap();
731 let result_random = fitted_random.transform(&x).unwrap();
732 assert_eq!(result_random.nrows(), 8);
733
734 let nystroem_kmeans = Nystroem::new(Kernel::Linear, 4)
736 .sampling_strategy(SamplingStrategy::KMeans)
737 .random_state(42);
738 let fitted_kmeans = nystroem_kmeans.fit(&x, &()).unwrap();
739 let result_kmeans = fitted_kmeans.transform(&x).unwrap();
740 assert_eq!(result_kmeans.nrows(), 8);
741
742 let nystroem_leverage = Nystroem::new(Kernel::Linear, 4)
744 .sampling_strategy(SamplingStrategy::LeverageScore)
745 .random_state(42);
746 let fitted_leverage = nystroem_leverage.fit(&x, &()).unwrap();
747 let result_leverage = fitted_leverage.transform(&x).unwrap();
748 assert_eq!(result_leverage.nrows(), 8);
749
750 let nystroem_norm = Nystroem::new(Kernel::Linear, 4)
752 .sampling_strategy(SamplingStrategy::ColumnNorm)
753 .random_state(42);
754 let fitted_norm = nystroem_norm.fit(&x, &()).unwrap();
755 let result_norm = fitted_norm.transform(&x).unwrap();
756 assert_eq!(result_norm.nrows(), 8);
757 }
758
759 #[test]
760 fn test_nystroem_rbf_with_different_sampling() {
761 let x = array![
762 [1.0, 2.0],
763 [3.0, 4.0],
764 [5.0, 6.0],
765 [7.0, 8.0],
766 [2.0, 1.0],
767 [4.0, 3.0],
768 [6.0, 5.0],
769 [8.0, 7.0]
770 ];
771
772 let kernel = Kernel::Rbf { gamma: 0.1 };
773
774 let nystroem = Nystroem::new(kernel, 4)
776 .sampling_strategy(SamplingStrategy::LeverageScore)
777 .random_state(42);
778 let fitted = nystroem.fit(&x, &()).unwrap();
779 let result = fitted.transform(&x).unwrap();
780
781 assert_eq!(result.shape(), &[8, 4]);
782
783 for val in result.iter() {
785 assert!(val.is_finite());
786 }
787 }
788
789 #[test]
790 fn test_nystroem_improved_eigendecomposition() {
791 let x = array![[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]];
792
793 let nystroem = Nystroem::new(Kernel::Linear, 3)
794 .sampling_strategy(SamplingStrategy::Random)
795 .random_state(42);
796 let fitted = nystroem.fit(&x, &()).unwrap();
797 let result = fitted.transform(&x).unwrap();
798
799 assert_eq!(result.nrows(), 4);
800 assert!(result.ncols() <= 3);
801
802 for val in result.iter() {
804 assert!(val.is_finite());
805 }
806 }
807}