sklears_kernel_approximation/
progressive.rs

1//! Progressive kernel approximation methods
2//!
3//! This module provides progressive approximation strategies that start with
4//! coarse approximations and progressively refine them based on quality criteria.
5
6use crate::{Nystroem, RBFSampler};
7use scirs2_core::ndarray::Array2;
8use scirs2_linalg::compat::{Norm, SVD};
9use sklears_core::traits::Fit;
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::Transform,
13};
14use std::time::Instant;
15
16/// Progressive refinement strategies
17#[derive(Debug, Clone)]
18/// ProgressiveStrategy
19pub enum ProgressiveStrategy {
20    /// Double the number of components at each step
21    Doubling,
22    /// Add a fixed number of components at each step
23    FixedIncrement { increment: usize },
24    /// Adaptive increment based on quality improvement
25    AdaptiveIncrement {
26        min_increment: usize,
27
28        max_increment: usize,
29
30        improvement_threshold: f64,
31    },
32    /// Exponential growth with custom base
33    Exponential { base: f64 },
34    /// Fibonacci-based growth
35    Fibonacci,
36}
37
38/// Stopping criteria for progressive approximation
39#[derive(Debug, Clone)]
40/// StoppingCriterion
41pub enum StoppingCriterion {
42    /// Stop when target quality is reached
43    TargetQuality { quality: f64 },
44    /// Stop when improvement falls below threshold
45    ImprovementThreshold { threshold: f64 },
46    /// Stop after maximum number of iterations
47    MaxIterations { max_iter: usize },
48    /// Stop when maximum components is reached
49    MaxComponents { max_components: usize },
50    /// Combined criteria (all must be satisfied)
51    Combined {
52        quality: Option<f64>,
53        improvement_threshold: Option<f64>,
54        max_iter: Option<usize>,
55        max_components: Option<usize>,
56    },
57}
58
59/// Quality metrics for progressive approximation
60#[derive(Debug, Clone)]
61/// ProgressiveQualityMetric
62pub enum ProgressiveQualityMetric {
63    /// Kernel alignment between exact and approximate kernels
64    KernelAlignment,
65    /// Frobenius norm of approximation error
66    FrobeniusError,
67    /// Spectral norm of approximation error
68    SpectralError,
69    /// Effective rank of the approximation
70    EffectiveRank,
71    /// Relative improvement over previous iteration
72    RelativeImprovement,
73    /// Custom quality function
74    Custom,
75}
76
77/// Configuration for progressive approximation
78#[derive(Debug, Clone)]
79/// ProgressiveConfig
80pub struct ProgressiveConfig {
81    /// Initial number of components
82    pub initial_components: usize,
83    /// Progressive strategy
84    pub strategy: ProgressiveStrategy,
85    /// Stopping criterion
86    pub stopping_criterion: StoppingCriterion,
87    /// Quality metric to optimize
88    pub quality_metric: ProgressiveQualityMetric,
89    /// Number of trials per iteration for stability
90    pub n_trials: usize,
91    /// Random seed for reproducibility
92    pub random_seed: Option<u64>,
93    /// Validation fraction for quality assessment
94    pub validation_fraction: f64,
95    /// Whether to store intermediate results
96    pub store_intermediate: bool,
97}
98
99impl Default for ProgressiveConfig {
100    fn default() -> Self {
101        Self {
102            initial_components: 10,
103            strategy: ProgressiveStrategy::Doubling,
104            stopping_criterion: StoppingCriterion::Combined {
105                quality: Some(0.95),
106                improvement_threshold: Some(0.01),
107                max_iter: Some(10),
108                max_components: Some(1000),
109            },
110            quality_metric: ProgressiveQualityMetric::KernelAlignment,
111            n_trials: 3,
112            random_seed: None,
113            validation_fraction: 0.2,
114            store_intermediate: true,
115        }
116    }
117}
118
119/// Results from a single progressive step
120#[derive(Debug, Clone)]
121/// ProgressiveStep
122pub struct ProgressiveStep {
123    /// Number of components in this step
124    pub n_components: usize,
125    /// Quality score achieved
126    pub quality_score: f64,
127    /// Improvement over previous step
128    pub improvement: f64,
129    /// Time taken for this step
130    pub time_taken: f64,
131    /// Iteration number
132    pub iteration: usize,
133}
134
135/// Results from progressive approximation
136#[derive(Debug, Clone)]
137/// ProgressiveResult
138pub struct ProgressiveResult {
139    /// Final number of components
140    pub final_components: usize,
141    /// Final quality score
142    pub final_quality: f64,
143    /// All progressive steps
144    pub steps: Vec<ProgressiveStep>,
145    /// Whether convergence was achieved
146    pub converged: bool,
147    /// Stopping reason
148    pub stopping_reason: String,
149    /// Total time taken
150    pub total_time: f64,
151}
152
153/// Progressive RBF sampler
154#[derive(Debug, Clone)]
155/// ProgressiveRBFSampler
156pub struct ProgressiveRBFSampler {
157    gamma: f64,
158    config: ProgressiveConfig,
159}
160
161impl Default for ProgressiveRBFSampler {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167impl ProgressiveRBFSampler {
168    /// Create a new progressive RBF sampler
169    pub fn new() -> Self {
170        Self {
171            gamma: 1.0,
172            config: ProgressiveConfig::default(),
173        }
174    }
175
176    /// Set gamma parameter
177    pub fn gamma(mut self, gamma: f64) -> Self {
178        self.gamma = gamma;
179        self
180    }
181
182    /// Set configuration
183    pub fn config(mut self, config: ProgressiveConfig) -> Self {
184        self.config = config;
185        self
186    }
187
188    /// Set initial components
189    pub fn initial_components(mut self, components: usize) -> Self {
190        self.config.initial_components = components;
191        self
192    }
193
194    /// Set progressive strategy
195    pub fn strategy(mut self, strategy: ProgressiveStrategy) -> Self {
196        self.config.strategy = strategy;
197        self
198    }
199
200    /// Set stopping criterion
201    pub fn stopping_criterion(mut self, criterion: StoppingCriterion) -> Self {
202        self.config.stopping_criterion = criterion;
203        self
204    }
205
206    /// Run progressive approximation
207    pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
208        let start_time = Instant::now();
209        let n_samples = x.nrows();
210
211        // Split data for validation
212        let split_idx = (n_samples as f64 * (1.0 - self.config.validation_fraction)) as usize;
213        let x_train = x
214            .slice(scirs2_core::ndarray::s![..split_idx, ..])
215            .to_owned();
216        let x_val = x
217            .slice(scirs2_core::ndarray::s![split_idx.., ..])
218            .to_owned();
219
220        // Compute exact kernel matrix for validation (small subset)
221        let k_exact = self.compute_exact_kernel_matrix(&x_val)?;
222
223        let mut steps = Vec::new();
224        let mut current_components = self.config.initial_components;
225        let mut previous_quality = 0.0;
226        let mut iteration = 0;
227        let result;
228
229        // Fibonacci sequence state (for Fibonacci strategy)
230        let mut fib_prev = 1;
231        let mut fib_curr = 1;
232
233        loop {
234            let step_start = Instant::now();
235
236            // Compute quality for current number of components
237            let quality = self.compute_quality_for_components(
238                current_components,
239                &x_train,
240                &x_val,
241                &k_exact,
242            )?;
243
244            let improvement = if iteration == 0 {
245                quality
246            } else {
247                quality - previous_quality
248            };
249
250            let step_time = step_start.elapsed().as_secs_f64();
251
252            // Store step result
253            let step = ProgressiveStep {
254                n_components: current_components,
255                quality_score: quality,
256                improvement,
257                time_taken: step_time,
258                iteration,
259            };
260            steps.push(step);
261
262            // Check stopping criteria
263            if let Some(stop_result) =
264                self.check_stopping_criteria(quality, improvement, iteration, current_components)
265            {
266                result = Some(stop_result);
267                break;
268            }
269
270            // Update for next iteration
271            previous_quality = quality;
272            iteration += 1;
273
274            // Determine next number of components
275            current_components = match &self.config.strategy {
276                ProgressiveStrategy::Doubling => current_components * 2,
277                ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
278                ProgressiveStrategy::AdaptiveIncrement {
279                    min_increment,
280                    max_increment,
281                    improvement_threshold,
282                } => {
283                    let increment = if improvement > *improvement_threshold {
284                        *min_increment
285                    } else {
286                        (*min_increment + (*max_increment - *min_increment) / 2).max(*min_increment)
287                    };
288                    current_components + increment
289                }
290                ProgressiveStrategy::Exponential { base } => {
291                    ((current_components as f64) * base) as usize
292                }
293                ProgressiveStrategy::Fibonacci => {
294                    let next_fib = fib_prev + fib_curr;
295                    fib_prev = fib_curr;
296                    fib_curr = next_fib;
297                    self.config.initial_components + fib_curr
298                }
299            };
300        }
301
302        let total_time = start_time.elapsed().as_secs_f64();
303        let (converged, stopping_reason) =
304            result.unwrap_or((false, "Max iterations reached".to_string()));
305
306        Ok(ProgressiveResult {
307            final_components: steps
308                .last()
309                .map(|s| s.n_components)
310                .unwrap_or(current_components),
311            final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
312            steps,
313            converged,
314            stopping_reason,
315            total_time,
316        })
317    }
318
319    /// Compute exact kernel matrix for validation
320    fn compute_exact_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
321        let n_samples = x.nrows().min(100); // Limit for computational efficiency
322        let x_subset = x.slice(scirs2_core::ndarray::s![..n_samples, ..]);
323
324        let mut k_exact = Array2::zeros((n_samples, n_samples));
325
326        for i in 0..n_samples {
327            for j in 0..n_samples {
328                let diff = &x_subset.row(i) - &x_subset.row(j);
329                let squared_norm = diff.dot(&diff);
330                k_exact[[i, j]] = (-self.gamma * squared_norm).exp();
331            }
332        }
333
334        Ok(k_exact)
335    }
336
337    /// Compute quality for a given number of components
338    fn compute_quality_for_components(
339        &self,
340        n_components: usize,
341        x_train: &Array2<f64>,
342        x_val: &Array2<f64>,
343        k_exact: &Array2<f64>,
344    ) -> Result<f64> {
345        let mut trial_qualities = Vec::new();
346
347        // Run multiple trials for stability
348        for trial in 0..self.config.n_trials {
349            let seed = self.config.random_seed.map(|s| s + trial as u64);
350            let sampler = if let Some(s) = seed {
351                RBFSampler::new(n_components)
352                    .gamma(self.gamma)
353                    .random_state(s)
354            } else {
355                RBFSampler::new(n_components).gamma(self.gamma)
356            };
357
358            let fitted = sampler.fit(x_train, &())?;
359            let x_val_transformed = fitted.transform(x_val)?;
360
361            let quality = self.compute_quality_metric(x_val, &x_val_transformed, k_exact)?;
362            trial_qualities.push(quality);
363        }
364
365        // Return average quality across trials
366        Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
367    }
368
369    /// Compute quality metric
370    fn compute_quality_metric(
371        &self,
372        _x: &Array2<f64>,
373        x_transformed: &Array2<f64>,
374        k_exact: &Array2<f64>,
375    ) -> Result<f64> {
376        match &self.config.quality_metric {
377            ProgressiveQualityMetric::KernelAlignment => {
378                self.compute_kernel_alignment(x_transformed, k_exact)
379            }
380            ProgressiveQualityMetric::FrobeniusError => {
381                self.compute_frobenius_error(x_transformed, k_exact)
382            }
383            ProgressiveQualityMetric::SpectralError => {
384                self.compute_spectral_error(x_transformed, k_exact)
385            }
386            ProgressiveQualityMetric::EffectiveRank => self.compute_effective_rank(x_transformed),
387            ProgressiveQualityMetric::RelativeImprovement => {
388                // This is handled at a higher level
389                Ok(1.0)
390            }
391            ProgressiveQualityMetric::Custom => {
392                // Placeholder for custom quality function
393                self.compute_kernel_alignment(x_transformed, k_exact)
394            }
395        }
396    }
397
398    /// Compute kernel alignment
399    fn compute_kernel_alignment(
400        &self,
401        x_transformed: &Array2<f64>,
402        k_exact: &Array2<f64>,
403    ) -> Result<f64> {
404        let n_samples = k_exact.nrows().min(x_transformed.nrows());
405        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
406
407        // Compute approximate kernel matrix
408        let k_approx = x_subset.dot(&x_subset.t());
409
410        // Compute alignment
411        let k_exact_norm = k_exact.norm_l2();
412        let k_approx_norm = k_approx.norm_l2();
413
414        if k_exact_norm > 1e-12 && k_approx_norm > 1e-12 {
415            let alignment = (k_exact * &k_approx).sum() / (k_exact_norm * k_approx_norm);
416            Ok(alignment)
417        } else {
418            Ok(0.0)
419        }
420    }
421
422    /// Compute Frobenius error (as quality score, so higher is better)
423    fn compute_frobenius_error(
424        &self,
425        x_transformed: &Array2<f64>,
426        k_exact: &Array2<f64>,
427    ) -> Result<f64> {
428        let n_samples = k_exact.nrows().min(x_transformed.nrows());
429        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
430
431        // Compute approximate kernel matrix
432        let k_approx = x_subset.dot(&x_subset.t());
433
434        // Compute error and convert to quality (higher is better)
435        let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
436        let error = diff.norm_l2();
437        let quality = 1.0 / (1.0 + error); // Convert error to quality score
438
439        Ok(quality)
440    }
441
442    /// Compute spectral error (as quality score)
443    fn compute_spectral_error(
444        &self,
445        x_transformed: &Array2<f64>,
446        k_exact: &Array2<f64>,
447    ) -> Result<f64> {
448        let n_samples = k_exact.nrows().min(x_transformed.nrows());
449        let x_subset = x_transformed.slice(scirs2_core::ndarray::s![..n_samples, ..]);
450
451        // Compute approximate kernel matrix
452        let k_approx = x_subset.dot(&x_subset.t());
453
454        // Compute spectral norm (largest singular value) of the error
455        let diff = k_exact - &k_approx.slice(scirs2_core::ndarray::s![..n_samples, ..n_samples]);
456        let (_, s, _) = diff
457            .svd(false)
458            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
459
460        let spectral_error = s.iter().fold(0.0f64, |acc, &x| acc.max(x));
461        let quality = 1.0 / (1.0 + spectral_error);
462
463        Ok(quality)
464    }
465
466    /// Compute effective rank
467    fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
468        // Compute SVD of transformed data
469        let (_, s, _) = x_transformed
470            .svd(true)
471            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
472
473        // Compute effective rank using entropy
474        let s_sum = s.sum();
475        if s_sum == 0.0 {
476            return Ok(0.0);
477        }
478
479        let s_normalized = &s / s_sum;
480        let entropy = -s_normalized
481            .iter()
482            .filter(|&&x| x > 1e-12)
483            .map(|&x| x * x.ln())
484            .sum::<f64>();
485
486        let effective_rank = entropy.exp();
487        Ok(effective_rank / x_transformed.ncols() as f64) // Normalize by max possible rank
488    }
489
490    /// Check stopping criteria
491    fn check_stopping_criteria(
492        &self,
493        quality: f64,
494        improvement: f64,
495        iteration: usize,
496        components: usize,
497    ) -> Option<(bool, String)> {
498        match &self.config.stopping_criterion {
499            StoppingCriterion::TargetQuality { quality: target } => {
500                if quality >= *target {
501                    Some((true, format!("Target quality {} reached", target)))
502                } else {
503                    None
504                }
505            }
506            StoppingCriterion::ImprovementThreshold { threshold } => {
507                if iteration > 0 && improvement < *threshold {
508                    Some((
509                        true,
510                        format!("Improvement {} below threshold {}", improvement, threshold),
511                    ))
512                } else {
513                    None
514                }
515            }
516            StoppingCriterion::MaxIterations { max_iter } => {
517                if iteration + 1 >= *max_iter {
518                    Some((false, format!("Maximum iterations {} reached", max_iter)))
519                } else {
520                    None
521                }
522            }
523            StoppingCriterion::MaxComponents { max_components } => {
524                if components >= *max_components {
525                    Some((
526                        false,
527                        format!("Maximum components {} reached", max_components),
528                    ))
529                } else {
530                    None
531                }
532            }
533            StoppingCriterion::Combined {
534                quality: target_quality,
535                improvement_threshold,
536                max_iter,
537                max_components,
538            } => {
539                // Check target quality
540                if let Some(target) = target_quality {
541                    if quality >= *target {
542                        return Some((true, format!("Target quality {} reached", target)));
543                    }
544                }
545
546                // Check improvement threshold
547                if let Some(threshold) = improvement_threshold {
548                    if iteration > 0 && improvement < *threshold {
549                        return Some((
550                            true,
551                            format!("Improvement {} below threshold {}", improvement, threshold),
552                        ));
553                    }
554                }
555
556                // Check max iterations
557                if let Some(max) = max_iter {
558                    if iteration >= *max {
559                        return Some((false, format!("Maximum iterations {} reached", max)));
560                    }
561                }
562
563                // Check max components
564                if let Some(max) = max_components {
565                    if components >= *max {
566                        return Some((false, format!("Maximum components {} reached", max)));
567                    }
568                }
569
570                None
571            }
572        }
573    }
574}
575
576/// Fitted progressive RBF sampler
577pub struct FittedProgressiveRBFSampler {
578    fitted_rbf: crate::rbf_sampler::RBFSampler<sklears_core::traits::Trained>,
579    progressive_result: ProgressiveResult,
580}
581
582impl Fit<Array2<f64>, ()> for ProgressiveRBFSampler {
583    type Fitted = FittedProgressiveRBFSampler;
584
585    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
586        // Run progressive approximation
587        let progressive_result = self.run_progressive_approximation(x)?;
588
589        // Fit RBF sampler with final configuration
590        let rbf_sampler = RBFSampler::new(progressive_result.final_components).gamma(self.gamma);
591        let fitted_rbf = rbf_sampler.fit(x, &())?;
592
593        Ok(FittedProgressiveRBFSampler {
594            fitted_rbf,
595            progressive_result,
596        })
597    }
598}
599
600impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveRBFSampler {
601    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
602        self.fitted_rbf.transform(x)
603    }
604}
605
606impl FittedProgressiveRBFSampler {
607    /// Get the progressive result
608    pub fn progressive_result(&self) -> &ProgressiveResult {
609        &self.progressive_result
610    }
611
612    /// Get the final number of components
613    pub fn final_components(&self) -> usize {
614        self.progressive_result.final_components
615    }
616
617    /// Get the final quality score
618    pub fn final_quality(&self) -> f64 {
619        self.progressive_result.final_quality
620    }
621
622    /// Check if progressive approximation converged
623    pub fn converged(&self) -> bool {
624        self.progressive_result.converged
625    }
626
627    /// Get all progressive steps
628    pub fn steps(&self) -> &[ProgressiveStep] {
629        &self.progressive_result.steps
630    }
631
632    /// Get the stopping reason
633    pub fn stopping_reason(&self) -> &str {
634        &self.progressive_result.stopping_reason
635    }
636}
637
638/// Progressive Nyström method
639#[derive(Debug, Clone)]
640/// ProgressiveNystroem
641pub struct ProgressiveNystroem {
642    kernel: crate::nystroem::Kernel,
643    config: ProgressiveConfig,
644}
645
646impl Default for ProgressiveNystroem {
647    fn default() -> Self {
648        Self::new()
649    }
650}
651
652impl ProgressiveNystroem {
653    /// Create a new progressive Nyström method
654    pub fn new() -> Self {
655        Self {
656            kernel: crate::nystroem::Kernel::Rbf { gamma: 1.0 },
657            config: ProgressiveConfig::default(),
658        }
659    }
660
661    /// Set gamma parameter (for RBF kernel)
662    pub fn gamma(mut self, gamma: f64) -> Self {
663        self.kernel = crate::nystroem::Kernel::Rbf { gamma };
664        self
665    }
666
667    /// Set kernel type
668    pub fn kernel(mut self, kernel: crate::nystroem::Kernel) -> Self {
669        self.kernel = kernel;
670        self
671    }
672
673    /// Set configuration
674    pub fn config(mut self, config: ProgressiveConfig) -> Self {
675        self.config = config;
676        self
677    }
678
679    /// Run progressive approximation for Nyström method
680    pub fn run_progressive_approximation(&self, x: &Array2<f64>) -> Result<ProgressiveResult> {
681        let start_time = Instant::now();
682
683        let mut steps = Vec::new();
684        let mut current_components = self.config.initial_components;
685        let mut previous_quality = 0.0;
686        let mut iteration = 0;
687        let result;
688
689        loop {
690            let step_start = Instant::now();
691
692            // Compute quality for current number of components
693            let quality = self.compute_nystroem_quality(current_components, x)?;
694
695            let improvement = if iteration == 0 {
696                quality
697            } else {
698                quality - previous_quality
699            };
700
701            let step_time = step_start.elapsed().as_secs_f64();
702
703            // Store step result
704            let step = ProgressiveStep {
705                n_components: current_components,
706                quality_score: quality,
707                improvement,
708                time_taken: step_time,
709                iteration,
710            };
711            steps.push(step);
712
713            // Check stopping criteria (using same logic as RBF sampler)
714            if let Some(stop_result) =
715                self.check_stopping_criteria(quality, improvement, iteration, current_components)
716            {
717                result = Some(stop_result);
718                break;
719            }
720
721            // Update for next iteration
722            previous_quality = quality;
723            iteration += 1;
724
725            // Determine next number of components (same logic as RBF sampler)
726            current_components = match &self.config.strategy {
727                ProgressiveStrategy::Doubling => current_components * 2,
728                ProgressiveStrategy::FixedIncrement { increment } => current_components + increment,
729                _ => current_components * 2, // Simplified for Nyström
730            };
731        }
732
733        let total_time = start_time.elapsed().as_secs_f64();
734        let (converged, stopping_reason) =
735            result.unwrap_or((false, "Max iterations reached".to_string()));
736
737        Ok(ProgressiveResult {
738            final_components: steps
739                .last()
740                .map(|s| s.n_components)
741                .unwrap_or(current_components),
742            final_quality: steps.last().map(|s| s.quality_score).unwrap_or(0.0),
743            steps,
744            converged,
745            stopping_reason,
746            total_time,
747        })
748    }
749
750    /// Compute quality for Nyström with given components
751    fn compute_nystroem_quality(&self, n_components: usize, x: &Array2<f64>) -> Result<f64> {
752        let mut trial_qualities = Vec::new();
753
754        // Run multiple trials for stability
755        for trial in 0..self.config.n_trials {
756            let seed = self.config.random_seed.map(|s| s + trial as u64);
757            let nystroem = if let Some(s) = seed {
758                Nystroem::new(self.kernel.clone(), n_components).random_state(s)
759            } else {
760                Nystroem::new(self.kernel.clone(), n_components)
761            };
762
763            let fitted = nystroem.fit(x, &())?;
764            let x_transformed = fitted.transform(x)?;
765
766            // Use effective rank as quality measure
767            let quality = self.compute_effective_rank(&x_transformed)?;
768            trial_qualities.push(quality);
769        }
770
771        Ok(trial_qualities.iter().sum::<f64>() / trial_qualities.len() as f64)
772    }
773
774    /// Compute effective rank (same as RBF sampler)
775    fn compute_effective_rank(&self, x_transformed: &Array2<f64>) -> Result<f64> {
776        let (_, s, _) = x_transformed
777            .svd(true)
778            .map_err(|_| SklearsError::InvalidInput("SVD computation failed".to_string()))?;
779
780        let s_sum = s.sum();
781        if s_sum == 0.0 {
782            return Ok(0.0);
783        }
784
785        let s_normalized = &s / s_sum;
786        let entropy = -s_normalized
787            .iter()
788            .filter(|&&x| x > 1e-12)
789            .map(|&x| x * x.ln())
790            .sum::<f64>();
791
792        let effective_rank = entropy.exp();
793        Ok(effective_rank / x_transformed.ncols() as f64)
794    }
795
796    /// Check stopping criteria (same as RBF sampler)
797    fn check_stopping_criteria(
798        &self,
799        quality: f64,
800        _improvement: f64,
801        iteration: usize,
802        _components: usize,
803    ) -> Option<(bool, String)> {
804        match &self.config.stopping_criterion {
805            StoppingCriterion::TargetQuality { quality: target } => {
806                if quality >= *target {
807                    Some((true, format!("Target quality {} reached", target)))
808                } else {
809                    None
810                }
811            }
812            StoppingCriterion::MaxIterations { max_iter } => {
813                if iteration + 1 >= *max_iter {
814                    Some((false, format!("Maximum iterations {} reached", max_iter)))
815                } else {
816                    None
817                }
818            }
819            _ => None, // Simplified for Nyström
820        }
821    }
822}
823
824/// Fitted progressive Nyström method
825pub struct FittedProgressiveNystroem {
826    fitted_nystroem: crate::nystroem::Nystroem<sklears_core::traits::Trained>,
827    progressive_result: ProgressiveResult,
828}
829
830impl Fit<Array2<f64>, ()> for ProgressiveNystroem {
831    type Fitted = FittedProgressiveNystroem;
832
833    fn fit(self, x: &Array2<f64>, _y: &()) -> Result<Self::Fitted> {
834        // Run progressive approximation
835        let progressive_result = self.run_progressive_approximation(x)?;
836
837        // Fit Nyström method with final configuration
838        let nystroem = Nystroem::new(self.kernel, progressive_result.final_components);
839        let fitted_nystroem = nystroem.fit(x, &())?;
840
841        Ok(FittedProgressiveNystroem {
842            fitted_nystroem,
843            progressive_result,
844        })
845    }
846}
847
848impl Transform<Array2<f64>, Array2<f64>> for FittedProgressiveNystroem {
849    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
850        self.fitted_nystroem.transform(x)
851    }
852}
853
854impl FittedProgressiveNystroem {
855    /// Get the progressive result
856    pub fn progressive_result(&self) -> &ProgressiveResult {
857        &self.progressive_result
858    }
859
860    /// Get the final number of components
861    pub fn final_components(&self) -> usize {
862        self.progressive_result.final_components
863    }
864
865    /// Get the final quality score
866    pub fn final_quality(&self) -> f64 {
867        self.progressive_result.final_quality
868    }
869
870    /// Check if progressive approximation converged
871    pub fn converged(&self) -> bool {
872        self.progressive_result.converged
873    }
874}
875
876#[allow(non_snake_case)]
877#[cfg(test)]
878mod tests {
879    use super::*;
880    use approx::assert_abs_diff_eq;
881
882    #[test]
883    fn test_progressive_rbf_sampler() {
884        let x = Array2::from_shape_vec((100, 4), (0..400).map(|i| (i as f64) * 0.01).collect())
885            .unwrap();
886
887        let config = ProgressiveConfig {
888            initial_components: 5,
889            strategy: ProgressiveStrategy::Doubling,
890            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
891            quality_metric: ProgressiveQualityMetric::KernelAlignment,
892            n_trials: 2,
893            validation_fraction: 0.3,
894            ..Default::default()
895        };
896
897        let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
898
899        let fitted = sampler.fit(&x, &()).unwrap();
900        let transformed = fitted.transform(&x).unwrap();
901
902        assert_eq!(transformed.nrows(), 100);
903        assert!(fitted.final_components() >= 5);
904        assert!(fitted.final_quality() >= 0.0);
905        assert_eq!(fitted.steps().len(), 3); // 3 iterations max
906    }
907
908    #[test]
909    fn test_progressive_nystroem() {
910        let x =
911            Array2::from_shape_vec((80, 3), (0..240).map(|i| (i as f64) * 0.02).collect()).unwrap();
912
913        let config = ProgressiveConfig {
914            initial_components: 10,
915            strategy: ProgressiveStrategy::FixedIncrement { increment: 5 },
916            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
917            n_trials: 2,
918            ..Default::default()
919        };
920
921        let nystroem = ProgressiveNystroem::new().gamma(1.0).config(config);
922
923        let fitted = nystroem.fit(&x, &()).unwrap();
924        let transformed = fitted.transform(&x).unwrap();
925
926        assert_eq!(transformed.nrows(), 80);
927        assert!(fitted.final_components() >= 10);
928        assert!(fitted.final_quality() >= 0.0);
929    }
930
931    #[test]
932    fn test_progressive_strategies() {
933        let x =
934            Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.05).collect()).unwrap();
935
936        let strategies = vec![
937            ProgressiveStrategy::Doubling,
938            ProgressiveStrategy::FixedIncrement { increment: 3 },
939            ProgressiveStrategy::Exponential { base: 1.5 },
940            ProgressiveStrategy::Fibonacci,
941        ];
942
943        for strategy in strategies {
944            let config = ProgressiveConfig {
945                initial_components: 5,
946                strategy,
947                stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
948                n_trials: 1,
949                ..Default::default()
950            };
951
952            let sampler = ProgressiveRBFSampler::new().gamma(0.8).config(config);
953
954            let result = sampler.run_progressive_approximation(&x).unwrap();
955
956            assert!(result.final_components >= 5);
957            assert!(result.final_quality >= 0.0);
958            assert_eq!(result.steps.len(), 3);
959        }
960    }
961
962    #[test]
963    fn test_stopping_criteria() {
964        let x =
965            Array2::from_shape_vec((60, 3), (0..180).map(|i| (i as f64) * 0.03).collect()).unwrap();
966
967        let criteria = vec![
968            StoppingCriterion::TargetQuality { quality: 0.8 },
969            StoppingCriterion::ImprovementThreshold { threshold: 0.01 },
970            StoppingCriterion::MaxIterations { max_iter: 5 },
971            StoppingCriterion::MaxComponents { max_components: 50 },
972        ];
973
974        for criterion in criteria {
975            let config = ProgressiveConfig {
976                initial_components: 10,
977                strategy: ProgressiveStrategy::Doubling,
978                stopping_criterion: criterion,
979                n_trials: 1,
980                ..Default::default()
981            };
982
983            let sampler = ProgressiveRBFSampler::new().gamma(0.5).config(config);
984
985            let result = sampler.run_progressive_approximation(&x).unwrap();
986
987            assert!(result.final_components >= 10);
988            assert!(result.final_quality >= 0.0);
989            assert!(!result.stopping_reason.is_empty());
990        }
991    }
992
993    #[test]
994    fn test_quality_metrics() {
995        let x =
996            Array2::from_shape_vec((40, 2), (0..80).map(|i| (i as f64) * 0.05).collect()).unwrap();
997
998        let metrics = vec![
999            ProgressiveQualityMetric::KernelAlignment,
1000            ProgressiveQualityMetric::FrobeniusError,
1001            ProgressiveQualityMetric::SpectralError,
1002            ProgressiveQualityMetric::EffectiveRank,
1003        ];
1004
1005        for metric in metrics {
1006            let config = ProgressiveConfig {
1007                initial_components: 5,
1008                strategy: ProgressiveStrategy::Doubling,
1009                stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1010                quality_metric: metric,
1011                n_trials: 1,
1012                ..Default::default()
1013            };
1014
1015            let sampler = ProgressiveRBFSampler::new().gamma(0.3).config(config);
1016
1017            let result = sampler.run_progressive_approximation(&x).unwrap();
1018
1019            assert!(result.final_components >= 5);
1020            assert!(result.final_quality >= 0.0);
1021
1022            // All steps should have valid quality scores
1023            for step in &result.steps {
1024                assert!(step.quality_score >= 0.0);
1025                assert!(step.time_taken >= 0.0);
1026            }
1027        }
1028    }
1029
1030    #[test]
1031    fn test_progressive_improvement() {
1032        let x =
1033            Array2::from_shape_vec((70, 3), (0..210).map(|i| (i as f64) * 0.02).collect()).unwrap();
1034
1035        let config = ProgressiveConfig {
1036            initial_components: 10,
1037            strategy: ProgressiveStrategy::Doubling,
1038            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 4 },
1039            quality_metric: ProgressiveQualityMetric::KernelAlignment,
1040            n_trials: 2,
1041            ..Default::default()
1042        };
1043
1044        let sampler = ProgressiveRBFSampler::new().gamma(0.7).config(config);
1045
1046        let result = sampler.run_progressive_approximation(&x).unwrap();
1047
1048        // Quality should generally improve or stay stable
1049        for i in 1..result.steps.len() {
1050            let current_quality = result.steps[i].quality_score;
1051            let previous_quality = result.steps[i - 1].quality_score;
1052
1053            // Allow for small numerical differences
1054            assert!(
1055                current_quality >= previous_quality - 0.1,
1056                "Quality should not decrease significantly: {} -> {}",
1057                previous_quality,
1058                current_quality
1059            );
1060        }
1061    }
1062
1063    #[test]
1064    fn test_progressive_reproducibility() {
1065        let x =
1066            Array2::from_shape_vec((50, 2), (0..100).map(|i| (i as f64) * 0.04).collect()).unwrap();
1067
1068        let config = ProgressiveConfig {
1069            initial_components: 5,
1070            strategy: ProgressiveStrategy::Doubling,
1071            stopping_criterion: StoppingCriterion::MaxIterations { max_iter: 3 },
1072            n_trials: 2,
1073            random_seed: Some(42),
1074            ..Default::default()
1075        };
1076
1077        let sampler1 = ProgressiveRBFSampler::new()
1078            .gamma(0.6)
1079            .config(config.clone());
1080
1081        let sampler2 = ProgressiveRBFSampler::new().gamma(0.6).config(config);
1082
1083        let result1 = sampler1.run_progressive_approximation(&x).unwrap();
1084        let result2 = sampler2.run_progressive_approximation(&x).unwrap();
1085
1086        assert_eq!(result1.final_components, result2.final_components);
1087        assert_abs_diff_eq!(
1088            result1.final_quality,
1089            result2.final_quality,
1090            epsilon = 1e-10
1091        );
1092        assert_eq!(result1.steps.len(), result2.steps.len());
1093    }
1094}