1use crate::{Nystroem, RBFSampler};
7use scirs2_core::ndarray::ndarray_linalg::{Norm, SVD};
8use scirs2_core::ndarray::Array2;
9use sklears_core::traits::Fit;
10use sklears_core::{
11 error::{Result, SklearsError},
12 traits::Transform,
13};
14use std::time::Instant;
15
16#[derive(Debug, Clone)]
18pub enum ProgressiveStrategy {
20 Doubling,
22 FixedIncrement { increment: usize },
24 AdaptiveIncrement {
26 min_increment: usize,
27
28 max_increment: usize,
29
30 improvement_threshold: f64,
31 },
32 Exponential { base: f64 },
34 Fibonacci,
36}
37
38#[derive(Debug, Clone)]
40pub enum StoppingCriterion {
42 TargetQuality { quality: f64 },
44 ImprovementThreshold { threshold: f64 },
46 MaxIterations { max_iter: usize },
48 MaxComponents { max_components: usize },
50 Combined {
52 quality: Option<f64>,
53 improvement_threshold: Option<f64>,
54 max_iter: Option<usize>,
55 max_components: Option<usize>,
56 },
57}
58
59#[derive(Debug, Clone)]
61pub enum ProgressiveQualityMetric {
63 KernelAlignment,
65 FrobeniusError,
67 SpectralError,
69 EffectiveRank,
71 RelativeImprovement,
73 Custom,
75}
76
77#[derive(Debug, Clone)]
79pub struct ProgressiveConfig {
81 pub initial_components: usize,
83 pub strategy: ProgressiveStrategy,
85 pub stopping_criterion: StoppingCriterion,
87 pub quality_metric: ProgressiveQualityMetric,
89 pub n_trials: usize,
91 pub random_seed: Option<u64>,
93 pub validation_fraction: f64,
95 pub store_intermediate: bool,
97}
98
99impl Default for ProgressiveConfig {
100 fn default() -> Self {
101 Self {
102 initial_components: 10,
103 strategy: ProgressiveStrategy::Doubling,
104 stopping_criterion: StoppingCriterion::Combined {
105 quality: Some(0.95),
106 improvement_threshold: Some(0.01),
107 max_iter: Some(10),
108 max_components: Some(1000),
109 },
110 quality_metric: ProgressiveQualityMetric::KernelAlignment,
111 n_trials: 3,
112 random_seed: None,
113 validation_fraction: 0.2,
114 store_intermediate: true,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct ProgressiveStep {
123 pub n_components: usize,
125 pub quality_score: f64,
127 pub improvement: f64,
129 pub time_taken: f64,
131 pub iteration: usize,
133}
134
135#[derive(Debug, Clone)]
137pub struct ProgressiveResult {
139 pub final_components: usize,
141 pub final_quality: f64,
143 pub steps: Vec<ProgressiveStep>,
145 pub converged: bool,
147 pub stopping_reason: String,
149 pub total_time: f64,
151}
152
153#[derive(Debug, Clone)]
155pub struct ProgressiveRBFSampler {
157 gamma: f64,
158 config: ProgressiveConfig,
159}
160
161impl ProgressiveRBFSampler {
162 pub fn new() -> Self {
164 Self {
165 gamma: 1.0,
166 config: ProgressiveConfig::default(),
167 }
168 }
169
170 pub fn gamma(mut self, gamma: f64) -> Self {
172 self.gamma = gamma;
173 self
174 }
175
176 pub fn config(mut self, config: ProgressiveConfig) -> Self {
178 self.config = config;
179 self
180 }
181
182 pub fn initial_components(mut self, components: usize) -> Self {
184 self.config.initial_components = components;
185 self
186 }
187
188 pub fn strategy(mut self, strategy: ProgressiveStrategy) -> Self {
190 self.config.strategy = strategy;
191 self
192 }
193
194 pub fn stopping_criterion(mut self, criterion: StoppingCriterion) -> Self {
196 self.config.stopping_criterion = criterion;
197 self
198 }
199
200 pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
202 let start_time = Instant::now();
203 let n_samples = x.nrows();
204
205 let split_idx = (n_samples as f64 * (1.0 - self.config.validation_fraction)) as usize;
207 let x_train = x
208 .slice(scirs2_core::ndarray::s![..split_idx, ..])
209 .to_owned();
210 let x_val = x
211 .slice(scirs2_core::ndarray::s![split_idx.., ..])
212 .to_owned();
213
214 let k_exact = self.compute_exact_kernel_matrix(&x_val)?;
216
217 let mut steps = Vec::new();
218 let mut current_components = self.config.initial_components;
219 let mut previous_quality = 0.0;
220 let mut iteration = 0;
221 let mut converged = false;
222 let mut stopping_reason = String::from("Max iterations reached");
223
224 let mut fib_prev = 1;
226 let mut fib_curr = 1;
227
228 loop {
229 let step_start = Instant::now();
230
231 let quality = self.compute_quality_for_components(
233 current_components,
234 &x_train,
235 &x_val,
236 &k_exact,
237 )?;
238
239 let improvement = if iteration == 0 {
240 quality
241 } else {
242 quality - previous_quality
243 };
244
245 let step_time = step_start.elapsed().as_secs_f64();
246
247 let step = ProgressiveStep {
249 n_components: current_components,
250 quality_score: quality,
251 improvement,
252 time_taken: step_time,
253 iteration,
254 };
255 steps.push(step);
256
257 if let Some((converged_flag, reason)) =
259 self.check_stopping_criteria(quality, improvement, iteration, current_components)
260 {
261 converged = converged_flag;
262 stopping_reason = reason;
263 break;
264 }
265
266 previous_quality = quality;
268 iteration += 1;
269
270 current_components = match &self.config.strategy {
272 ProgressiveStrategy::Doubling => current_components * 2,
273 ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
274 ProgressiveStrategy::AdaptiveIncrement {
275 min_increment,
276 max_increment,
277 improvement_threshold,
278 } => {
279 let increment = if improvement > *improvement_threshold {
280 *min_increment
281 } else {
282 (*min_increment + (*max_increment - *min_increment) / 2).max(*min_increment)
283 };
284 current_components + increment
285 }
286 ProgressiveStrategy::Exponential { base } => {
287 ((current_components as f64) * base) as usize
288 }
289 ProgressiveStrategy::Fibonacci => {
290 let next_fib = fib_prev + fib_curr;
291 fib_prev = fib_curr;
292 fib_curr = next_fib;
293 self.config.initial_components + fib_curr
294 }
295 };
296 }
297
298 let total_time = start_time.elapsed().as_secs_f64();
299
300 Ok(ProgressiveResult {
301 final_components: steps
302 .last()
303 .map(|s| s.n_components)
304 .unwrap_or(current_components),
305 final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
306 steps,
307 converged,
308 stopping_reason,
309 total_time,
310 })
311 }
312
313 fn compute_exact_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
315 let n_samples = x.nrows().min(100); let x_subset = x.slice(scirs2_core::ndarray::s![..n_samples, ..]);
317
318 let mut k_exact = Array2::zeros((n_samples, n_samples));
319
320 for i in 0..n_samples {
321 for j in 0..n_samples {
322 let diff = &x_subset.row(i) - &x_subset.row(j);
323 let squared_norm = diff.dot(&diff);
324 k_exact[[i, j]] = (-self.gamma * squared_norm).exp();
325 }
326 }
327
328 Ok(k_exact)
329 }
330
331 fn compute_quality_for_components(
333 &self,
334 n_components: usize,
335 x_train: &Array2<f64>,
336 x_val: &Array2<f64>,
337 k_exact: &Array2<f64>,
338 ) -> Result<f64> {
339 let mut trial_qualities = Vec::new();
340
341 for trial in 0..self.config.n_trials {
343 let seed = self.config.random_seed.map(|s| s + trial as u64);
344 let sampler = if let Some(s) = seed {
345 RBFSampler::new(n_components)
346 .gamma(self.gamma)
347 .random_state(s)
348 } else {
349 RBFSampler::new(n_components).gamma(self.gamma)
350 };
351
352 let fitted = sampler.fit(x_train, &())?;
353 let x_val_transformed = fitted.transform(x_val)?;
354
355 let quality = self.compute_quality_metric(x_val, &x_val_transformed, k_exact)?;
356 trial_qualities.push(quality);
357 }
358
359 Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
361 }
362
363 fn compute_quality_metric(
365 &self,
366 x: &Array2<f64>,
367 x_transformed: &Array2<f64>,
368 k_exact: &Array2<f64>,
369 ) -> Result<f64> {
370 match &self.config.quality_metric {
371 ProgressiveQualityMetric::KernelAlignment => {
372 self.compute_kernel_alignment(x_transformed, k_exact)
373 }
374 ProgressiveQualityMetric::FrobeniusError => {
375 self.compute_frobenius_error(x_transformed, k_exact)
376 }
377 ProgressiveQualityMetric::SpectralError => {
378 self.compute_spectral_error(x_transformed, k_exact)
379 }
380 ProgressiveQualityMetric::EffectiveRank => self.compute_effective_rank(x_transformed),
381 ProgressiveQualityMetric::RelativeImprovement => {
382 Ok(1.0)
384 }
385 ProgressiveQualityMetric::Custom => {
386 self.compute_kernel_alignment(x_transformed, k_exact)
388 }
389 }
390 }
391
392 fn compute_kernel_alignment(
394 &self,
395 x_transformed: &Array2<f64>,
396 k_exact: &Array2<f64>,
397 ) -> Result<f64> {
398 let n_samples = k_exact.nrows().min(x_transformed.nrows());
399 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
400
401 let k_approx = x_subset.dot(&x_subset.t());
403
404 let k_exact_norm = k_exact.norm_l2();
406 let k_approx_norm = k_approx.norm_l2();
407
408 if k_exact_norm > 1e-12 && k_approx_norm > 1e-12 {
409 let alignment = (k_exact * &k_approx).sum() / (k_exact_norm * k_approx_norm);
410 Ok(alignment)
411 } else {
412 Ok(0.0)
413 }
414 }
415
416 fn compute_frobenius_error(
418 &self,
419 x_transformed: &Array2<f64>,
420 k_exact: &Array2<f64>,
421 ) -> Result<f64> {
422 let n_samples = k_exact.nrows().min(x_transformed.nrows());
423 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
424
425 let k_approx = x_subset.dot(&x_subset.t());
427
428 let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
430 let error = diff.norm_l2();
431 let quality = 1.0 / (1.0 + error); Ok(quality)
434 }
435
436 fn compute_spectral_error(
438 &self,
439 x_transformed: &Array2<f64>,
440 k_exact: &Array2<f64>,
441 ) -> Result<f64> {
442 let n_samples = k_exact.nrows().min(x_transformed.nrows());
443 let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
444
445 let k_approx = x_subset.dot(&x_subset.t());
447
448 let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
450 let (_, s, _) = diff
451 .svd(false, false)
452 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
453
454 let spectral_error = s.iter().fold(0.0f64, |acc, &x| acc.max(x));
455 let quality = 1.0 / (1.0 + spectral_error);
456
457 Ok(quality)
458 }
459
460 fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
462 let (_, s, _) = x_transformed
464 .svd(true, true)
465 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
466
467 let s_sum = s.sum();
469 if s_sum == 0.0 {
470 return Ok(0.0);
471 }
472
473 let s_normalized = &s / s_sum;
474 let entropy = -s_normalized
475 .iter()
476 .filter(|&&x| x > 1e-12)
477 .map(|&x| x * x.ln())
478 .sum::<f64>();
479
480 let effective_rank = entropy.exp();
481 Ok(effective_rank / x_transformed.ncols() as f64) }
483
484 fn check_stopping_criteria(
486 &self,
487 quality: f64,
488 improvement: f64,
489 iteration: usize,
490 components: usize,
491 ) -> Option<(bool, String)> {
492 match &self.config.stopping_criterion {
493 StoppingCriterion::TargetQuality { quality: target } => {
494 if quality >= *target {
495 Some((true, format!("Target quality {} reached", target)))
496 } else {
497 None
498 }
499 }
500 StoppingCriterion::ImprovementThreshold { threshold } => {
501 if iteration > 0 && improvement < *threshold {
502 Some((
503 true,
504 format!("Improvement {} below threshold {}", improvement, threshold),
505 ))
506 } else {
507 None
508 }
509 }
510 StoppingCriterion::MaxIterations { max_iter } => {
511 if iteration + 1 >= *max_iter {
512 Some((false, format!("Maximum iterations {} reached", max_iter)))
513 } else {
514 None
515 }
516 }
517 StoppingCriterion::MaxComponents { max_components } => {
518 if components >= *max_components {
519 Some((
520 false,
521 format!("Maximum components {} reached", max_components),
522 ))
523 } else {
524 None
525 }
526 }
527 StoppingCriterion::Combined {
528 quality: target_quality,
529 improvement_threshold,
530 max_iter,
531 max_components,
532 } => {
533 if let Some(target) = target_quality {
535 if quality >= *target {
536 return Some((true, format!("Target quality {} reached", target)));
537 }
538 }
539
540 if let Some(threshold) = improvement_threshold {
542 if iteration > 0 && improvement < *threshold {
543 return Some((
544 true,
545 format!("Improvement {} below threshold {}", improvement, threshold),
546 ));
547 }
548 }
549
550 if let Some(max) = max_iter {
552 if iteration >= *max {
553 return Some((false, format!("Maximum iterations {} reached", max)));
554 }
555 }
556
557 if let Some(max) = max_components {
559 if components >= *max {
560 return Some((false, format!("Maximum components {} reached", max)));
561 }
562 }
563
564 None
565 }
566 }
567 }
568}
569
570pub struct FittedProgressiveRBFSampler {
572 fitted_rbf: crate::rbf_sampler::RBFSampler<sklears_core::traits::Trained>,
573 progressive_result: ProgressiveResult,
574}
575
576impl Fit<Array2<f64>, ()> for ProgressiveRBFSampler {
577 type Fitted = FittedProgressiveRBFSampler;
578
579 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
580 let progressive_result = self.run_progressive_approximation(x)?;
582
583 let rbf_sampler = RBFSampler::new(progressive_result.final_components).gamma(self.gamma);
585 let fitted_rbf = rbf_sampler.fit(x, &())?;
586
587 Ok(FittedProgressiveRBFSampler {
588 fitted_rbf,
589 progressive_result,
590 })
591 }
592}
593
594impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveRBFSampler {
595 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
596 self.fitted_rbf.transform(x)
597 }
598}
599
600impl FittedProgressiveRBFSampler {
601 pub fn progressive_result(&self) -> &ProgressiveResult {
603 &self.progressive_result
604 }
605
606 pub fn final_components(&self) -> usize {
608 self.progressive_result.final_components
609 }
610
611 pub fn final_quality(&self) -> f64 {
613 self.progressive_result.final_quality
614 }
615
616 pub fn converged(&self) -> bool {
618 self.progressive_result.converged
619 }
620
621 pub fn steps(&self) -> &[ProgressiveStep] {
623 &self.progressive_result.steps
624 }
625
626 pub fn stopping_reason(&self) -> &str {
628 &self.progressive_result.stopping_reason
629 }
630}
631
632#[derive(Debug, Clone)]
634pub struct ProgressiveNystroem {
636 kernel: crate::nystroem::Kernel,
637 config: ProgressiveConfig,
638}
639
640impl ProgressiveNystroem {
641 pub fn new() -> Self {
643 Self {
644 kernel: crate::nystroem::Kernel::Rbf { gamma: 1.0 },
645 config: ProgressiveConfig::default(),
646 }
647 }
648
649 pub fn gamma(mut self, gamma: f64) -> Self {
651 self.kernel = crate::nystroem::Kernel::Rbf { gamma };
652 self
653 }
654
655 pub fn kernel(mut self, kernel: crate::nystroem::Kernel) -> Self {
657 self.kernel = kernel;
658 self
659 }
660
661 pub fn config(mut self, config: ProgressiveConfig) -> Self {
663 self.config = config;
664 self
665 }
666
667 pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
669 let start_time = Instant::now();
670
671 let mut steps = Vec::new();
672 let mut current_components = self.config.initial_components;
673 let mut previous_quality = 0.0;
674 let mut iteration = 0;
675 let mut converged = false;
676 let mut stopping_reason = String::from("Max iterations reached");
677
678 loop {
679 let step_start = Instant::now();
680
681 let quality = self.compute_nystroem_quality(current_components, x)?;
683
684 let improvement = if iteration == 0 {
685 quality
686 } else {
687 quality - previous_quality
688 };
689
690 let step_time = step_start.elapsed().as_secs_f64();
691
692 let step = ProgressiveStep {
694 n_components: current_components,
695 quality_score: quality,
696 improvement,
697 time_taken: step_time,
698 iteration,
699 };
700 steps.push(step);
701
702 if let Some((converged_flag, reason)) =
704 self.check_stopping_criteria(quality, improvement, iteration, current_components)
705 {
706 converged = converged_flag;
707 stopping_reason = reason;
708 break;
709 }
710
711 previous_quality = quality;
713 iteration += 1;
714
715 current_components = match &self.config.strategy {
717 ProgressiveStrategy::Doubling => current_components * 2,
718 ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
719 _ => current_components * 2, };
721 }
722
723 let total_time = start_time.elapsed().as_secs_f64();
724
725 Ok(ProgressiveResult {
726 final_components: steps
727 .last()
728 .map(|s| s.n_components)
729 .unwrap_or(current_components),
730 final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
731 steps,
732 converged,
733 stopping_reason,
734 total_time,
735 })
736 }
737
738 fn compute_nystroem_quality(&self, n_components: usize, x: &Array2<f64>) -> Result<f64> {
740 let mut trial_qualities = Vec::new();
741
742 for trial in 0..self.config.n_trials {
744 let seed = self.config.random_seed.map(|s| s + trial as u64);
745 let nystroem = if let Some(s) = seed {
746 Nystroem::new(self.kernel.clone(), n_components).random_state(s)
747 } else {
748 Nystroem::new(self.kernel.clone(), n_components)
749 };
750
751 let fitted = nystroem.fit(x, &())?;
752 let x_transformed = fitted.transform(x)?;
753
754 let quality = self.compute_effective_rank(&x_transformed)?;
756 trial_qualities.push(quality);
757 }
758
759 Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
760 }
761
762 fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
764 let (_, s, _) = x_transformed
765 .svd(true, true)
766 .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
767
768 let s_sum = s.sum();
769 if s_sum == 0.0 {
770 return Ok(0.0);
771 }
772
773 let s_normalized = &s / s_sum;
774 let entropy = -s_normalized
775 .iter()
776 .filter(|&&x| x > 1e-12)
777 .map(|&x| x * x.ln())
778 .sum::<f64>();
779
780 let effective_rank = entropy.exp();
781 Ok(effective_rank / x_transformed.ncols() as f64)
782 }
783
784 fn check_stopping_criteria(
786 &self,
787 quality: f64,
788 improvement: f64,
789 iteration: usize,
790 components: usize,
791 ) -> Option<(bool, String)> {
792 match &self.config.stopping_criterion {
793 StoppingCriterion::TargetQuality { quality: target } => {
794 if quality >= *target {
795 Some((true, format!("Target quality {} reached", target)))
796 } else {
797 None
798 }
799 }
800 StoppingCriterion::MaxIterations { max_iter } => {
801 if iteration + 1 >= *max_iter {
802 Some((false, format!("Maximum iterations {} reached", max_iter)))
803 } else {
804 None
805 }
806 }
807 _ => None, }
809 }
810}
811
812pub struct FittedProgressiveNystroem {
814 fitted_nystroem: crate::nystroem::Nystroem<sklears_core::traits::Trained>,
815 progressive_result: ProgressiveResult,
816}
817
818impl Fit<Array2<f64>, ()> for ProgressiveNystroem {
819 type Fitted = FittedProgressiveNystroem;
820
821 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
822 let progressive_result = self.run_progressive_approximation(x)?;
824
825 let nystroem = Nystroem::new(self.kernel, progressive_result.final_components);
827 let fitted_nystroem = nystroem.fit(x, &())?;
828
829 Ok(FittedProgressiveNystroem {
830 fitted_nystroem,
831 progressive_result,
832 })
833 }
834}
835
836impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveNystroem {
837 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
838 self.fitted_nystroem.transform(x)
839 }
840}
841
842impl FittedProgressiveNystroem {
843 pub fn progressive_result(&self) -> &ProgressiveResult {
845 &self.progressive_result
846 }
847
848 pub fn final_components(&self) -> usize {
850 self.progressive_result.final_components
851 }
852
853 pub fn final_quality(&self) -> f64 {
855 self.progressive_result.final_quality
856 }
857
858 pub fn converged(&self) -> bool {
860 self.progressive_result.converged
861 }
862}
863
864#[allow(non_snake_case)]
865#[cfg(test)]
866mod tests {
867 use super::*;
868 use approx::assert_abs_diff_eq;
869
870 #[test]
871 fn test_progressive_rbf_sampler() {
872 let x = Array2::from_shape_vec((100, 4), (0..400).map(|i| (i as f64) * 0.01).collect())
873 .unwrap();
874
875 let config = ProgressiveConfig {
876 initial_components: 5,
877 strategy: ProgressiveStrategy::Doubling,
878 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
879 quality_metric: ProgressiveQualityMetric::KernelAlignment,
880 n_trials: 2,
881 validation_fraction: 0.3,
882 ..Default::default()
883 };
884
885 let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
886
887 let fitted = sampler.fit(&x, &()).unwrap();
888 let transformed = fitted.transform(&x).unwrap();
889
890 assert_eq!(transformed.nrows(), 100);
891 assert!(fitted.final_components() >= 5);
892 assert!(fitted.final_quality() >= 0.0);
893 assert_eq!(fitted.steps().len(), 3); }
895
896 #[test]
897 fn test_progressive_nystroem() {
898 let x =
899 Array2::from_shape_vec((80, 3), (0..240).map(|i| (i as f64) * 0.02).collect()).unwrap();
900
901 let config = ProgressiveConfig {
902 initial_components: 10,
903 strategy: ProgressiveStrategy::FixedIncrement { increment: 5 },
904 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
905 n_trials: 2,
906 ..Default::default()
907 };
908
909 let nystroem = ProgressiveNystroem::new().gamma(1.0).config(config);
910
911 let fitted = nystroem.fit(&x, &()).unwrap();
912 let transformed = fitted.transform(&x).unwrap();
913
914 assert_eq!(transformed.nrows(), 80);
915 assert!(fitted.final_components() >= 10);
916 assert!(fitted.final_quality() >= 0.0);
917 }
918
919 #[test]
920 fn test_progressive_strategies() {
921 let x =
922 Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.05).collect()).unwrap();
923
924 let strategies = vec![
925 ProgressiveStrategy::Doubling,
926 ProgressiveStrategy::FixedIncrement { increment: 3 },
927 ProgressiveStrategy::Exponential { base: 1.5 },
928 ProgressiveStrategy::Fibonacci,
929 ];
930
931 for strategy in strategies {
932 let config = ProgressiveConfig {
933 initial_components: 5,
934 strategy,
935 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
936 n_trials: 1,
937 ..Default::default()
938 };
939
940 let sampler = ProgressiveRBFSampler::new().gamma(0.8).config(config);
941
942 let result = sampler.run_progressive_approximation(&x).unwrap();
943
944 assert!(result.final_components >= 5);
945 assert!(result.final_quality >= 0.0);
946 assert_eq!(result.steps.len(), 3);
947 }
948 }
949
950 #[test]
951 fn test_stopping_criteria() {
952 let x =
953 Array2::from_shape_vec((60, 3), (0..180).map(|i| (i as f64) * 0.03).collect()).unwrap();
954
955 let criteria = vec![
956 StoppingCriterion::TargetQuality { quality: 0.8 },
957 StoppingCriterion::ImprovementThreshold { threshold: 0.01 },
958 StoppingCriterion::MaxIterations { max_iter: 5 },
959 StoppingCriterion::MaxComponents { max_components: 50 },
960 ];
961
962 for criterion in criteria {
963 let config = ProgressiveConfig {
964 initial_components: 10,
965 strategy: ProgressiveStrategy::Doubling,
966 stopping_criterion: criterion,
967 n_trials: 1,
968 ..Default::default()
969 };
970
971 let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
972
973 let result = sampler.run_progressive_approximation(&x).unwrap();
974
975 assert!(result.final_components >= 10);
976 assert!(result.final_quality >= 0.0);
977 assert!(!result.stopping_reason.is_empty());
978 }
979 }
980
981 #[test]
982 fn test_quality_metrics() {
983 let x =
984 Array2::from_shape_vec((40, 2), (0..80).map(|i| (i as f64) * 0.05).collect()).unwrap();
985
986 let metrics = vec![
987 ProgressiveQualityMetric::KernelAlignment,
988 ProgressiveQualityMetric::FrobeniusError,
989 ProgressiveQualityMetric::SpectralError,
990 ProgressiveQualityMetric::EffectiveRank,
991 ];
992
993 for metric in metrics {
994 let config = ProgressiveConfig {
995 initial_components: 5,
996 strategy: ProgressiveStrategy::Doubling,
997 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
998 quality_metric: metric,
999 n_trials: 1,
1000 ..Default::default()
1001 };
1002
1003 let sampler = ProgressiveRBFSampler::new().gamma(0.3).config(config);
1004
1005 let result = sampler.run_progressive_approximation(&x).unwrap();
1006
1007 assert!(result.final_components >= 5);
1008 assert!(result.final_quality >= 0.0);
1009
1010 for step in &result.steps {
1012 assert!(step.quality_score >= 0.0);
1013 assert!(step.time_taken >= 0.0);
1014 }
1015 }
1016 }
1017
1018 #[test]
1019 fn test_progressive_improvement() {
1020 let x =
1021 Array2::from_shape_vec((70, 3), (0..210).map(|i| (i as f64) * 0.02).collect()).unwrap();
1022
1023 let config = ProgressiveConfig {
1024 initial_components: 10,
1025 strategy: ProgressiveStrategy::Doubling,
1026 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
1027 quality_metric: ProgressiveQualityMetric::KernelAlignment,
1028 n_trials: 2,
1029 ..Default::default()
1030 };
1031
1032 let sampler = ProgressiveRBFSampler::new().gamma(0.7).config(config);
1033
1034 let result = sampler.run_progressive_approximation(&x).unwrap();
1035
1036 for i in 1..result.steps.len() {
1038 let current_quality = result.steps[i].quality_score;
1039 let previous_quality = result.steps[i - 1].quality_score;
1040
1041 assert!(
1043 current_quality >= previous_quality - 0.1,
1044 "Quality should not decrease significantly: {} -> {}",
1045 previous_quality,
1046 current_quality
1047 );
1048 }
1049 }
1050
1051 #[test]
1052 fn test_progressive_reproducibility() {
1053 let x =
1054 Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.04).collect()).unwrap();
1055
1056 let config = ProgressiveConfig {
1057 initial_components: 5,
1058 strategy: ProgressiveStrategy::Doubling,
1059 stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1060 n_trials: 2,
1061 random_seed: Some(42),
1062 ..Default::default()
1063 };
1064
1065 let sampler1 = ProgressiveRBFSampler::new()
1066 .gamma(0.6)
1067 .config(config.clone());
1068
1069 let sampler2 = ProgressiveRBFSampler::new().gamma(0.6).config(config);
1070
1071 let result1 = sampler1.run_progressive_approximation(&x).unwrap();
1072 let result2 = sampler2.run_progressive_approximation(&x).unwrap();
1073
1074 assert_eq!(result1.final_components, result2.final_components);
1075 assert_abs_diff_eq!(
1076 result1.final_quality,
1077 result2.final_quality,
1078 epsilon = 1e-10
1079 );
1080 assert_eq!(result1.steps.len(), result2.steps.len());
1081 }
1082}