sklears_kernel_approximation/
ensemble_nystroem.rs

1//! Ensemble Nyström method for improved kernel approximation
2use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use sklears_core::{
6    error::{Result, SklearsError},
7    traits::{Estimator, Fit, Trained, Transform, Untrained},
8    types::Float,
9};
10use std::marker::PhantomData;
11
12use scirs2_core::random::{thread_rng, Rng, SeedableRng};
13/// Ensemble method for combining multiple Nyström approximations
14#[derive(Debug, Clone)]
15pub enum EnsembleMethod {
16    /// Simple averaging of all approximations
17    Average,
18    /// Weighted average based on approximation quality (higher quality gets more weight)
19    WeightedAverage,
20    /// Concatenate all approximations
21    Concatenate,
22    /// Use the best approximation based on some quality metric
23    BestApproximation,
24}
25
26/// Quality metric for evaluating Nyström approximations
27#[derive(Debug, Clone)]
28pub enum QualityMetric {
29    /// Frobenius norm of the approximation error
30    FrobeniusNorm,
31    /// Trace of the approximation
32    Trace,
33    /// Spectral norm (largest eigenvalue)
34    SpectralNorm,
35    /// Nuclear norm (sum of eigenvalues)
36    NuclearNorm,
37}
38
39/// Ensemble Nyström method for kernel approximation
40///
41/// Combines multiple Nyström approximations using different sampling strategies
42/// and component sizes to achieve better approximation quality than a single
43/// Nyström approximation.
44///
45/// # Parameters
46///
47/// * `kernel` - Kernel function to approximate
48/// * `n_estimators` - Number of base Nyström estimators (default: 5)
49/// * `n_components` - Number of samples per estimator (default: 100)
50/// * `ensemble_method` - Method for combining estimators
51/// * `sampling_strategies` - List of sampling strategies to use
52/// * `random_state` - Random seed for reproducibility
53///
54/// # Examples
55///
56/// ```rust,ignore
57/// use sklears_kernel_approximation::ensemble_nystroem::{EnsembleNystroem, EnsembleMethod};
58/// use sklears_kernel_approximation::nystroem::{Kernel, SamplingStrategy};
59/// use sklears_core::traits::{Transform, Fit, Untrained}
60/// use scirs2_core::ndarray::array;
61///
62/// let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
63///
64/// let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 1.0 }, 3, 2)
65///     .ensemble_method(EnsembleMethod::WeightedAverage);
66/// let fitted_ensemble = ensemble.fit(&X, &()).unwrap();
67/// let X_transformed = fitted_ensemble.transform(&X).unwrap();
68/// ```
69#[derive(Debug, Clone)]
70pub struct EnsembleNystroem<State = Untrained> {
71    /// Kernel function
72    pub kernel: Kernel,
73    /// Number of base estimators
74    pub n_estimators: usize,
75    /// Number of components per estimator
76    pub n_components: usize,
77    /// Method for ensemble combination
78    pub ensemble_method: EnsembleMethod,
79    /// Sampling strategies to use (if None, uses diverse set)
80    pub sampling_strategies: Option<Vec<SamplingStrategy>>,
81    /// Quality metric for evaluating approximations
82    pub quality_metric: QualityMetric,
83    /// Random seed
84    pub random_state: Option<u64>,
85
86    // Fitted attributes
87    estimators_: Option<Vec<Nystroem<Trained>>>,
88    weights_: Option<Vec<Float>>,
89    n_features_out_: Option<usize>,
90
91    _state: PhantomData<State>,
92}
93
94impl EnsembleNystroem<Untrained> {
95    pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
96        Self {
97            kernel,
98            n_estimators,
99            n_components,
100            ensemble_method: EnsembleMethod::WeightedAverage,
101            sampling_strategies: None,
102            quality_metric: QualityMetric::FrobeniusNorm,
103            random_state: None,
104            estimators_: None,
105            weights_: None,
106            n_features_out_: None,
107            _state: PhantomData,
108        }
109    }
110
111    /// Set the ensemble method
112    pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
113        self.ensemble_method = method;
114        self
115    }
116
117    /// Set custom sampling strategies
118    pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
119        self.sampling_strategies = Some(strategies);
120        self
121    }
122
123    /// Set the quality metric for evaluation
124    pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
125        self.quality_metric = metric;
126        self
127    }
128
129    /// Set random state for reproducibility
130    pub fn random_state(mut self, seed: u64) -> Self {
131        self.random_state = Some(seed);
132        self
133    }
134
135    /// Generate diverse sampling strategies
136    fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
137        if let Some(ref strategies) = self.sampling_strategies {
138            strategies.clone()
139        } else {
140            // Generate diverse set of strategies
141            let mut strategies = Vec::new();
142            let base_strategies = vec![
143                SamplingStrategy::Random,
144                SamplingStrategy::KMeans,
145                SamplingStrategy::LeverageScore,
146                SamplingStrategy::ColumnNorm,
147            ];
148
149            for i in 0..self.n_estimators {
150                strategies.push(base_strategies[i % base_strategies.len()].clone());
151            }
152            strategies
153        }
154    }
155
156    /// Compute quality score for a Nyström approximation
157    fn compute_quality_score(
158        &self,
159        estimator: &Nystroem<Trained>,
160        x: &Array2<Float>,
161    ) -> Result<Float> {
162        match self.quality_metric {
163            QualityMetric::FrobeniusNorm => {
164                // Approximate quality based on component matrix properties
165                let components = estimator.components();
166                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
167                Ok(norm)
168            }
169            QualityMetric::Trace => {
170                let components = estimator.components();
171                let kernel_matrix = self.kernel.compute_kernel(components, components);
172                Ok(kernel_matrix.diag().sum())
173            }
174            QualityMetric::SpectralNorm => {
175                // Approximate spectral norm using power iteration
176                let components = estimator.components();
177                let kernel_matrix = self.kernel.compute_kernel(components, components);
178                self.power_iteration_spectral_norm(&kernel_matrix)
179            }
180            QualityMetric::NuclearNorm => {
181                let components = estimator.components();
182                let kernel_matrix = self.kernel.compute_kernel(components, components);
183                // Nuclear norm is sum of eigenvalues, approximate with trace
184                Ok(kernel_matrix.diag().sum())
185            }
186        }
187    }
188
189    /// Approximate spectral norm using power iteration
190    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
191        let n = matrix.nrows();
192        if n == 0 {
193            return Ok(0.0);
194        }
195
196        let mut v = Array1::ones(n) / (n as Float).sqrt();
197        let max_iter = 100;
198        let tolerance = 1e-6;
199
200        for _ in 0..max_iter {
201            let v_new = matrix.dot(&v);
202            let norm = (v_new.dot(&v_new)).sqrt();
203
204            if norm < tolerance {
205                break;
206            }
207
208            let v_normalized = &v_new / norm;
209            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
210            v = v_normalized;
211
212            if diff < tolerance {
213                break;
214            }
215        }
216
217        let eigenvalue = v.dot(&matrix.dot(&v));
218        Ok(eigenvalue.abs())
219    }
220}
221
222impl Estimator for EnsembleNystroem<Untrained> {
223    type Config = ();
224    type Error = SklearsError;
225    type Float = Float;
226
227    fn config(&self) -> &Self::Config {
228        &()
229    }
230}
231
232impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
233    type Fitted = EnsembleNystroem<Trained>;
234
235    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
236        if self.n_estimators == 0 {
237            return Err(SklearsError::InvalidInput(
238                "n_estimators must be positive".to_string(),
239            ));
240        }
241
242        if self.n_components == 0 {
243            return Err(SklearsError::InvalidInput(
244                "n_components must be positive".to_string(),
245            ));
246        }
247
248        let mut rng = if let Some(seed) = self.random_state {
249            RealStdRng::seed_from_u64(seed)
250        } else {
251            RealStdRng::from_seed(thread_rng().gen())
252        };
253
254        let sampling_strategies = self.generate_sampling_strategies();
255        let mut estimators = Vec::new();
256        let mut quality_scores = Vec::new();
257
258        // Train base estimators
259        for i in 0..self.n_estimators {
260            let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
261            let seed = if self.random_state.is_some() {
262                // Use deterministic seed sequence for reproducibility
263                self.random_state.unwrap().wrapping_add(i as u64)
264            } else {
265                rng.gen::<u64>()
266            };
267
268            let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
269                .sampling_strategy(strategy)
270                .random_state(seed);
271
272            let fitted_nystroem = nystroem.fit(x, &())?;
273
274            // Compute quality score
275            let quality = self.compute_quality_score(&fitted_nystroem, x)?;
276            quality_scores.push(quality);
277            estimators.push(fitted_nystroem);
278        }
279
280        // Compute weights based on ensemble method
281        let weights = match self.ensemble_method {
282            EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
283            EnsembleMethod::WeightedAverage => {
284                let total_quality: Float = quality_scores.iter().sum();
285                if total_quality > 0.0 {
286                    quality_scores.iter().map(|&q| q / total_quality).collect()
287                } else {
288                    vec![1.0 / self.n_estimators as Float; self.n_estimators]
289                }
290            }
291            EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
292            EnsembleMethod::BestApproximation => {
293                let best_idx = quality_scores
294                    .iter()
295                    .enumerate()
296                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
297                    .map(|(idx, _)| idx)
298                    .unwrap_or(0);
299                let mut weights = vec![0.0; self.n_estimators];
300                weights[best_idx] = 1.0;
301                weights
302            }
303        };
304
305        // Determine output feature size
306        let n_features_out = match self.ensemble_method {
307            EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
308            _ => self.n_components,
309        };
310
311        Ok(EnsembleNystroem {
312            kernel: self.kernel,
313            n_estimators: self.n_estimators,
314            n_components: self.n_components,
315            ensemble_method: self.ensemble_method,
316            sampling_strategies: self.sampling_strategies,
317            quality_metric: self.quality_metric,
318            random_state: self.random_state,
319            estimators_: Some(estimators),
320            weights_: Some(weights),
321            n_features_out_: Some(n_features_out),
322            _state: PhantomData,
323        })
324    }
325}
326
327impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
328    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
329        let estimators = self.estimators_.as_ref().unwrap();
330        let weights = self.weights_.as_ref().unwrap();
331        let n_features_out = self.n_features_out_.unwrap();
332        let (n_samples, _) = x.dim();
333
334        match self.ensemble_method {
335            EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
336                let mut result = Array2::zeros((n_samples, self.n_components));
337
338                for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
339                    if weight > 0.0 {
340                        let transformed = estimator.transform(x)?;
341                        result = result + &(transformed * weight);
342                    }
343                }
344
345                Ok(result)
346            }
347            EnsembleMethod::Concatenate => {
348                let mut result = Array2::zeros((n_samples, n_features_out));
349                let mut col_offset = 0;
350
351                for estimator in estimators.iter() {
352                    let transformed = estimator.transform(x)?;
353                    let n_cols = transformed.ncols();
354                    result
355                        .slice_mut(s![.., col_offset..col_offset + n_cols])
356                        .assign(&transformed);
357                    col_offset += n_cols;
358                }
359
360                Ok(result)
361            }
362            EnsembleMethod::BestApproximation => {
363                let best_idx = weights
364                    .iter()
365                    .enumerate()
366                    .find(|(_, &w)| w > 0.0)
367                    .map(|(idx, _)| idx)
368                    .unwrap_or(0);
369
370                estimators[best_idx].transform(x)
371            }
372        }
373    }
374}
375
376impl EnsembleNystroem<Trained> {
377    /// Get the base estimators
378    pub fn estimators(&self) -> &[Nystroem<Trained>] {
379        self.estimators_.as_ref().unwrap()
380    }
381
382    /// Get the estimator weights
383    pub fn weights(&self) -> &[Float] {
384        self.weights_.as_ref().unwrap()
385    }
386
387    /// Get the number of output features
388    pub fn n_features_out(&self) -> usize {
389        self.n_features_out_.unwrap()
390    }
391
392    /// Get quality scores for all estimators
393    pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
394        let estimators = self.estimators_.as_ref().unwrap();
395        let mut scores = Vec::new();
396
397        for estimator in estimators.iter() {
398            let score = self.compute_quality_score_for_estimator(estimator, x)?;
399            scores.push(score);
400        }
401
402        Ok(scores)
403    }
404
405    /// Compute quality score for a single estimator (helper method)
406    fn compute_quality_score_for_estimator(
407        &self,
408        estimator: &Nystroem<Trained>,
409        x: &Array2<Float>,
410    ) -> Result<Float> {
411        match self.quality_metric {
412            QualityMetric::FrobeniusNorm => {
413                let components = estimator.components();
414                let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
415                Ok(norm)
416            }
417            QualityMetric::Trace => {
418                let components = estimator.components();
419                let kernel_matrix = self.kernel.compute_kernel(components, components);
420                Ok(kernel_matrix.diag().sum())
421            }
422            QualityMetric::SpectralNorm => {
423                let components = estimator.components();
424                let kernel_matrix = self.kernel.compute_kernel(components, components);
425                self.power_iteration_spectral_norm(&kernel_matrix)
426            }
427            QualityMetric::NuclearNorm => {
428                let components = estimator.components();
429                let kernel_matrix = self.kernel.compute_kernel(components, components);
430                Ok(kernel_matrix.diag().sum())
431            }
432        }
433    }
434
435    /// Approximate spectral norm using power iteration
436    fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
437        let n = matrix.nrows();
438        if n == 0 {
439            return Ok(0.0);
440        }
441
442        let mut v = Array1::ones(n) / (n as Float).sqrt();
443        let max_iter = 100;
444        let tolerance = 1e-6;
445
446        for _ in 0..max_iter {
447            let v_new = matrix.dot(&v);
448            let norm = (v_new.dot(&v_new)).sqrt();
449
450            if norm < tolerance {
451                break;
452            }
453
454            let v_normalized = &v_new / norm;
455            let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
456            v = v_normalized;
457
458            if diff < tolerance {
459                break;
460            }
461        }
462
463        let eigenvalue = v.dot(&matrix.dot(&v));
464        Ok(eigenvalue.abs())
465    }
466}
467
468// Add ndarray slice import
469use scirs2_core::ndarray::s;
470
471#[allow(non_snake_case)]
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use scirs2_core::ndarray::array;
476
477    #[test]
478    fn test_ensemble_nystroem_basic() {
479        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
480
481        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
482        let fitted = ensemble.fit(&x, &()).unwrap();
483        let x_transformed = fitted.transform(&x).unwrap();
484
485        assert_eq!(x_transformed.nrows(), 4);
486        assert_eq!(x_transformed.ncols(), 2); // n_components
487    }
488
489    #[test]
490    fn test_ensemble_nystroem_average() {
491        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
492
493        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
494            .ensemble_method(EnsembleMethod::Average);
495        let fitted = ensemble.fit(&x, &()).unwrap();
496        let x_transformed = fitted.transform(&x).unwrap();
497
498        assert_eq!(x_transformed.shape(), &[3, 3]);
499    }
500
501    #[test]
502    fn test_ensemble_nystroem_concatenate() {
503        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
504
505        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
506            .ensemble_method(EnsembleMethod::Concatenate);
507        let fitted = ensemble.fit(&x, &()).unwrap();
508        let x_transformed = fitted.transform(&x).unwrap();
509
510        assert_eq!(x_transformed.shape(), &[3, 6]); // 2 estimators * 3 components = 6
511    }
512
513    #[test]
514    fn test_ensemble_nystroem_weighted_average() {
515        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
516
517        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
518            .ensemble_method(EnsembleMethod::WeightedAverage);
519        let fitted = ensemble.fit(&x, &()).unwrap();
520        let x_transformed = fitted.transform(&x).unwrap();
521
522        assert_eq!(x_transformed.shape(), &[4, 2]);
523
524        // Check that weights sum to 1 (approximately)
525        let weights = fitted.weights();
526        let weight_sum: Float = weights.iter().sum();
527        assert!((weight_sum - 1.0).abs() < 1e-6);
528    }
529
530    #[test]
531    fn test_ensemble_nystroem_best_approximation() {
532        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
533
534        let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
535            .ensemble_method(EnsembleMethod::BestApproximation);
536        let fitted = ensemble.fit(&x, &()).unwrap();
537        let x_transformed = fitted.transform(&x).unwrap();
538
539        assert_eq!(x_transformed.shape(), &[3, 2]);
540
541        // Check that exactly one weight is 1.0 and others are 0.0
542        let weights = fitted.weights();
543        let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
544        assert_eq!(active_weights.len(), 1);
545        assert!((active_weights[0] - 1.0).abs() < 1e-10);
546    }
547
548    #[test]
549    fn test_ensemble_nystroem_custom_strategies() {
550        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
551
552        let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
553
554        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
555        let fitted = ensemble.fit(&x, &()).unwrap();
556        let x_transformed = fitted.transform(&x).unwrap();
557
558        assert_eq!(x_transformed.shape(), &[4, 3]);
559        assert_eq!(fitted.estimators().len(), 2);
560    }
561
562    #[test]
563    fn test_ensemble_nystroem_reproducibility() {
564        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
565
566        let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
567        let fitted1 = ensemble1.fit(&x, &()).unwrap();
568        let result1 = fitted1.transform(&x).unwrap();
569
570        let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
571        let fitted2 = ensemble2.fit(&x, &()).unwrap();
572        let result2 = fitted2.transform(&x).unwrap();
573
574        // Results should be very similar with same random state (allowing for numerical precision)
575        assert_eq!(result1.shape(), result2.shape());
576        for (a, b) in result1.iter().zip(result2.iter()) {
577            assert!(
578                (a - b).abs() < 1e-6,
579                "Values differ too much: {} vs {}",
580                a,
581                b
582            );
583        }
584    }
585
586    #[test]
587    fn test_ensemble_nystroem_quality_metrics() {
588        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
589
590        let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
591            .quality_metric(QualityMetric::Trace);
592        let fitted = ensemble.fit(&x, &()).unwrap();
593        let quality_scores = fitted.quality_scores(&x).unwrap();
594
595        assert_eq!(quality_scores.len(), 2);
596        for score in quality_scores.iter() {
597            assert!(score.is_finite());
598            assert!(*score >= 0.0);
599        }
600    }
601
602    #[test]
603    fn test_ensemble_nystroem_invalid_parameters() {
604        let x = array![[1.0, 2.0]];
605
606        // Zero estimators
607        let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
608        assert!(ensemble.fit(&x, &()).is_err());
609
610        // Zero components
611        let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
612        assert!(ensemble.fit(&x, &()).is_err());
613    }
614}