1use crate::nystroem::{Kernel, SamplingStrategy};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::rngs::StdRng as RealStdRng;
6use scirs2_core::random::seq::SliceRandom;
7use scirs2_core::random::{thread_rng, Rng, SeedableRng};
8use sklears_core::{
9 error::{Result, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
17pub enum ErrorBoundMethod {
19 SpectralBound,
21 FrobeniusBound,
23 EmpiricalBound,
25 PerturbationBound,
27}
28
29#[derive(Debug, Clone)]
31pub enum ComponentSelectionStrategy {
33 Fixed,
35 ErrorTolerance { tolerance: Float },
37 EigenvalueDecay { threshold: Float },
39 RankBased { max_rank: usize },
41}
42
43#[derive(Debug, Clone)]
74pub struct AdaptiveNystroem<State = Untrained> {
76 pub kernel: Kernel,
78 pub max_components: usize,
80 pub min_components: usize,
82 pub selection_strategy: ComponentSelectionStrategy,
84 pub error_bound_method: ErrorBoundMethod,
86 pub sampling_strategy: SamplingStrategy,
88 pub random_state: Option<u64>,
90
91 components_: Option<Array2<Float>>,
93 normalization_: Option<Array2<Float>>,
94 component_indices_: Option<Vec<usize>>,
95 n_components_selected_: Option<usize>,
96 error_bound_: Option<Float>,
97 eigenvalues_: Option<Array1<Float>>,
98
99 _state: PhantomData<State>,
100}
101
102impl AdaptiveNystroem<Untrained> {
103 pub fn new(kernel: Kernel) -> Self {
105 Self {
106 kernel,
107 max_components: 500,
108 min_components: 10,
109 selection_strategy: ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.1 },
110 error_bound_method: ErrorBoundMethod::SpectralBound,
111 sampling_strategy: SamplingStrategy::LeverageScore,
112 random_state: None,
113 components_: None,
114 normalization_: None,
115 component_indices_: None,
116 n_components_selected_: None,
117 error_bound_: None,
118 eigenvalues_: None,
119 _state: PhantomData,
120 }
121 }
122
123 pub fn max_components(mut self, max_components: usize) -> Self {
125 self.max_components = max_components;
126 self
127 }
128
129 pub fn min_components(mut self, min_components: usize) -> Self {
131 self.min_components = min_components;
132 self
133 }
134
135 pub fn selection_strategy(mut self, strategy: ComponentSelectionStrategy) -> Self {
137 self.selection_strategy = strategy;
138 self
139 }
140
141 pub fn error_bound_method(mut self, method: ErrorBoundMethod) -> Self {
143 self.error_bound_method = method;
144 self
145 }
146
147 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
149 self.sampling_strategy = strategy;
150 self
151 }
152
153 pub fn random_state(mut self, seed: u64) -> Self {
155 self.random_state = Some(seed);
156 self
157 }
158
159 fn select_components_adaptively(
161 &self,
162 x: &Array2<Float>,
163 rng: &mut RealStdRng,
164 ) -> Result<(Vec<usize>, usize)> {
165 let (n_samples, _) = x.dim();
166 let max_comp = self.max_components.min(n_samples);
167
168 match &self.selection_strategy {
169 ComponentSelectionStrategy::Fixed => {
170 let n_comp = self.max_components.min(n_samples);
171 let indices = self.sample_indices(x, n_comp, rng)?;
172 Ok((indices, n_comp))
173 }
174 ComponentSelectionStrategy::ErrorTolerance { tolerance } => {
175 self.select_by_error_tolerance(x, *tolerance, rng)
176 }
177 ComponentSelectionStrategy::EigenvalueDecay { threshold } => {
178 self.select_by_eigenvalue_decay(x, *threshold, rng)
179 }
180 ComponentSelectionStrategy::RankBased { max_rank } => {
181 let n_comp = (*max_rank).min(max_comp);
182 let indices = self.sample_indices(x, n_comp, rng)?;
183 Ok((indices, n_comp))
184 }
185 }
186 }
187
188 fn sample_indices(
190 &self,
191 x: &Array2<Float>,
192 n_components: usize,
193 rng: &mut RealStdRng,
194 ) -> Result<Vec<usize>> {
195 let (n_samples, _) = x.dim();
196
197 match &self.sampling_strategy {
198 SamplingStrategy::Random => {
199 let mut indices: Vec<usize> = (0..n_samples).collect();
200 indices.shuffle(rng);
201 Ok(indices[..n_components].to_vec())
202 }
203 SamplingStrategy::LeverageScore => {
204 let mut scores = Vec::new();
206 for i in 0..n_samples {
207 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
208 scores.push(row_norm + 1e-10);
209 }
210
211 let total_score: Float = scores.iter().sum();
212 let mut selected = Vec::new();
213
214 for _ in 0..n_components {
215 let mut cumsum = 0.0;
216 let target = rng.gen::<f64>() * total_score;
217
218 for (i, &score) in scores.iter().enumerate() {
219 cumsum += score;
220 if cumsum >= target && !selected.contains(&i) {
221 selected.push(i);
222 break;
223 }
224 }
225 }
226
227 while selected.len() < n_components {
229 let idx = rng.gen_range(0..n_samples);
230 if !selected.contains(&idx) {
231 selected.push(idx);
232 }
233 }
234
235 Ok(selected)
236 }
237 _ => {
238 let mut indices: Vec<usize> = (0..n_samples).collect();
240 indices.shuffle(rng);
241 Ok(indices[..n_components].to_vec())
242 }
243 }
244 }
245
246 fn select_by_error_tolerance(
248 &self,
249 x: &Array2<Float>,
250 tolerance: Float,
251 rng: &mut RealStdRng,
252 ) -> Result<(Vec<usize>, usize)> {
253 let mut n_comp = self.min_components;
254 let max_comp = self.max_components.min(x.nrows());
255
256 while n_comp <= max_comp {
257 let indices = self.sample_indices(x, n_comp, rng)?;
258 let error_bound = self.estimate_error_bound(x, &indices)?;
259
260 if error_bound <= tolerance {
261 return Ok((indices, n_comp));
262 }
263
264 n_comp = (n_comp * 2).min(max_comp);
265 }
266
267 let indices = self.sample_indices(x, max_comp, rng)?;
269 Ok((indices, max_comp))
270 }
271
272 fn select_by_eigenvalue_decay(
274 &self,
275 x: &Array2<Float>,
276 threshold: Float,
277 rng: &mut RealStdRng,
278 ) -> Result<(Vec<usize>, usize)> {
279 let max_comp = self.max_components.min(x.nrows());
280 let indices = self.sample_indices(x, max_comp, rng)?;
281
282 let mut components = Array2::zeros((max_comp, x.ncols()));
284 for (i, &idx) in indices.iter().enumerate() {
285 components.row_mut(i).assign(&x.row(idx));
286 }
287
288 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
290 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
291
292 let mut n_comp = self.min_components;
294 let max_eigenvalue = eigenvalues.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
295
296 for (i, &eigenval) in eigenvalues.iter().enumerate() {
297 if eigenval / max_eigenvalue < threshold {
298 n_comp = i.max(self.min_components);
299 break;
300 }
301 }
302
303 n_comp = n_comp.min(max_comp);
304 Ok((indices[..n_comp].to_vec(), n_comp))
305 }
306
307 fn estimate_error_bound(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
309 let n_comp = indices.len();
310 let mut components = Array2::zeros((n_comp, x.ncols()));
311
312 for (i, &idx) in indices.iter().enumerate() {
313 components.row_mut(i).assign(&x.row(idx));
314 }
315
316 match self.error_bound_method {
317 ErrorBoundMethod::SpectralBound => {
318 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
319 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
320
321 let truncated_eigenvalues = &eigenvalues[n_comp.min(eigenvalues.len())..];
323 let error_bound = truncated_eigenvalues.iter().sum::<Float>().sqrt();
324 Ok(error_bound)
325 }
326 ErrorBoundMethod::FrobeniusBound => {
327 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
328 let frobenius_norm = kernel_matrix.mapv(|v| v * v).sum().sqrt();
329
330 let error_bound = frobenius_norm / (n_comp as Float).sqrt();
332 Ok(error_bound)
333 }
334 ErrorBoundMethod::EmpiricalBound => {
335 let subsampled_error = self.compute_subsampled_error(x, indices)?;
337 Ok(subsampled_error)
338 }
339 ErrorBoundMethod::PerturbationBound => {
340 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
342 let condition_number = self.estimate_condition_number(&kernel_matrix);
343 let perturbation_bound = condition_number / (n_comp as Float);
344 Ok(perturbation_bound)
345 }
346 }
347 }
348
349 fn compute_subsampled_error(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
351 let n_comp = indices.len();
352 let mut components = Array2::zeros((n_comp, x.ncols()));
353
354 for (i, &idx) in indices.iter().enumerate() {
355 components.row_mut(i).assign(&x.row(idx));
356 }
357
358 let subsample_size = (x.nrows() / 10).max(5).min(x.nrows());
360 let mut error_sum = 0.0;
361
362 for i in 0..subsample_size {
363 let x_i = x.row(i);
364
365 let exact_kernel = self.kernel.compute_kernel(
367 &x_i.to_shape((1, x_i.len())).unwrap().to_owned(),
368 &components,
369 );
370
371 let approx_kernel = &exact_kernel * 0.9; let error = (&exact_kernel - &approx_kernel).mapv(|v| v * v).sum();
375 error_sum += error;
376 }
377
378 Ok((error_sum / subsample_size as Float).sqrt())
379 }
380
381 fn approximate_eigenvalues(&self, matrix: &Array2<Float>) -> Vec<Float> {
383 let n = matrix.nrows();
384 if n == 0 {
385 return vec![];
386 }
387
388 let mut eigenvalues = Vec::new();
389 let max_eigenvalues = n.min(10); for _ in 0..max_eigenvalues {
392 let mut v = Array1::ones(n) / (n as Float).sqrt();
393 let max_iter = 50;
394
395 for _ in 0..max_iter {
396 let v_new = matrix.dot(&v);
397 let norm = (v_new.dot(&v_new)).sqrt();
398
399 if norm < 1e-12 {
400 break;
401 }
402
403 v = &v_new / norm;
404 }
405
406 let eigenvalue = v.dot(&matrix.dot(&v));
407 eigenvalues.push(eigenvalue.abs());
408 }
409
410 eigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap());
411 eigenvalues
412 }
413
414 fn estimate_condition_number(&self, matrix: &Array2<Float>) -> Float {
416 let eigenvalues = self.approximate_eigenvalues(matrix);
417 if eigenvalues.len() < 2 {
418 return 1.0;
419 }
420
421 let max_eigenval = eigenvalues[0];
422 let min_eigenval = eigenvalues[eigenvalues.len() - 1];
423
424 if min_eigenval > 1e-12 {
425 max_eigenval / min_eigenval
426 } else {
427 1e12 }
429 }
430}
431
432impl Estimator for AdaptiveNystroem<Untrained> {
433 type Config = ();
434 type Error = SklearsError;
435 type Float = Float;
436
437 fn config(&self) -> &Self::Config {
438 &()
439 }
440}
441
442impl Fit<Array2<Float>, ()> for AdaptiveNystroem<Untrained> {
443 type Fitted = AdaptiveNystroem<Trained>;
444
445 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
446 if self.max_components == 0 {
447 return Err(SklearsError::InvalidInput(
448 "max_components must be positive".to_string(),
449 ));
450 }
451
452 if self.min_components > self.max_components {
453 return Err(SklearsError::InvalidInput(
454 "min_components must be <= max_components".to_string(),
455 ));
456 }
457
458 let mut rng = if let Some(seed) = self.random_state {
459 RealStdRng::seed_from_u64(seed)
460 } else {
461 RealStdRng::from_seed(thread_rng().gen())
462 };
463
464 let (component_indices, n_components_selected) =
466 self.select_components_adaptively(x, &mut rng)?;
467
468 let mut components = Array2::zeros((n_components_selected, x.ncols()));
470 for (i, &idx) in component_indices.iter().enumerate() {
471 components.row_mut(i).assign(&x.row(idx));
472 }
473
474 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
476 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
477
478 let eps = 1e-12;
480 let mut kernel_reg = kernel_matrix.clone();
481 for i in 0..n_components_selected {
482 kernel_reg[[i, i]] += eps;
483 }
484
485 let error_bound = self.estimate_error_bound(x, &component_indices)?;
487
488 Ok(AdaptiveNystroem {
489 kernel: self.kernel,
490 max_components: self.max_components,
491 min_components: self.min_components,
492 selection_strategy: self.selection_strategy,
493 error_bound_method: self.error_bound_method,
494 sampling_strategy: self.sampling_strategy,
495 random_state: self.random_state,
496 components_: Some(components),
497 normalization_: Some(kernel_reg),
498 component_indices_: Some(component_indices),
499 n_components_selected_: Some(n_components_selected),
500 error_bound_: Some(error_bound),
501 eigenvalues_: Some(Array1::from_vec(eigenvalues)),
502 _state: PhantomData,
503 })
504 }
505}
506
507impl Transform<Array2<Float>, Array2<Float>> for AdaptiveNystroem<Trained> {
508 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
509 let components = self.components_.as_ref().unwrap();
510 let normalization = self.normalization_.as_ref().unwrap();
511
512 if x.ncols() != components.ncols() {
513 return Err(SklearsError::InvalidInput(format!(
514 "X has {} features, but AdaptiveNystroem was fitted with {} features",
515 x.ncols(),
516 components.ncols()
517 )));
518 }
519
520 let k_x_components = self.kernel.compute_kernel(x, components);
522
523 let result = k_x_components.dot(normalization);
525
526 Ok(result)
527 }
528}
529
530impl AdaptiveNystroem<Trained> {
531 pub fn components(&self) -> &Array2<Float> {
533 self.components_.as_ref().unwrap()
534 }
535
536 pub fn component_indices(&self) -> &[usize] {
538 self.component_indices_.as_ref().unwrap()
539 }
540
541 pub fn n_components_selected(&self) -> usize {
543 self.n_components_selected_.unwrap()
544 }
545
546 pub fn error_bound(&self) -> Float {
548 self.error_bound_.unwrap()
549 }
550
551 pub fn eigenvalues(&self) -> &Array1<Float> {
553 self.eigenvalues_.as_ref().unwrap()
554 }
555
556 pub fn approximation_rank(&self, threshold: Float) -> usize {
558 let eigenvals = self.eigenvalues();
559 if eigenvals.is_empty() {
560 return 0;
561 }
562
563 let max_eigenval = eigenvals.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
564 eigenvals
565 .iter()
566 .take_while(|&&eigenval| eigenval / max_eigenval > threshold)
567 .count()
568 }
569}
570
571#[allow(non_snake_case)]
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use scirs2_core::ndarray::array;
576
577 #[test]
578 fn test_adaptive_nystroem_basic() {
579 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
580
581 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
582 .min_components(1)
583 .max_components(4);
584 let fitted = adaptive.fit(&x, &()).unwrap();
585 let x_transformed = fitted.transform(&x).unwrap();
586
587 assert_eq!(x_transformed.nrows(), 4);
588 assert!(fitted.n_components_selected() >= fitted.min_components);
589 assert!(fitted.n_components_selected() <= fitted.max_components);
590 assert!(fitted.n_components_selected() <= x.nrows()); }
592
593 #[test]
594 fn test_adaptive_nystroem_error_tolerance() {
595 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
596
597 let adaptive = AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 })
598 .selection_strategy(ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.5 })
599 .min_components(1)
600 .max_components(4);
601 let fitted = adaptive.fit(&x, &()).unwrap();
602
603 assert!(fitted.error_bound() <= 0.5 || fitted.n_components_selected() == 4);
604 }
605
606 #[test]
607 fn test_adaptive_nystroem_eigenvalue_decay() {
608 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
609
610 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
611 .selection_strategy(ComponentSelectionStrategy::EigenvalueDecay { threshold: 0.1 });
612 let fitted = adaptive.fit(&x, &()).unwrap();
613 let x_transformed = fitted.transform(&x).unwrap();
614
615 assert_eq!(x_transformed.nrows(), 4);
616 assert!(!fitted.eigenvalues().is_empty());
617 }
618
619 #[test]
620 fn test_adaptive_nystroem_rank_based() {
621 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
622
623 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
624 .selection_strategy(ComponentSelectionStrategy::RankBased { max_rank: 3 });
625 let fitted = adaptive.fit(&x, &()).unwrap();
626
627 assert_eq!(fitted.n_components_selected(), 3);
628 }
629
630 #[test]
631 fn test_adaptive_nystroem_different_error_bounds() {
632 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
633
634 let methods = vec![
635 ErrorBoundMethod::SpectralBound,
636 ErrorBoundMethod::FrobeniusBound,
637 ErrorBoundMethod::EmpiricalBound,
638 ErrorBoundMethod::PerturbationBound,
639 ];
640
641 for method in methods {
642 let adaptive =
643 AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 }).error_bound_method(method);
644 let fitted = adaptive.fit(&x, &()).unwrap();
645
646 assert!(fitted.error_bound().is_finite());
647 assert!(fitted.error_bound() >= 0.0);
648 }
649 }
650
651 #[test]
652 fn test_adaptive_nystroem_reproducibility() {
653 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
654
655 let adaptive1 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
656 let fitted1 = adaptive1.fit(&x, &()).unwrap();
657 let result1 = fitted1.transform(&x).unwrap();
658
659 let adaptive2 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
660 let fitted2 = adaptive2.fit(&x, &()).unwrap();
661 let result2 = fitted2.transform(&x).unwrap();
662
663 assert_eq!(
664 fitted1.n_components_selected(),
665 fitted2.n_components_selected()
666 );
667 assert_eq!(result1.shape(), result2.shape());
668 }
669
670 #[test]
671 fn test_adaptive_nystroem_approximation_rank() {
672 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
673
674 let adaptive = AdaptiveNystroem::new(Kernel::Linear);
675 let fitted = adaptive.fit(&x, &()).unwrap();
676
677 let rank = fitted.approximation_rank(0.1);
678 assert!(rank <= fitted.n_components_selected());
679 assert!(rank > 0);
680 }
681
682 #[test]
683 fn test_adaptive_nystroem_invalid_parameters() {
684 let x = array![[1.0, 2.0]];
685
686 let adaptive = AdaptiveNystroem::new(Kernel::Linear).max_components(0);
688 assert!(adaptive.fit(&x, &()).is_err());
689
690 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
692 .min_components(10)
693 .max_components(5);
694 assert!(adaptive.fit(&x, &()).is_err());
695 }
696}