sklears_kernel_approximation/
ensemble_nystroem.rs

1//! Ensemble Nyström method for improved kernel approximation
2use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Rng;
6use sklears_core::{
7    error::{Result, SklearsError},
8    traits::{Estimator, Fit, Trained, Transform, Untrained},
9    types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, SeedableRng};
14/// Ensemble method for combining multiple Nyström approximations
15#[derive(Debug, Clone)]
16pub enum EnsembleMethod {
17    /// Simple averaging of all approximations
18    Average,
19    /// Weighted average based on approximation quality (higher quality gets more weight)
20    WeightedAverage,
21    /// Concatenate all approximations
22    Concatenate,
23    /// Use the best approximation based on some quality metric
24    BestApproximation,
25}
26
27/// Quality metric for evaluating Nyström approximations
28#[derive(Debug, Clone)]
29pub enum QualityMetric {
30    /// Frobenius norm of the approximation error
31    FrobeniusNorm,
32    /// Trace of the approximation
33    Trace,
34    /// Spectral norm (largest eigenvalue)
35    SpectralNorm,
36    /// Nuclear norm (sum of eigenvalues)
37    NuclearNorm,
38}
39
40/// Ensemble Nyström method for kernel approximation
41///
42/// Combines multiple Nyström approximations using different sampling strategies
43/// and component sizes to achieve better approximation quality than a single
44/// Nyström approximation.
45///
46/// # Parameters
47///
48/// * `kernel` - Kernel function to approximate
49/// * `n_estimators` - Number of base Nyström estimators (default: 5)
50/// * `n_components` - Number of samples per estimator (default: 100)
51/// * `ensemble_method` - Method for combining estimators
52/// * `sampling_strategies` - List of sampling strategies to use
53/// * `random_state` - Random seed for reproducibility
54///
55/// # Examples
56///
57/// ```rust,ignore
58/// use sklears_kernel_approximation::ensemble_nystroem::{EnsembleNystroem, EnsembleMethod};
59/// use sklears_kernel_approximation::nystroem::{Kernel, SamplingStrategy};
60/// use sklears_core::traits::{Transform, Fit, Untrained}
61/// use scirs2_core::ndarray::array;
62///
63/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
64///
65/// let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3, 2)
66///     .ensemble_method(EnsembleMethod::WeightedAverage);
67/// let fitted_ensemble = ensemble.fit(&X, &()).unwrap();
68/// let X_transformed = fitted_ensemble.transform(&X).unwrap();
69/// ```
70#[derive(Debug, Clone)]
71pub struct EnsembleNystroem<State = Untrained> {
72    /// Kernel function
73    pub kernel: Kernel,
74    /// Number of base estimators
75    pub n_estimators: usize,
76    /// Number of components per estimator
77    pub n_components: usize,
78    /// Method for ensemble combination
79    pub ensemble_method: EnsembleMethod,
80    /// Sampling strategies to use (if None, uses diverse set)
81    pub sampling_strategies: Option<Vec<SamplingStrategy>>,
82    /// Quality metric for evaluating approximations
83    pub quality_metric: QualityMetric,
84    /// Random seed
85    pub random_state: Option<u64>,
86
87    // Fitted attributes
88    estimators_: Option<Vec<Nystroem<Trained>>>,
89    weights_: Option<Vec<Float>>,
90    n_features_out_: Option<usize>,
91
92    _state: PhantomData<State>,
93}
94
95impl EnsembleNystroem<Untrained> {
96    pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
97        Self {
98            kernel,
99            n_estimators,
100            n_components,
101            ensemble_method: EnsembleMethod::WeightedAverage,
102            sampling_strategies: None,
103            quality_metric: QualityMetric::FrobeniusNorm,
104            random_state: None,
105            estimators_: None,
106            weights_: None,
107            n_features_out_: None,
108            _state: PhantomData,
109        }
110    }
111
112    /// Set the ensemble method
113    pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
114        self.ensemble_method = method;
115        self
116    }
117
118    /// Set custom sampling strategies
119    pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
120        self.sampling_strategies = Some(strategies);
121        self
122    }
123
124    /// Set the quality metric for evaluation
125    pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
126        self.quality_metric = metric;
127        self
128    }
129
130    /// Set random state for reproducibility
131    pub fn random_state(mut self, seed: u64) -> Self {
132        self.random_state = Some(seed);
133        self
134    }
135
136    /// Generate diverse sampling strategies
137    fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
138        if let Some(ref strategies) = self.sampling_strategies {
139            strategies.clone()
140        } else {
141            // Generate diverse set of strategies
142            let mut strategies = Vec::new();
143            let base_strategies = [
144                SamplingStrategy::Random,
145                SamplingStrategy::KMeans,
146                SamplingStrategy::LeverageScore,
147                SamplingStrategy::ColumnNorm,
148            ];
149
150            for i in 0..self.n_estimators {
151                strategies.push(base_strategies[i % base_strategies.len()].clone());
152            }
153            strategies
154        }
155    }
156
157    /// Compute quality score for a Nyström approximation
158    fn compute_quality_score(
159        &self,
160        estimator: &Nystroem<Trained>,
161        _x: &Array2<Float>,
162    ) -> Result<Float> {
163        match self.quality_metric {
164            QualityMetric::FrobeniusNorm => {
165                // Approximate quality based on component matrix properties
166                let components = estimator.components();
167                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
168                Ok(norm)
169            }
170            QualityMetric::Trace => {
171                let components = estimator.components();
172                let kernel_matrix = self.kernel.compute_kernel(components, components);
173                Ok(kernel_matrix.diag().sum())
174            }
175            QualityMetric::SpectralNorm => {
176                // Approximate spectral norm using power iteration
177                let components = estimator.components();
178                let kernel_matrix = self.kernel.compute_kernel(components, components);
179                self.power_iteration_spectral_norm(&kernel_matrix)
180            }
181            QualityMetric::NuclearNorm => {
182                let components = estimator.components();
183                let kernel_matrix = self.kernel.compute_kernel(components, components);
184                // Nuclear norm is sum of eigenvalues, approximate with trace
185                Ok(kernel_matrix.diag().sum())
186            }
187        }
188    }
189
190    /// Approximate spectral norm using power iteration
191    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
192        let n = matrix.nrows();
193        if n == 0 {
194            return Ok(0.0);
195        }
196
197        let mut v = Array1::ones(n) / (n as Float).sqrt();
198        let max_iter = 100;
199        let tolerance = 1e-6;
200
201        for _ in 0..max_iter {
202            let v_new = matrix.dot(&v);
203            let norm = (v_new.dot(&v_new)).sqrt();
204
205            if norm < tolerance {
206                break;
207            }
208
209            let v_normalized = &v_new / norm;
210            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
211            v = v_normalized;
212
213            if diff < tolerance {
214                break;
215            }
216        }
217
218        let eigenvalue = v.dot(&matrix.dot(&v));
219        Ok(eigenvalue.abs())
220    }
221}
222
223impl Estimator for EnsembleNystroem<Untrained> {
224    type Config = ();
225    type Error = SklearsError;
226    type Float = Float;
227
228    fn config(&self) -> &Self::Config {
229        &()
230    }
231}
232
233impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
234    type Fitted = EnsembleNystroem<Trained>;
235
236    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
237        if self.n_estimators == 0 {
238            return Err(SklearsError::InvalidInput(
239                "n_estimators must be positive".to_string(),
240            ));
241        }
242
243        if self.n_components == 0 {
244            return Err(SklearsError::InvalidInput(
245                "n_components must be positive".to_string(),
246            ));
247        }
248
249        let mut rng = if let Some(seed) = self.random_state {
250            RealStdRng::seed_from_u64(seed)
251        } else {
252            RealStdRng::from_seed(thread_rng().gen())
253        };
254
255        let sampling_strategies = self.generate_sampling_strategies();
256        let mut estimators = Vec::new();
257        let mut quality_scores = Vec::new();
258
259        // Train base estimators
260        for i in 0..self.n_estimators {
261            let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
262            let seed = if self.random_state.is_some() {
263                // Use deterministic seed sequence for reproducibility
264                self.random_state.unwrap().wrapping_add(i as u64)
265            } else {
266                rng.gen::<u64>()
267            };
268
269            let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
270                .sampling_strategy(strategy)
271                .random_state(seed);
272
273            let fitted_nystroem = nystroem.fit(x, &())?;
274
275            // Compute quality score
276            let quality = self.compute_quality_score(&fitted_nystroem, x)?;
277            quality_scores.push(quality);
278            estimators.push(fitted_nystroem);
279        }
280
281        // Compute weights based on ensemble method
282        let weights = match self.ensemble_method {
283            EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
284            EnsembleMethod::WeightedAverage => {
285                let total_quality: Float = quality_scores.iter().sum();
286                if total_quality > 0.0 {
287                    quality_scores.iter().map(|&q| q / total_quality).collect()
288                } else {
289                    vec![1.0 / self.n_estimators as Float; self.n_estimators]
290                }
291            }
292            EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
293            EnsembleMethod::BestApproximation => {
294                let best_idx = quality_scores
295                    .iter()
296                    .enumerate()
297                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
298                    .map(|(idx, _)| idx)
299                    .unwrap_or(0);
300                let mut weights = vec![0.0; self.n_estimators];
301                weights[best_idx] = 1.0;
302                weights
303            }
304        };
305
306        // Determine output feature size
307        let n_features_out = match self.ensemble_method {
308            EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
309            _ => self.n_components,
310        };
311
312        Ok(EnsembleNystroem {
313            kernel: self.kernel,
314            n_estimators: self.n_estimators,
315            n_components: self.n_components,
316            ensemble_method: self.ensemble_method,
317            sampling_strategies: self.sampling_strategies,
318            quality_metric: self.quality_metric,
319            random_state: self.random_state,
320            estimators_: Some(estimators),
321            weights_: Some(weights),
322            n_features_out_: Some(n_features_out),
323            _state: PhantomData,
324        })
325    }
326}
327
328impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
329    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
330        let estimators = self.estimators_.as_ref().unwrap();
331        let weights = self.weights_.as_ref().unwrap();
332        let n_features_out = self.n_features_out_.unwrap();
333        let (n_samples, _) = x.dim();
334
335        match self.ensemble_method {
336            EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
337                let mut result = Array2::zeros((n_samples, self.n_components));
338
339                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
340                    if weight > 0.0 {
341                        let transformed = estimator.transform(x)?;
342                        result += &(transformed * weight);
343                    }
344                }
345
346                Ok(result)
347            }
348            EnsembleMethod::Concatenate => {
349                let mut result = Array2::zeros((n_samples, n_features_out));
350                let mut col_offset = 0;
351
352                for estimator in estimators.iter() {
353                    let transformed = estimator.transform(x)?;
354                    let n_cols = transformed.ncols();
355                    result
356                        .slice_mut(s![.., col_offset..col_offset + n_cols])
357                        .assign(&transformed);
358                    col_offset += n_cols;
359                }
360
361                Ok(result)
362            }
363            EnsembleMethod::BestApproximation => {
364                let best_idx = weights
365                    .iter()
366                    .enumerate()
367                    .find(|(_, &w)| w > 0.0)
368                    .map(|(idx, _)| idx)
369                    .unwrap_or(0);
370
371                estimators[best_idx].transform(x)
372            }
373        }
374    }
375}
376
377impl EnsembleNystroem<Trained> {
378    /// Get the base estimators
379    pub fn estimators(&self) -> &[Nystroem<Trained>] {
380        self.estimators_.as_ref().unwrap()
381    }
382
383    /// Get the estimator weights
384    pub fn weights(&self) -> &[Float] {
385        self.weights_.as_ref().unwrap()
386    }
387
388    /// Get the number of output features
389    pub fn n_features_out(&self) -> usize {
390        self.n_features_out_.unwrap()
391    }
392
393    /// Get quality scores for all estimators
394    pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
395        let estimators = self.estimators_.as_ref().unwrap();
396        let mut scores = Vec::new();
397
398        for estimator in estimators.iter() {
399            let score = self.compute_quality_score_for_estimator(estimator, x)?;
400            scores.push(score);
401        }
402
403        Ok(scores)
404    }
405
406    /// Compute quality score for a single estimator (helper method)
407    fn compute_quality_score_for_estimator(
408        &self,
409        estimator: &Nystroem<Trained>,
410        _x: &Array2<Float>,
411    ) -> Result<Float> {
412        match self.quality_metric {
413            QualityMetric::FrobeniusNorm => {
414                let components = estimator.components();
415                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
416                Ok(norm)
417            }
418            QualityMetric::Trace => {
419                let components = estimator.components();
420                let kernel_matrix = self.kernel.compute_kernel(components, components);
421                Ok(kernel_matrix.diag().sum())
422            }
423            QualityMetric::SpectralNorm => {
424                let components = estimator.components();
425                let kernel_matrix = self.kernel.compute_kernel(components, components);
426                self.power_iteration_spectral_norm(&kernel_matrix)
427            }
428            QualityMetric::NuclearNorm => {
429                let components = estimator.components();
430                let kernel_matrix = self.kernel.compute_kernel(components, components);
431                Ok(kernel_matrix.diag().sum())
432            }
433        }
434    }
435
436    /// Approximate spectral norm using power iteration
437    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
438        let n = matrix.nrows();
439        if n == 0 {
440            return Ok(0.0);
441        }
442
443        let mut v = Array1::ones(n) / (n as Float).sqrt();
444        let max_iter = 100;
445        let tolerance = 1e-6;
446
447        for _ in 0..max_iter {
448            let v_new = matrix.dot(&v);
449            let norm = (v_new.dot(&v_new)).sqrt();
450
451            if norm < tolerance {
452                break;
453            }
454
455            let v_normalized = &v_new / norm;
456            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
457            v = v_normalized;
458
459            if diff < tolerance {
460                break;
461            }
462        }
463
464        let eigenvalue = v.dot(&matrix.dot(&v));
465        Ok(eigenvalue.abs())
466    }
467}
468
469// Add ndarray slice import
470use scirs2_core::ndarray::s;
471
472#[allow(non_snake_case)]
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use scirs2_core::ndarray::array;
477
478    #[test]
479    fn test_ensemble_nystroem_basic() {
480        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
481
482        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
483        let fitted = ensemble.fit(&x, &()).unwrap();
484        let x_transformed = fitted.transform(&x).unwrap();
485
486        assert_eq!(x_transformed.nrows(), 4);
487        assert_eq!(x_transformed.ncols(), 2); // n_components
488    }
489
490    #[test]
491    fn test_ensemble_nystroem_average() {
492        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
493
494        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
495            .ensemble_method(EnsembleMethod::Average);
496        let fitted = ensemble.fit(&x, &()).unwrap();
497        let x_transformed = fitted.transform(&x).unwrap();
498
499        assert_eq!(x_transformed.shape(), &[3, 3]);
500    }
501
502    #[test]
503    fn test_ensemble_nystroem_concatenate() {
504        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
505
506        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
507            .ensemble_method(EnsembleMethod::Concatenate);
508        let fitted = ensemble.fit(&x, &()).unwrap();
509        let x_transformed = fitted.transform(&x).unwrap();
510
511        assert_eq!(x_transformed.shape(), &[3, 6]); // 2 estimators * 3 components = 6
512    }
513
514    #[test]
515    fn test_ensemble_nystroem_weighted_average() {
516        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
517
518        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
519            .ensemble_method(EnsembleMethod::WeightedAverage);
520        let fitted = ensemble.fit(&x, &()).unwrap();
521        let x_transformed = fitted.transform(&x).unwrap();
522
523        assert_eq!(x_transformed.shape(), &[4, 2]);
524
525        // Check that weights sum to 1 (approximately)
526        let weights = fitted.weights();
527        let weight_sum: Float = weights.iter().sum();
528        assert!((weight_sum - 1.0).abs() < 1e-6);
529    }
530
531    #[test]
532    fn test_ensemble_nystroem_best_approximation() {
533        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
534
535        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
536            .ensemble_method(EnsembleMethod::BestApproximation);
537        let fitted = ensemble.fit(&x, &()).unwrap();
538        let x_transformed = fitted.transform(&x).unwrap();
539
540        assert_eq!(x_transformed.shape(), &[3, 2]);
541
542        // Check that exactly one weight is 1.0 and others are 0.0
543        let weights = fitted.weights();
544        let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
545        assert_eq!(active_weights.len(), 1);
546        assert!((active_weights[0] - 1.0).abs() < 1e-10);
547    }
548
549    #[test]
550    fn test_ensemble_nystroem_custom_strategies() {
551        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
552
553        let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
554
555        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
556        let fitted = ensemble.fit(&x, &()).unwrap();
557        let x_transformed = fitted.transform(&x).unwrap();
558
559        assert_eq!(x_transformed.shape(), &[4, 3]);
560        assert_eq!(fitted.estimators().len(), 2);
561    }
562
563    #[test]
564    fn test_ensemble_nystroem_reproducibility() {
565        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
566
567        let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
568        let fitted1 = ensemble1.fit(&x, &()).unwrap();
569        let result1 = fitted1.transform(&x).unwrap();
570
571        let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
572        let fitted2 = ensemble2.fit(&x, &()).unwrap();
573        let result2 = fitted2.transform(&x).unwrap();
574
575        // Results should be very similar with same random state (allowing for numerical precision)
576        assert_eq!(result1.shape(), result2.shape());
577        for (a, b) in result1.iter().zip(result2.iter()) {
578            assert!(
579                (a - b).abs() < 1e-6,
580                "Values differ too much: {} vs {}",
581                a,
582                b
583            );
584        }
585    }
586
587    #[test]
588    fn test_ensemble_nystroem_quality_metrics() {
589        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
590
591        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
592            .quality_metric(QualityMetric::Trace);
593        let fitted = ensemble.fit(&x, &()).unwrap();
594        let quality_scores = fitted.quality_scores(&x).unwrap();
595
596        assert_eq!(quality_scores.len(), 2);
597        for score in quality_scores.iter() {
598            assert!(score.is_finite());
599            assert!(*score >= 0.0);
600        }
601    }
602
603    #[test]
604    fn test_ensemble_nystroem_invalid_parameters() {
605        let x = array![[1.0, 2.0]];
606
607        // Zero estimators
608        let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
609        assert!(ensemble.fit(&x, &()).is_err());
610
611        // Zero components
612        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
613        assert!(ensemble.fit(&x, &()).is_err());
614    }
615}