1use crate::nystroem::{Kernel, SamplingStrategy};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::rngs::StdRng as RealStdRng;
6use scirs2_core::random::seq::SliceRandom;
7use scirs2_core::random::Rng;
8use scirs2_core::random::{thread_rng, SeedableRng};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
18pub enum ErrorBoundMethod {
20 SpectralBound,
22 FrobeniusBound,
24 EmpiricalBound,
26 PerturbationBound,
28}
29
30#[derive(Debug, Clone)]
32pub enum ComponentSelectionStrategy {
34 Fixed,
36 ErrorTolerance { tolerance: Float },
38 EigenvalueDecay { threshold: Float },
40 RankBased { max_rank: usize },
42}
43
44#[derive(Debug, Clone)]
75pub struct AdaptiveNystroem<State = Untrained> {
77 pub kernel: Kernel,
79 pub max_components: usize,
81 pub min_components: usize,
83 pub selection_strategy: ComponentSelectionStrategy,
85 pub error_bound_method: ErrorBoundMethod,
87 pub sampling_strategy: SamplingStrategy,
89 pub random_state: Option<u64>,
91
92 components_: Option<Array2<Float>>,
94 normalization_: Option<Array2<Float>>,
95 component_indices_: Option<Vec<usize>>,
96 n_components_selected_: Option<usize>,
97 error_bound_: Option<Float>,
98 eigenvalues_: Option<Array1<Float>>,
99
100 _state: PhantomData<State>,
101}
102
103impl AdaptiveNystroem<Untrained> {
104 pub fn new(kernel: Kernel) -> Self {
106 Self {
107 kernel,
108 max_components: 500,
109 min_components: 10,
110 selection_strategy: ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.1 },
111 error_bound_method: ErrorBoundMethod::SpectralBound,
112 sampling_strategy: SamplingStrategy::LeverageScore,
113 random_state: None,
114 components_: None,
115 normalization_: None,
116 component_indices_: None,
117 n_components_selected_: None,
118 error_bound_: None,
119 eigenvalues_: None,
120 _state: PhantomData,
121 }
122 }
123
124 pub fn max_components(mut self, max_components: usize) -> Self {
126 self.max_components = max_components;
127 self
128 }
129
130 pub fn min_components(mut self, min_components: usize) -> Self {
132 self.min_components = min_components;
133 self
134 }
135
136 pub fn selection_strategy(mut self, strategy: ComponentSelectionStrategy) -> Self {
138 self.selection_strategy = strategy;
139 self
140 }
141
142 pub fn error_bound_method(mut self, method: ErrorBoundMethod) -> Self {
144 self.error_bound_method = method;
145 self
146 }
147
148 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
150 self.sampling_strategy = strategy;
151 self
152 }
153
154 pub fn random_state(mut self, seed: u64) -> Self {
156 self.random_state = Some(seed);
157 self
158 }
159
160 fn select_components_adaptively(
162 &self,
163 x: &Array2<Float>,
164 rng: &mut RealStdRng,
165 ) -> Result<(Vec<usize>, usize)> {
166 let (n_samples, _) = x.dim();
167 let max_comp = self.max_components.min(n_samples);
168
169 match &self.selection_strategy {
170 ComponentSelectionStrategy::Fixed => {
171 let n_comp = self.max_components.min(n_samples);
172 let indices = self.sample_indices(x, n_comp, rng)?;
173 Ok((indices, n_comp))
174 }
175 ComponentSelectionStrategy::ErrorTolerance { tolerance } => {
176 self.select_by_error_tolerance(x, *tolerance, rng)
177 }
178 ComponentSelectionStrategy::EigenvalueDecay { threshold } => {
179 self.select_by_eigenvalue_decay(x, *threshold, rng)
180 }
181 ComponentSelectionStrategy::RankBased { max_rank } => {
182 let n_comp = (*max_rank).min(max_comp);
183 let indices = self.sample_indices(x, n_comp, rng)?;
184 Ok((indices, n_comp))
185 }
186 }
187 }
188
189 fn sample_indices(
191 &self,
192 x: &Array2<Float>,
193 n_components: usize,
194 rng: &mut RealStdRng,
195 ) -> Result<Vec<usize>> {
196 let (n_samples, _) = x.dim();
197
198 match &self.sampling_strategy {
199 SamplingStrategy::Random => {
200 let mut indices: Vec<usize> = (0..n_samples).collect();
201 indices.shuffle(rng);
202 Ok(indices[..n_components].to_vec())
203 }
204 SamplingStrategy::LeverageScore => {
205 let mut scores = Vec::new();
207 for i in 0..n_samples {
208 let row_norm = x.row(i).dot(&x.row(i)).sqrt();
209 scores.push(row_norm + 1e-10);
210 }
211
212 let total_score: Float = scores.iter().sum();
213 let mut selected = Vec::new();
214
215 for _ in 0..n_components {
216 let mut cumsum = 0.0;
217 let target = rng.gen::<f64>() * total_score;
218
219 for (i, &score) in scores.iter().enumerate() {
220 cumsum += score;
221 if cumsum >= target && !selected.contains(&i) {
222 selected.push(i);
223 break;
224 }
225 }
226 }
227
228 while selected.len() < n_components {
230 let idx = rng.gen_range(0..n_samples);
231 if !selected.contains(&idx) {
232 selected.push(idx);
233 }
234 }
235
236 Ok(selected)
237 }
238 _ => {
239 let mut indices: Vec<usize> = (0..n_samples).collect();
241 indices.shuffle(rng);
242 Ok(indices[..n_components].to_vec())
243 }
244 }
245 }
246
247 fn select_by_error_tolerance(
249 &self,
250 x: &Array2<Float>,
251 tolerance: Float,
252 rng: &mut RealStdRng,
253 ) -> Result<(Vec<usize>, usize)> {
254 let mut n_comp = self.min_components;
255 let max_comp = self.max_components.min(x.nrows());
256
257 while n_comp <= max_comp {
258 let indices = self.sample_indices(x, n_comp, rng)?;
259 let error_bound = self.estimate_error_bound(x, &indices)?;
260
261 if error_bound <= tolerance {
262 return Ok((indices, n_comp));
263 }
264
265 n_comp = (n_comp * 2).min(max_comp);
266 }
267
268 let indices = self.sample_indices(x, max_comp, rng)?;
270 Ok((indices, max_comp))
271 }
272
273 fn select_by_eigenvalue_decay(
275 &self,
276 x: &Array2<Float>,
277 threshold: Float,
278 rng: &mut RealStdRng,
279 ) -> Result<(Vec<usize>, usize)> {
280 let max_comp = self.max_components.min(x.nrows());
281 let indices = self.sample_indices(x, max_comp, rng)?;
282
283 let mut components = Array2::zeros((max_comp, x.ncols()));
285 for (i, &idx) in indices.iter().enumerate() {
286 components.row_mut(i).assign(&x.row(idx));
287 }
288
289 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
291 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
292
293 let mut n_comp = self.min_components;
295 let max_eigenvalue = eigenvalues.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
296
297 for (i, &eigenval) in eigenvalues.iter().enumerate() {
298 if eigenval / max_eigenvalue < threshold {
299 n_comp = i.max(self.min_components);
300 break;
301 }
302 }
303
304 n_comp = n_comp.min(max_comp);
305 Ok((indices[..n_comp].to_vec(), n_comp))
306 }
307
308 fn estimate_error_bound(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
310 let n_comp = indices.len();
311 let mut components = Array2::zeros((n_comp, x.ncols()));
312
313 for (i, &idx) in indices.iter().enumerate() {
314 components.row_mut(i).assign(&x.row(idx));
315 }
316
317 match self.error_bound_method {
318 ErrorBoundMethod::SpectralBound => {
319 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
320 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
321
322 let truncated_eigenvalues = &eigenvalues[n_comp.min(eigenvalues.len())..];
324 let error_bound = truncated_eigenvalues.iter().sum::<Float>().sqrt();
325 Ok(error_bound)
326 }
327 ErrorBoundMethod::FrobeniusBound => {
328 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
329 let frobenius_norm = kernel_matrix.mapv(|v| v * v).sum().sqrt();
330
331 let error_bound = frobenius_norm / (n_comp as Float).sqrt();
333 Ok(error_bound)
334 }
335 ErrorBoundMethod::EmpiricalBound => {
336 let subsampled_error = self.compute_subsampled_error(x, indices)?;
338 Ok(subsampled_error)
339 }
340 ErrorBoundMethod::PerturbationBound => {
341 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
343 let condition_number = self.estimate_condition_number(&kernel_matrix);
344 let perturbation_bound = condition_number / (n_comp as Float);
345 Ok(perturbation_bound)
346 }
347 }
348 }
349
350 fn compute_subsampled_error(&self, x: &Array2<Float>, indices: &[usize]) -> Result<Float> {
352 let n_comp = indices.len();
353 let mut components = Array2::zeros((n_comp, x.ncols()));
354
355 for (i, &idx) in indices.iter().enumerate() {
356 components.row_mut(i).assign(&x.row(idx));
357 }
358
359 let subsample_size = (x.nrows() / 10).max(5).min(x.nrows());
361 let mut error_sum = 0.0;
362
363 for i in 0..subsample_size {
364 let x_i = x.row(i);
365
366 let exact_kernel = self.kernel.compute_kernel(
368 &x_i.to_shape((1, x_i.len())).unwrap().to_owned(),
369 &components,
370 );
371
372 let approx_kernel = &exact_kernel * 0.9; let error = (&exact_kernel - &approx_kernel).mapv(|v| v * v).sum();
376 error_sum += error;
377 }
378
379 Ok((error_sum / subsample_size as Float).sqrt())
380 }
381
382 fn approximate_eigenvalues(&self, matrix: &Array2<Float>) -> Vec<Float> {
384 let n = matrix.nrows();
385 if n == 0 {
386 return vec![];
387 }
388
389 let mut eigenvalues = Vec::new();
390 let max_eigenvalues = n.min(10); for _ in 0..max_eigenvalues {
393 let mut v = Array1::ones(n) / (n as Float).sqrt();
394 let max_iter = 50;
395
396 for _ in 0..max_iter {
397 let v_new = matrix.dot(&v);
398 let norm = (v_new.dot(&v_new)).sqrt();
399
400 if norm < 1e-12 {
401 break;
402 }
403
404 v = &v_new / norm;
405 }
406
407 let eigenvalue = v.dot(&matrix.dot(&v));
408 eigenvalues.push(eigenvalue.abs());
409 }
410
411 eigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap());
412 eigenvalues
413 }
414
415 fn estimate_condition_number(&self, matrix: &Array2<Float>) -> Float {
417 let eigenvalues = self.approximate_eigenvalues(matrix);
418 if eigenvalues.len() < 2 {
419 return 1.0;
420 }
421
422 let max_eigenval = eigenvalues[0];
423 let min_eigenval = eigenvalues[eigenvalues.len() - 1];
424
425 if min_eigenval > 1e-12 {
426 max_eigenval / min_eigenval
427 } else {
428 1e12 }
430 }
431}
432
433impl Estimator for AdaptiveNystroem<Untrained> {
434 type Config = ();
435 type Error = SklearsError;
436 type Float = Float;
437
438 fn config(&self) -> &Self::Config {
439 &()
440 }
441}
442
443impl Fit<Array2<Float>, ()> for AdaptiveNystroem<Untrained> {
444 type Fitted = AdaptiveNystroem<Trained>;
445
446 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
447 if self.max_components == 0 {
448 return Err(SklearsError::InvalidInput(
449 "max_components must be positive".to_string(),
450 ));
451 }
452
453 if self.min_components > self.max_components {
454 return Err(SklearsError::InvalidInput(
455 "min_components must be <= max_components".to_string(),
456 ));
457 }
458
459 let mut rng = if let Some(seed) = self.random_state {
460 RealStdRng::seed_from_u64(seed)
461 } else {
462 RealStdRng::from_seed(thread_rng().gen())
463 };
464
465 let (component_indices, n_components_selected) =
467 self.select_components_adaptively(x, &mut rng)?;
468
469 let mut components = Array2::zeros((n_components_selected, x.ncols()));
471 for (i, &idx) in component_indices.iter().enumerate() {
472 components.row_mut(i).assign(&x.row(idx));
473 }
474
475 let kernel_matrix = self.kernel.compute_kernel(&components, &components);
477 let eigenvalues = self.approximate_eigenvalues(&kernel_matrix);
478
479 let eps = 1e-12;
481 let mut kernel_reg = kernel_matrix.clone();
482 for i in 0..n_components_selected {
483 kernel_reg[[i, i]] += eps;
484 }
485
486 let error_bound = self.estimate_error_bound(x, &component_indices)?;
488
489 Ok(AdaptiveNystroem {
490 kernel: self.kernel,
491 max_components: self.max_components,
492 min_components: self.min_components,
493 selection_strategy: self.selection_strategy,
494 error_bound_method: self.error_bound_method,
495 sampling_strategy: self.sampling_strategy,
496 random_state: self.random_state,
497 components_: Some(components),
498 normalization_: Some(kernel_reg),
499 component_indices_: Some(component_indices),
500 n_components_selected_: Some(n_components_selected),
501 error_bound_: Some(error_bound),
502 eigenvalues_: Some(Array1::from_vec(eigenvalues)),
503 _state: PhantomData,
504 })
505 }
506}
507
508impl Transform<Array2<Float>, Array2<Float>> for AdaptiveNystroem<Trained> {
509 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
510 let components = self.components_.as_ref().unwrap();
511 let normalization = self.normalization_.as_ref().unwrap();
512
513 if x.ncols() != components.ncols() {
514 return Err(SklearsError::InvalidInput(format!(
515 "X has {} features, but AdaptiveNystroem was fitted with {} features",
516 x.ncols(),
517 components.ncols()
518 )));
519 }
520
521 let k_x_components = self.kernel.compute_kernel(x, components);
523
524 let result = k_x_components.dot(normalization);
526
527 Ok(result)
528 }
529}
530
531impl AdaptiveNystroem<Trained> {
532 pub fn components(&self) -> &Array2<Float> {
534 self.components_.as_ref().unwrap()
535 }
536
537 pub fn component_indices(&self) -> &[usize] {
539 self.component_indices_.as_ref().unwrap()
540 }
541
542 pub fn n_components_selected(&self) -> usize {
544 self.n_components_selected_.unwrap()
545 }
546
547 pub fn error_bound(&self) -> Float {
549 self.error_bound_.unwrap()
550 }
551
552 pub fn eigenvalues(&self) -> &Array1<Float> {
554 self.eigenvalues_.as_ref().unwrap()
555 }
556
557 pub fn approximation_rank(&self, threshold: Float) -> usize {
559 let eigenvals = self.eigenvalues();
560 if eigenvals.is_empty() {
561 return 0;
562 }
563
564 let max_eigenval = eigenvals.iter().fold(0.0_f64, |a: Float, &b| a.max(b));
565 eigenvals
566 .iter()
567 .take_while(|&&eigenval| eigenval / max_eigenval > threshold)
568 .count()
569 }
570}
571
572#[allow(non_snake_case)]
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use scirs2_core::ndarray::array;
577
578 #[test]
579 fn test_adaptive_nystroem_basic() {
580 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
581
582 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
583 .min_components(1)
584 .max_components(4);
585 let fitted = adaptive.fit(&x, &()).unwrap();
586 let x_transformed = fitted.transform(&x).unwrap();
587
588 assert_eq!(x_transformed.nrows(), 4);
589 assert!(fitted.n_components_selected() >= fitted.min_components);
590 assert!(fitted.n_components_selected() <= fitted.max_components);
591 assert!(fitted.n_components_selected() <= x.nrows()); }
593
594 #[test]
595 fn test_adaptive_nystroem_error_tolerance() {
596 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0],];
597
598 let adaptive = AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 })
599 .selection_strategy(ComponentSelectionStrategy::ErrorTolerance { tolerance: 0.5 })
600 .min_components(1)
601 .max_components(4);
602 let fitted = adaptive.fit(&x, &()).unwrap();
603
604 assert!(fitted.error_bound() <= 0.5 || fitted.n_components_selected() == 4);
605 }
606
607 #[test]
608 fn test_adaptive_nystroem_eigenvalue_decay() {
609 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
610
611 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
612 .selection_strategy(ComponentSelectionStrategy::EigenvalueDecay { threshold: 0.1 });
613 let fitted = adaptive.fit(&x, &()).unwrap();
614 let x_transformed = fitted.transform(&x).unwrap();
615
616 assert_eq!(x_transformed.nrows(), 4);
617 assert!(!fitted.eigenvalues().is_empty());
618 }
619
620 #[test]
621 fn test_adaptive_nystroem_rank_based() {
622 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
623
624 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
625 .selection_strategy(ComponentSelectionStrategy::RankBased { max_rank: 3 });
626 let fitted = adaptive.fit(&x, &()).unwrap();
627
628 assert_eq!(fitted.n_components_selected(), 3);
629 }
630
631 #[test]
632 fn test_adaptive_nystroem_different_error_bounds() {
633 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
634
635 let methods = vec![
636 ErrorBoundMethod::SpectralBound,
637 ErrorBoundMethod::FrobeniusBound,
638 ErrorBoundMethod::EmpiricalBound,
639 ErrorBoundMethod::PerturbationBound,
640 ];
641
642 for method in methods {
643 let adaptive =
644 AdaptiveNystroem::new(Kernel::Rbf { gamma: 0.1 }).error_bound_method(method);
645 let fitted = adaptive.fit(&x, &()).unwrap();
646
647 assert!(fitted.error_bound().is_finite());
648 assert!(fitted.error_bound() >= 0.0);
649 }
650 }
651
652 #[test]
653 fn test_adaptive_nystroem_reproducibility() {
654 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
655
656 let adaptive1 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
657 let fitted1 = adaptive1.fit(&x, &()).unwrap();
658 let result1 = fitted1.transform(&x).unwrap();
659
660 let adaptive2 = AdaptiveNystroem::new(Kernel::Linear).random_state(42);
661 let fitted2 = adaptive2.fit(&x, &()).unwrap();
662 let result2 = fitted2.transform(&x).unwrap();
663
664 assert_eq!(
665 fitted1.n_components_selected(),
666 fitted2.n_components_selected()
667 );
668 assert_eq!(result1.shape(), result2.shape());
669 }
670
671 #[test]
672 fn test_adaptive_nystroem_approximation_rank() {
673 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
674
675 let adaptive = AdaptiveNystroem::new(Kernel::Linear);
676 let fitted = adaptive.fit(&x, &()).unwrap();
677
678 let rank = fitted.approximation_rank(0.1);
679 assert!(rank <= fitted.n_components_selected());
680 assert!(rank > 0);
681 }
682
683 #[test]
684 fn test_adaptive_nystroem_invalid_parameters() {
685 let x = array![[1.0, 2.0]];
686
687 let adaptive = AdaptiveNystroem::new(Kernel::Linear).max_components(0);
689 assert!(adaptive.fit(&x, &()).is_err());
690
691 let adaptive = AdaptiveNystroem::new(Kernel::Linear)
693 .min_components(10)
694 .max_components(5);
695 assert!(adaptive.fit(&x, &()).is_err());
696 }
697}