1use crate::nystroem::{Kernel, SamplingStrategy};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::rngs::StdRng as RealStdRng;
6use scirs2_core::random::seq::SliceRandom;
7use scirs2_core::random::RngExt;
8use scirs2_core::random::{thread_rng, SeedableRng};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
18pub enum ErrorBoundMethod {
20 SpectralBound,
22 FrobeniusBound,
24 EmpiricalBound,
26 PerturbationBound,
28}
29
30#[derive(Debug, Clone)]
32pub enum ComponentSelectionStrategy {
34 Fixed,
36 ErrorTolerance { tolerance: Float },
38 EigenvalueDecay { threshold: Float },
40 RankBased { max_rank: usize },
42}
43
44#[derive(Debug, Clone)]
75pub struct AdaptiveNystroem<State = Untrained> {
77 pub kernel: Kernel,
79 pub max_components: usize,
81 pub min_components: usize,
83 pub selection_strategy: ComponentSelectionStrategy,
85 pub error_bound_method: ErrorBoundMethod,
87 pub sampling_strategy: SamplingStrategy,
89 pub random_state: Option<u64>,
91
92 components_: Option<Array2<Float>>,
94 normalization_: Option<Array2<Float>>,
95 component_indices_: Option<Vec<usize>>,
96 n_components_selected_: Option<usize>,
97 error_bound_: Option<Float>,
98 eigenvalues_: Option<Array1<Float>>,
99
100 _state: PhantomData<State>,
101}
102
103impl AdaptiveNystroem<Untrained> {
104 pub fn new(kernel: Kernel) -> Self {
106 Self {
107 kernel,
108 max_components: 500,
109 min_components: 10,
110 selection_strategy: ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.1 },
111 error_bound_method: ErrorBoundMethod::SpectralBound,
112 sampling_strategy: SamplingStrategy::LeverageScore,
113 random_state: None,
114 components_: None,
115 normalization_: None,
116 component_indices_: None,
117 n_components_selected_: None,
118 error_bound_: None,
119 eigenvalues_: None,
120 _state: PhantomData,
121 }
122 }
123
124 pub fn max_components(mut self, max_components: usize) -> Self {
126 self.max_components = max_components;
127 self
128 }
129
130 pub fn min_components(mut self, min_components: usize) -> Self {
132 self.min_components = min_components;
133 self
134 }
135
136 pub fn selection_strategy(mut self, strategy: ComponentSelectionStrategy) -> Self {
138 self.selection_strategy = strategy;
139 self
140 }
141
142 pub fn error_bound_method(mut self, method: ErrorBoundMethod) -> Self {
144 self.error_bound_method = method;
145 self
146 }
147
148 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
150 self.sampling_strategy = strategy;
151 self
152 }
153
154 pub fn random_state(mut self, seed: u64) -> Self {
156 self.random_state = Some(seed);
157 self
158 }
159
160 fn select_components_adaptively(
162 &self,
163 x: &Array2<Float>,
164 rng: &mut RealStdRng,
165 ) -> Result<(Vec<usize>, usize)> {
166 let (n_samples, _) = x.dim();
167 let max_comp = self.max_components.min(n_samples);
168
169 match &self.selection_strategy {
170 ComponentSelectionStrategy::Fixed => {
171 let n_comp = self.max_components.min(n_samples);
172 let indices = self.sample_indices(x, n_comp, rng)?;
173 Ok((indices, n_comp))
174 }
175 ComponentSelectionStrategy::ErrorTolerance { tolerance } => {
176 self.select_by_error_tolerance(x, *tolerance, rng)
177 }
178 ComponentSelectionStrategy::EigenvalueDecay { threshold } => {
179 self.select_by_eigenvalue_decay(x, *threshold, rng)
180 }
181 ComponentSelectionStrategy::RankBased { max_rank } => {
182 let n_comp = (*max_rank).min(max_comp);
183 let indices = self.sample_indices(x, n_comp, rng)?;
184 Ok((indices, n_comp))
185 }
186 }
187 }
188
189 fn sample_indices(
191 &self,
192 x: &Array2<Float>,
193 n_components: usize,
194 rng: &mut RealStdRng,
195 ) -> Result<Vec<usize>> {
196 let (n_samples, _) = x.dim();
197
198 match &self.sampling_strategy {
199 SamplingStrategy::Random => {
200 let mut indices: Vec<usize> = (0..n_samples).collect();
201 indices.shuffle(rng);
202 Ok(indices[..n_components].to_vec())
203 }
204 SamplingStrategy::LeverageScore => {
205 let mut scores = Vec::new();
207 for i in 0..n_samples {
208 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
209 scores.push(row_norm + 1e-10);
210 }
211
212 let total_score: Float = scores.iter().sum();
213 let mut selected = Vec::new();
214
215 for _ in 0..n_components {
216 let mut cumsum = 0.0;
217 let target = rng.random::<f64>() * total_score;
218
219 for (i, &score) in scores.iter().enumerate() {
220 cumsum += score;
221 if cumsum >= target && !selected.contains(&i) {
222 selected.push(i);
223 break;
224 }
225 }
226 }
227
228 while selected.len() < n_components {
230 let idx = rng.random_range(0..n_samples);
231 if !selected.contains(&idx) {
232 selected.push(idx);
233 }
234 }
235
236 Ok(selected)
237 }
238 _ => {
239 let mut indices: Vec<usize> = (0..n_samples).collect();
241 indices.shuffle(rng);
242 Ok(indices[..n_components].to_vec())
243 }
244 }
245 }
246
247 fn select_by_error_tolerance(
249 &self,
250 x: &Array2<Float>,
251 tolerance: Float,
252 rng: &mut RealStdRng,
253 ) -> Result<(Vec<usize>, usize)> {
254 let mut n_comp = self.min_components;
255 let max_comp = self.max_components.min(x.nrows());
256
257 while n_comp <= max_comp {
258 let indices = self.sample_indices(x, n_comp, rng)?;
259 let error_bound = self.estimate_error_bound(x, &indices)?;
260
261 if error_bound <= tolerance {
262 return Ok((indices, n_comp));
263 }
264
265 n_comp = (n_comp * 2).min(max_comp);
266 }
267
268 let indices = self.sample_indices(x, max_comp, rng)?;
270 Ok((indices, max_comp))
271 }
272
273 fn select_by_eigenvalue_decay(
275 &self,
276 x: &Array2<Float>,
277 threshold: Float,
278 rng: &mut RealStdRng,
279 ) -> Result<(Vec<usize>, usize)> {
280 let max_comp = self.max_components.min(x.nrows());
281 let indices = self.sample_indices(x, max_comp, rng)?;
282
283 let mut components = Array2::zeros((max_comp, x.ncols()));
285 for (i, &idx) in indices.iter().enumerate() {
286 components.row_mut(i).assign(&x.row(idx));
287 }
288
289 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
291 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
292
293 let mut n_comp = self.min_components;
295 let max_eigenvalue = eigenvalues.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
296
297 for (i, &eigenval) in eigenvalues.iter().enumerate() {
298 if eigenval / max_eigenvalue < threshold {
299 n_comp = i.max(self.min_components);
300 break;
301 }
302 }
303
304 n_comp = n_comp.min(max_comp);
305 Ok((indices[..n_comp].to_vec(), n_comp))
306 }
307
308 fn estimate_error_bound(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
310 let n_comp = indices.len();
311 let mut components = Array2::zeros((n_comp, x.ncols()));
312
313 for (i, &idx) in indices.iter().enumerate() {
314 components.row_mut(i).assign(&x.row(idx));
315 }
316
317 match self.error_bound_method {
318 ErrorBoundMethod::SpectralBound => {
319 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
320 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
321
322 let truncated_eigenvalues = &eigenvalues[n_comp.min(eigenvalues.len())..];
324 let error_bound = truncated_eigenvalues.iter().sum::<Float>().sqrt();
325 Ok(error_bound)
326 }
327 ErrorBoundMethod::FrobeniusBound => {
328 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
329 let frobenius_norm = kernel_matrix.mapv(|v| v * v).sum().sqrt();
330
331 let error_bound = frobenius_norm / (n_comp as Float).sqrt();
333 Ok(error_bound)
334 }
335 ErrorBoundMethod::EmpiricalBound => {
336 let subsampled_error = self.compute_subsampled_error(x, indices)?;
338 Ok(subsampled_error)
339 }
340 ErrorBoundMethod::PerturbationBound => {
341 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
343 let condition_number = self.estimate_condition_number(&kernel_matrix);
344 let perturbation_bound = condition_number / (n_comp as Float);
345 Ok(perturbation_bound)
346 }
347 }
348 }
349
350 fn compute_subsampled_error(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
352 let n_comp = indices.len();
353 let mut components = Array2::zeros((n_comp, x.ncols()));
354
355 for (i, &idx) in indices.iter().enumerate() {
356 components.row_mut(i).assign(&x.row(idx));
357 }
358
359 let subsample_size = (x.nrows() / 10).max(5).min(x.nrows());
361 let mut error_sum = 0.0;
362
363 for i in 0..subsample_size {
364 let x_i = x.row(i);
365
366 let exact_kernel = self.kernel.compute_kernel(
368 &x_i.to_shape((1, x_i.len()))
369 .expect("operation should succeed")
370 .to_owned(),
371 &components,
372 );
373
374 let approx_kernel = &exact_kernel * 0.9; let error = (&exact_kernel - &approx_kernel).mapv(|v| v * v).sum();
378 error_sum += error;
379 }
380
381 Ok((error_sum / subsample_size as Float).sqrt())
382 }
383
384 fn approximate_eigenvalues(&self, matrix: &Array2<Float>) -> Vec<Float> {
386 let n = matrix.nrows();
387 if n == 0 {
388 return vec![];
389 }
390
391 let mut eigenvalues = Vec::new();
392 let max_eigenvalues = n.min(10); for _ in 0..max_eigenvalues {
395 let mut v = Array1::ones(n) / (n as Float).sqrt();
396 let max_iter = 50;
397
398 for _ in 0..max_iter {
399 let v_new = matrix.dot(&v);
400 let norm = (v_new.dot(&v_new)).sqrt();
401
402 if norm < 1e-12 {
403 break;
404 }
405
406 v = &v_new / norm;
407 }
408
409 let eigenvalue = v.dot(&matrix.dot(&v));
410 eigenvalues.push(eigenvalue.abs());
411 }
412
413 eigenvalues.sort_by(|a, b| b.partial_cmp(a).expect("operation should succeed"));
414 eigenvalues
415 }
416
417 fn estimate_condition_number(&self, matrix: &Array2<Float>) -> Float {
419 let eigenvalues = self.approximate_eigenvalues(matrix);
420 if eigenvalues.len() < 2 {
421 return 1.0;
422 }
423
424 let max_eigenval = eigenvalues[0];
425 let min_eigenval = eigenvalues[eigenvalues.len() - 1];
426
427 if min_eigenval > 1e-12 {
428 max_eigenval / min_eigenval
429 } else {
430 1e12 }
432 }
433}
434
435impl Estimator for AdaptiveNystroem<Untrained> {
436 type Config = ();
437 type Error = SklearsError;
438 type Float = Float;
439
440 fn config(&self) -> &Self::Config {
441 &()
442 }
443}
444
445impl Fit<Array2<Float>, ()> for AdaptiveNystroem<Untrained> {
446 type Fitted = AdaptiveNystroem<Trained>;
447
448 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
449 if self.max_components == 0 {
450 return Err(SklearsError::InvalidInput(
451 "max_components must be positive".to_string(),
452 ));
453 }
454
455 if self.min_components > self.max_components {
456 return Err(SklearsError::InvalidInput(
457 "min_components must be <= max_components".to_string(),
458 ));
459 }
460
461 let mut rng = if let Some(seed) = self.random_state {
462 RealStdRng::seed_from_u64(seed)
463 } else {
464 RealStdRng::from_seed(thread_rng().random())
465 };
466
467 let (component_indices, n_components_selected) =
469 self.select_components_adaptively(x, &mut rng)?;
470
471 let mut components = Array2::zeros((n_components_selected, x.ncols()));
473 for (i, &idx) in component_indices.iter().enumerate() {
474 components.row_mut(i).assign(&x.row(idx));
475 }
476
477 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
479 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
480
481 let eps = 1e-12;
483 let mut kernel_reg = kernel_matrix.clone();
484 for i in 0..n_components_selected {
485 kernel_reg[[i, i]] += eps;
486 }
487
488 let error_bound = self.estimate_error_bound(x, &component_indices)?;
490
491 Ok(AdaptiveNystroem {
492 kernel: self.kernel,
493 max_components: self.max_components,
494 min_components: self.min_components,
495 selection_strategy: self.selection_strategy,
496 error_bound_method: self.error_bound_method,
497 sampling_strategy: self.sampling_strategy,
498 random_state: self.random_state,
499 components_: Some(components),
500 normalization_: Some(kernel_reg),
501 component_indices_: Some(component_indices),
502 n_components_selected_: Some(n_components_selected),
503 error_bound_: Some(error_bound),
504 eigenvalues_: Some(Array1::from_vec(eigenvalues)),
505 _state: PhantomData,
506 })
507 }
508}
509
510impl Transform<Array2<Float>, Array2<Float>> for AdaptiveNystroem<Trained> {
511 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
512 let components = self.components_.as_ref().expect("operation should succeed");
513 let normalization = self
514 .normalization_
515 .as_ref()
516 .expect("operation should succeed");
517
518 if x.ncols() != components.ncols() {
519 return Err(SklearsError::InvalidInput(format!(
520 "X has {} features, but AdaptiveNystroem was fitted with {} features",
521 x.ncols(),
522 components.ncols()
523 )));
524 }
525
526 let k_x_components = self.kernel.compute_kernel(x, components);
528
529 let result = k_x_components.dot(normalization);
531
532 Ok(result)
533 }
534}
535
536impl AdaptiveNystroem<Trained> {
537 pub fn components(&self) -> &Array2<Float> {
539 self.components_.as_ref().expect("operation should succeed")
540 }
541
542 pub fn component_indices(&self) -> &[usize] {
544 self.component_indices_
545 .as_ref()
546 .expect("operation should succeed")
547 }
548
549 pub fn n_components_selected(&self) -> usize {
551 self.n_components_selected_
552 .expect("operation should succeed")
553 }
554
555 pub fn error_bound(&self) -> Float {
557 self.error_bound_.expect("operation should succeed")
558 }
559
560 pub fn eigenvalues(&self) -> &Array1<Float> {
562 self.eigenvalues_
563 .as_ref()
564 .expect("operation should succeed")
565 }
566
567 pub fn approximation_rank(&self, threshold: Float) -> usize {
569 let eigenvals = self.eigenvalues();
570 if eigenvals.is_empty() {
571 return 0;
572 }
573
574 let max_eigenval = eigenvals.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
575 eigenvals
576 .iter()
577 .take_while(|&&eigenval| eigenval / max_eigenval > threshold)
578 .count()
579 }
580}
581
582#[allow(non_snake_case)]
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use scirs2_core::ndarray::array;
587
588 #[test]
589 fn test_adaptive_nystroem_basic() {
590 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
591
592 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
593 .min_components(1)
594 .max_components(4);
595 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
596 let x_transformed = fitted.transform(&x).expect("operation should succeed");
597
598 assert_eq!(x_transformed.nrows(), 4);
599 assert!(fitted.n_components_selected() >= fitted.min_components);
600 assert!(fitted.n_components_selected() <= fitted.max_components);
601 assert!(fitted.n_components_selected() <= x.nrows()); }
603
604 #[test]
605 fn test_adaptive_nystroem_error_tolerance() {
606 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
607
608 let adaptive = AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 })
609 .selection_strategy(ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.5 })
610 .min_components(1)
611 .max_components(4);
612 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
613
614 assert!(fitted.error_bound() <= 0.5 || fitted.n_components_selected() == 4);
615 }
616
617 #[test]
618 fn test_adaptive_nystroem_eigenvalue_decay() {
619 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
620
621 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
622 .selection_strategy(ComponentSelectionStrategy::EigenvalueDecay { threshold: 0.1 });
623 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
624 let x_transformed = fitted.transform(&x).expect("operation should succeed");
625
626 assert_eq!(x_transformed.nrows(), 4);
627 assert!(!fitted.eigenvalues().is_empty());
628 }
629
630 #[test]
631 fn test_adaptive_nystroem_rank_based() {
632 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
633
634 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
635 .selection_strategy(ComponentSelectionStrategy::RankBased { max_rank: 3 });
636 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
637
638 assert_eq!(fitted.n_components_selected(), 3);
639 }
640
641 #[test]
642 fn test_adaptive_nystroem_different_error_bounds() {
643 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
644
645 let methods = vec![
646 ErrorBoundMethod::SpectralBound,
647 ErrorBoundMethod::FrobeniusBound,
648 ErrorBoundMethod::EmpiricalBound,
649 ErrorBoundMethod::PerturbationBound,
650 ];
651
652 for method in methods {
653 let adaptive =
654 AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 }).error_bound_method(method);
655 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
656
657 assert!(fitted.error_bound().is_finite());
658 assert!(fitted.error_bound() >= 0.0);
659 }
660 }
661
662 #[test]
663 fn test_adaptive_nystroem_reproducibility() {
664 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
665
666 let adaptive1 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
667 let fitted1 = adaptive1.fit(&x, &()).expect("operation should succeed");
668 let result1 = fitted1.transform(&x).expect("operation should succeed");
669
670 let adaptive2 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
671 let fitted2 = adaptive2.fit(&x, &()).expect("operation should succeed");
672 let result2 = fitted2.transform(&x).expect("operation should succeed");
673
674 assert_eq!(
675 fitted1.n_components_selected(),
676 fitted2.n_components_selected()
677 );
678 assert_eq!(result1.shape(), result2.shape());
679 }
680
681 #[test]
682 fn test_adaptive_nystroem_approximation_rank() {
683 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
684
685 let adaptive = AdaptiveNystroem::new(Kernel::Linear);
686 let fitted = adaptive.fit(&x, &()).expect("operation should succeed");
687
688 let rank = fitted.approximation_rank(0.1);
689 assert!(rank <= fitted.n_components_selected());
690 assert!(rank > 0);
691 }
692
693 #[test]
694 fn test_adaptive_nystroem_invalid_parameters() {
695 let x = array![[1.0, 2.0]];
696
697 let adaptive = AdaptiveNystroem::new(Kernel::Linear).max_components(0);
699 assert!(adaptive.fit(&x, &()).is_err());
700
701 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
703 .min_components(10)
704 .max_components(5);
705 assert!(adaptive.fit(&x, &()).is_err());
706 }
707}