Skip to main content

sklears_kernel_approximation/
ensemble_nystroem.rs

1//! Ensemble Nyström method for improved kernel approximation
2use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use sklears_core::{
6    error::{Result, SklearsError},
7    traits::{Estimator, Fit, Trained, Transform, Untrained},
8    types::Float,
9};
10use std::marker::PhantomData;
11
12use scirs2_core::random::RngExt;
13use scirs2_core::random::{thread_rng, SeedableRng};
14/// Ensemble method for combining multiple Nyström approximations
15#[derive(Debug, Clone)]
16pub enum EnsembleMethod {
17    /// Simple averaging of all approximations
18    Average,
19    /// Weighted average based on approximation quality (higher quality gets more weight)
20    WeightedAverage,
21    /// Concatenate all approximations
22    Concatenate,
23    /// Use the best approximation based on some quality metric
24    BestApproximation,
25}
26
27/// Quality metric for evaluating Nyström approximations
28#[derive(Debug, Clone)]
29pub enum QualityMetric {
30    /// Frobenius norm of the approximation error
31    FrobeniusNorm,
32    /// Trace of the approximation
33    Trace,
34    /// Spectral norm (largest eigenvalue)
35    SpectralNorm,
36    /// Nuclear norm (sum of eigenvalues)
37    NuclearNorm,
38}
39
40/// Ensemble Nyström method for kernel approximation
41///
42/// Combines multiple Nyström approximations using different sampling strategies
43/// and component sizes to achieve better approximation quality than a single
44/// Nyström approximation.
45///
46/// # Parameters
47///
48/// * `kernel` - Kernel function to approximate
49/// * `n_estimators` - Number of base Nyström estimators (default: 5)
50/// * `n_components` - Number of samples per estimator (default: 100)
51/// * `ensemble_method` - Method for combining estimators
52/// * `sampling_strategies` - List of sampling strategies to use
53/// * `random_state` - Random seed for reproducibility
54///
55/// # Examples
56///
57/// ```rust,ignore
58/// use sklears_kernel_approximation::ensemble_nystroem::{EnsembleNystroem, EnsembleMethod};
59/// use sklears_kernel_approximation::nystroem::{Kernel, SamplingStrategy};
60/// use sklears_core::traits::{Transform, Fit, Untrained}
61/// use scirs2_core::ndarray::array;
62///
63/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
64///
65/// let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3, 2)
66///     .ensemble_method(EnsembleMethod::WeightedAverage);
67/// let fitted_ensemble = ensemble.fit(&X, &()).unwrap();
68/// let X_transformed = fitted_ensemble.transform(&X).unwrap();
69/// ```
70#[derive(Debug, Clone)]
71pub struct EnsembleNystroem<State = Untrained> {
72    /// Kernel function
73    pub kernel: Kernel,
74    /// Number of base estimators
75    pub n_estimators: usize,
76    /// Number of components per estimator
77    pub n_components: usize,
78    /// Method for ensemble combination
79    pub ensemble_method: EnsembleMethod,
80    /// Sampling strategies to use (if None, uses diverse set)
81    pub sampling_strategies: Option<Vec<SamplingStrategy>>,
82    /// Quality metric for evaluating approximations
83    pub quality_metric: QualityMetric,
84    /// Random seed
85    pub random_state: Option<u64>,
86
87    // Fitted attributes
88    estimators_: Option<Vec<Nystroem<Trained>>>,
89    weights_: Option<Vec<Float>>,
90    n_features_out_: Option<usize>,
91
92    _state: PhantomData<State>,
93}
94
95impl EnsembleNystroem<Untrained> {
96    pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
97        Self {
98            kernel,
99            n_estimators,
100            n_components,
101            ensemble_method: EnsembleMethod::WeightedAverage,
102            sampling_strategies: None,
103            quality_metric: QualityMetric::FrobeniusNorm,
104            random_state: None,
105            estimators_: None,
106            weights_: None,
107            n_features_out_: None,
108            _state: PhantomData,
109        }
110    }
111
112    /// Set the ensemble method
113    pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
114        self.ensemble_method = method;
115        self
116    }
117
118    /// Set custom sampling strategies
119    pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
120        self.sampling_strategies = Some(strategies);
121        self
122    }
123
124    /// Set the quality metric for evaluation
125    pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
126        self.quality_metric = metric;
127        self
128    }
129
130    /// Set random state for reproducibility
131    pub fn random_state(mut self, seed: u64) -> Self {
132        self.random_state = Some(seed);
133        self
134    }
135
136    /// Generate diverse sampling strategies
137    fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
138        if let Some(ref strategies) = self.sampling_strategies {
139            strategies.clone()
140        } else {
141            // Generate diverse set of strategies
142            let mut strategies = Vec::new();
143            let base_strategies = [
144                SamplingStrategy::Random,
145                SamplingStrategy::KMeans,
146                SamplingStrategy::LeverageScore,
147                SamplingStrategy::ColumnNorm,
148            ];
149
150            for i in 0..self.n_estimators {
151                strategies.push(base_strategies[i % base_strategies.len()].clone());
152            }
153            strategies
154        }
155    }
156
157    /// Compute quality score for a Nyström approximation
158    fn compute_quality_score(
159        &self,
160        estimator: &Nystroem<Trained>,
161        _x: &Array2<Float>,
162    ) -> Result<Float> {
163        match self.quality_metric {
164            QualityMetric::FrobeniusNorm => {
165                // Approximate quality based on component matrix properties
166                let components = estimator.components();
167                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
168                Ok(norm)
169            }
170            QualityMetric::Trace => {
171                let components = estimator.components();
172                let kernel_matrix = self.kernel.compute_kernel(components, components);
173                Ok(kernel_matrix.diag().sum())
174            }
175            QualityMetric::SpectralNorm => {
176                // Approximate spectral norm using power iteration
177                let components = estimator.components();
178                let kernel_matrix = self.kernel.compute_kernel(components, components);
179                self.power_iteration_spectral_norm(&kernel_matrix)
180            }
181            QualityMetric::NuclearNorm => {
182                let components = estimator.components();
183                let kernel_matrix = self.kernel.compute_kernel(components, components);
184                // Nuclear norm is sum of eigenvalues, approximate with trace
185                Ok(kernel_matrix.diag().sum())
186            }
187        }
188    }
189
190    /// Approximate spectral norm using power iteration
191    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
192        let n = matrix.nrows();
193        if n == 0 {
194            return Ok(0.0);
195        }
196
197        let mut v = Array1::ones(n) / (n as Float).sqrt();
198        let max_iter = 100;
199        let tolerance = 1e-6;
200
201        for _ in 0..max_iter {
202            let v_new = matrix.dot(&v);
203            let norm = (v_new.dot(&v_new)).sqrt();
204
205            if norm < tolerance {
206                break;
207            }
208
209            let v_normalized = &v_new / norm;
210            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
211            v = v_normalized;
212
213            if diff < tolerance {
214                break;
215            }
216        }
217
218        let eigenvalue = v.dot(&matrix.dot(&v));
219        Ok(eigenvalue.abs())
220    }
221}
222
223impl Estimator for EnsembleNystroem<Untrained> {
224    type Config = ();
225    type Error = SklearsError;
226    type Float = Float;
227
228    fn config(&self) -> &Self::Config {
229        &()
230    }
231}
232
233impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
234    type Fitted = EnsembleNystroem<Trained>;
235
236    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
237        if self.n_estimators == 0 {
238            return Err(SklearsError::InvalidInput(
239                "n_estimators must be positive".to_string(),
240            ));
241        }
242
243        if self.n_components == 0 {
244            return Err(SklearsError::InvalidInput(
245                "n_components must be positive".to_string(),
246            ));
247        }
248
249        let mut rng = if let Some(seed) = self.random_state {
250            RealStdRng::seed_from_u64(seed)
251        } else {
252            RealStdRng::from_seed(thread_rng().random())
253        };
254
255        let sampling_strategies = self.generate_sampling_strategies();
256        let mut estimators = Vec::new();
257        let mut quality_scores = Vec::new();
258
259        // Train base estimators
260        for i in 0..self.n_estimators {
261            let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
262            let seed = if self.random_state.is_some() {
263                // Use deterministic seed sequence for reproducibility
264                self.random_state
265                    .expect("operation should succeed")
266                    .wrapping_add(i as u64)
267            } else {
268                rng.random::<u64>()
269            };
270
271            let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
272                .sampling_strategy(strategy)
273                .random_state(seed);
274
275            let fitted_nystroem = nystroem.fit(x, &())?;
276
277            // Compute quality score
278            let quality = self.compute_quality_score(&fitted_nystroem, x)?;
279            quality_scores.push(quality);
280            estimators.push(fitted_nystroem);
281        }
282
283        // Compute weights based on ensemble method
284        let weights = match self.ensemble_method {
285            EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
286            EnsembleMethod::WeightedAverage => {
287                let total_quality: Float = quality_scores.iter().sum();
288                if total_quality > 0.0 {
289                    quality_scores.iter().map(|&q| q / total_quality).collect()
290                } else {
291                    vec![1.0 / self.n_estimators as Float; self.n_estimators]
292                }
293            }
294            EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
295            EnsembleMethod::BestApproximation => {
296                let best_idx = quality_scores
297                    .iter()
298                    .enumerate()
299                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
300                    .map(|(idx, _)| idx)
301                    .unwrap_or(0);
302                let mut weights = vec![0.0; self.n_estimators];
303                weights[best_idx] = 1.0;
304                weights
305            }
306        };
307
308        // Determine output feature size
309        let n_features_out = match self.ensemble_method {
310            EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
311            _ => self.n_components,
312        };
313
314        Ok(EnsembleNystroem {
315            kernel: self.kernel,
316            n_estimators: self.n_estimators,
317            n_components: self.n_components,
318            ensemble_method: self.ensemble_method,
319            sampling_strategies: self.sampling_strategies,
320            quality_metric: self.quality_metric,
321            random_state: self.random_state,
322            estimators_: Some(estimators),
323            weights_: Some(weights),
324            n_features_out_: Some(n_features_out),
325            _state: PhantomData,
326        })
327    }
328}
329
330impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
331    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
332        let estimators = self.estimators_.as_ref().expect("operation should succeed");
333        let weights = self.weights_.as_ref().expect("operation should succeed");
334        let n_features_out = self.n_features_out_.expect("operation should succeed");
335        let (n_samples, _) = x.dim();
336
337        match self.ensemble_method {
338            EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
339                let mut result = Array2::zeros((n_samples, self.n_components));
340
341                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
342                    if weight > 0.0 {
343                        let transformed = estimator.transform(x)?;
344                        result += &(transformed * weight);
345                    }
346                }
347
348                Ok(result)
349            }
350            EnsembleMethod::Concatenate => {
351                let mut result = Array2::zeros((n_samples, n_features_out));
352                let mut col_offset = 0;
353
354                for estimator in estimators.iter() {
355                    let transformed = estimator.transform(x)?;
356                    let n_cols = transformed.ncols();
357                    result
358                        .slice_mut(s![.., col_offset..col_offset + n_cols])
359                        .assign(&transformed);
360                    col_offset += n_cols;
361                }
362
363                Ok(result)
364            }
365            EnsembleMethod::BestApproximation => {
366                let best_idx = weights
367                    .iter()
368                    .enumerate()
369                    .find(|(_, &w)| w > 0.0)
370                    .map(|(idx, _)| idx)
371                    .unwrap_or(0);
372
373                estimators[best_idx].transform(x)
374            }
375        }
376    }
377}
378
379impl EnsembleNystroem<Trained> {
380    /// Get the base estimators
381    pub fn estimators(&self) -> &[Nystroem<Trained>] {
382        self.estimators_.as_ref().expect("operation should succeed")
383    }
384
385    /// Get the estimator weights
386    pub fn weights(&self) -> &[Float] {
387        self.weights_.as_ref().expect("operation should succeed")
388    }
389
390    /// Get the number of output features
391    pub fn n_features_out(&self) -> usize {
392        self.n_features_out_.expect("operation should succeed")
393    }
394
395    /// Get quality scores for all estimators
396    pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
397        let estimators = self.estimators_.as_ref().expect("operation should succeed");
398        let mut scores = Vec::new();
399
400        for estimator in estimators.iter() {
401            let score = self.compute_quality_score_for_estimator(estimator, x)?;
402            scores.push(score);
403        }
404
405        Ok(scores)
406    }
407
408    /// Compute quality score for a single estimator (helper method)
409    fn compute_quality_score_for_estimator(
410        &self,
411        estimator: &Nystroem<Trained>,
412        _x: &Array2<Float>,
413    ) -> Result<Float> {
414        match self.quality_metric {
415            QualityMetric::FrobeniusNorm => {
416                let components = estimator.components();
417                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
418                Ok(norm)
419            }
420            QualityMetric::Trace => {
421                let components = estimator.components();
422                let kernel_matrix = self.kernel.compute_kernel(components, components);
423                Ok(kernel_matrix.diag().sum())
424            }
425            QualityMetric::SpectralNorm => {
426                let components = estimator.components();
427                let kernel_matrix = self.kernel.compute_kernel(components, components);
428                self.power_iteration_spectral_norm(&kernel_matrix)
429            }
430            QualityMetric::NuclearNorm => {
431                let components = estimator.components();
432                let kernel_matrix = self.kernel.compute_kernel(components, components);
433                Ok(kernel_matrix.diag().sum())
434            }
435        }
436    }
437
438    /// Approximate spectral norm using power iteration
439    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
440        let n = matrix.nrows();
441        if n == 0 {
442            return Ok(0.0);
443        }
444
445        let mut v = Array1::ones(n) / (n as Float).sqrt();
446        let max_iter = 100;
447        let tolerance = 1e-6;
448
449        for _ in 0..max_iter {
450            let v_new = matrix.dot(&v);
451            let norm = (v_new.dot(&v_new)).sqrt();
452
453            if norm < tolerance {
454                break;
455            }
456
457            let v_normalized = &v_new / norm;
458            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
459            v = v_normalized;
460
461            if diff < tolerance {
462                break;
463            }
464        }
465
466        let eigenvalue = v.dot(&matrix.dot(&v));
467        Ok(eigenvalue.abs())
468    }
469}
470
471// Add ndarray slice import
472use scirs2_core::ndarray::s;
473
474#[allow(non_snake_case)]
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use scirs2_core::ndarray::array;
479
480    #[test]
481    fn test_ensemble_nystroem_basic() {
482        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
483
484        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
485        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
486        let x_transformed = fitted.transform(&x).expect("operation should succeed");
487
488        assert_eq!(x_transformed.nrows(), 4);
489        assert_eq!(x_transformed.ncols(), 2); // n_components
490    }
491
492    #[test]
493    fn test_ensemble_nystroem_average() {
494        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
495
496        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
497            .ensemble_method(EnsembleMethod::Average);
498        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
499        let x_transformed = fitted.transform(&x).expect("operation should succeed");
500
501        assert_eq!(x_transformed.shape(), &[3, 3]);
502    }
503
504    #[test]
505    fn test_ensemble_nystroem_concatenate() {
506        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
507
508        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
509            .ensemble_method(EnsembleMethod::Concatenate);
510        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
511        let x_transformed = fitted.transform(&x).expect("operation should succeed");
512
513        assert_eq!(x_transformed.shape(), &[3, 6]); // 2 estimators * 3 components = 6
514    }
515
516    #[test]
517    fn test_ensemble_nystroem_weighted_average() {
518        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
519
520        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
521            .ensemble_method(EnsembleMethod::WeightedAverage);
522        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
523        let x_transformed = fitted.transform(&x).expect("operation should succeed");
524
525        assert_eq!(x_transformed.shape(), &[4, 2]);
526
527        // Check that weights sum to 1 (approximately)
528        let weights = fitted.weights();
529        let weight_sum: Float = weights.iter().sum();
530        assert!((weight_sum - 1.0).abs() < 1e-6);
531    }
532
533    #[test]
534    fn test_ensemble_nystroem_best_approximation() {
535        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
536
537        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
538            .ensemble_method(EnsembleMethod::BestApproximation);
539        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
540        let x_transformed = fitted.transform(&x).expect("operation should succeed");
541
542        assert_eq!(x_transformed.shape(), &[3, 2]);
543
544        // Check that exactly one weight is 1.0 and others are 0.0
545        let weights = fitted.weights();
546        let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
547        assert_eq!(active_weights.len(), 1);
548        assert!((active_weights[0] - 1.0).abs() < 1e-10);
549    }
550
551    #[test]
552    fn test_ensemble_nystroem_custom_strategies() {
553        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
554
555        let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
556
557        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
558        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
559        let x_transformed = fitted.transform(&x).expect("operation should succeed");
560
561        assert_eq!(x_transformed.shape(), &[4, 3]);
562        assert_eq!(fitted.estimators().len(), 2);
563    }
564
565    #[test]
566    fn test_ensemble_nystroem_reproducibility() {
567        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
568
569        let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
570        let fitted1 = ensemble1.fit(&x, &()).expect("operation should succeed");
571        let result1 = fitted1.transform(&x).expect("operation should succeed");
572
573        let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
574        let fitted2 = ensemble2.fit(&x, &()).expect("operation should succeed");
575        let result2 = fitted2.transform(&x).expect("operation should succeed");
576
577        // Results should be very similar with same random state (allowing for numerical precision)
578        assert_eq!(result1.shape(), result2.shape());
579        for (a, b) in result1.iter().zip(result2.iter()) {
580            assert!(
581                (a - b).abs() < 1e-6,
582                "Values differ too much: {} vs {}",
583                a,
584                b
585            );
586        }
587    }
588
589    #[test]
590    fn test_ensemble_nystroem_quality_metrics() {
591        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
592
593        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
594            .quality_metric(QualityMetric::Trace);
595        let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
596        let quality_scores = fitted.quality_scores(&x).expect("operation should succeed");
597
598        assert_eq!(quality_scores.len(), 2);
599        for score in quality_scores.iter() {
600            assert!(score.is_finite());
601            assert!(*score >= 0.0);
602        }
603    }
604
605    #[test]
606    fn test_ensemble_nystroem_invalid_parameters() {
607        let x = array![[1.0, 2.0]];
608
609        // Zero estimators
610        let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
611        assert!(ensemble.fit(&x, &()).is_err());
612
613        // Zero components
614        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
615        assert!(ensemble.fit(&x, &()).is_err());
616    }
617}