1use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Rng;
6use sklears_core::{
7 error::{Result, SklearsError},
8 traits::{Estimator, Fit, Trained, Transform, Untrained},
9 types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, SeedableRng};
14#[derive(Debug, Clone)]
16pub enum EnsembleMethod {
17 Average,
19 WeightedAverage,
21 Concatenate,
23 BestApproximation,
25}
26
27#[derive(Debug, Clone)]
29pub enum QualityMetric {
30 FrobeniusNorm,
32 Trace,
34 SpectralNorm,
36 NuclearNorm,
38}
39
40#[derive(Debug, Clone)]
71pub struct EnsembleNystroem<State = Untrained> {
72 pub kernel: Kernel,
74 pub n_estimators: usize,
76 pub n_components: usize,
78 pub ensemble_method: EnsembleMethod,
80 pub sampling_strategies: Option<Vec<SamplingStrategy>>,
82 pub quality_metric: QualityMetric,
84 pub random_state: Option<u64>,
86
87 estimators_: Option<Vec<Nystroem<Trained>>>,
89 weights_: Option<Vec<Float>>,
90 n_features_out_: Option<usize>,
91
92 _state: PhantomData<State>,
93}
94
95impl EnsembleNystroem<Untrained> {
96 pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
97 Self {
98 kernel,
99 n_estimators,
100 n_components,
101 ensemble_method: EnsembleMethod::WeightedAverage,
102 sampling_strategies: None,
103 quality_metric: QualityMetric::FrobeniusNorm,
104 random_state: None,
105 estimators_: None,
106 weights_: None,
107 n_features_out_: None,
108 _state: PhantomData,
109 }
110 }
111
112 pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
114 self.ensemble_method = method;
115 self
116 }
117
118 pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
120 self.sampling_strategies = Some(strategies);
121 self
122 }
123
124 pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
126 self.quality_metric = metric;
127 self
128 }
129
130 pub fn random_state(mut self, seed: u64) -> Self {
132 self.random_state = Some(seed);
133 self
134 }
135
136 fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
138 if let Some(ref strategies) = self.sampling_strategies {
139 strategies.clone()
140 } else {
141 let mut strategies = Vec::new();
143 let base_strategies = [
144 SamplingStrategy::Random,
145 SamplingStrategy::KMeans,
146 SamplingStrategy::LeverageScore,
147 SamplingStrategy::ColumnNorm,
148 ];
149
150 for i in 0..self.n_estimators {
151 strategies.push(base_strategies[i % base_strategies.len()].clone());
152 }
153 strategies
154 }
155 }
156
157 fn compute_quality_score(
159 &self,
160 estimator: &Nystroem<Trained>,
161 _x: &Array2<Float>,
162 ) -> Result<Float> {
163 match self.quality_metric {
164 QualityMetric::FrobeniusNorm => {
165 let components = estimator.components();
167 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
168 Ok(norm)
169 }
170 QualityMetric::Trace => {
171 let components = estimator.components();
172 let kernel_matrix = self.kernel.compute_kernel(components, components);
173 Ok(kernel_matrix.diag().sum())
174 }
175 QualityMetric::SpectralNorm => {
176 let components = estimator.components();
178 let kernel_matrix = self.kernel.compute_kernel(components, components);
179 self.power_iteration_spectral_norm(&kernel_matrix)
180 }
181 QualityMetric::NuclearNorm => {
182 let components = estimator.components();
183 let kernel_matrix = self.kernel.compute_kernel(components, components);
184 Ok(kernel_matrix.diag().sum())
186 }
187 }
188 }
189
190 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
192 let n = matrix.nrows();
193 if n == 0 {
194 return Ok(0.0);
195 }
196
197 let mut v = Array1::ones(n) / (n as Float).sqrt();
198 let max_iter = 100;
199 let tolerance = 1e-6;
200
201 for _ in 0..max_iter {
202 let v_new = matrix.dot(&v);
203 let norm = (v_new.dot(&v_new)).sqrt();
204
205 if norm < tolerance {
206 break;
207 }
208
209 let v_normalized = &v_new / norm;
210 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
211 v = v_normalized;
212
213 if diff < tolerance {
214 break;
215 }
216 }
217
218 let eigenvalue = v.dot(&matrix.dot(&v));
219 Ok(eigenvalue.abs())
220 }
221}
222
223impl Estimator for EnsembleNystroem<Untrained> {
224 type Config = ();
225 type Error = SklearsError;
226 type Float = Float;
227
228 fn config(&self) -> &Self::Config {
229 &()
230 }
231}
232
233impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
234 type Fitted = EnsembleNystroem<Trained>;
235
236 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
237 if self.n_estimators == 0 {
238 return Err(SklearsError::InvalidInput(
239 "n_estimators must be positive".to_string(),
240 ));
241 }
242
243 if self.n_components == 0 {
244 return Err(SklearsError::InvalidInput(
245 "n_components must be positive".to_string(),
246 ));
247 }
248
249 let mut rng = if let Some(seed) = self.random_state {
250 RealStdRng::seed_from_u64(seed)
251 } else {
252 RealStdRng::from_seed(thread_rng().gen())
253 };
254
255 let sampling_strategies = self.generate_sampling_strategies();
256 let mut estimators = Vec::new();
257 let mut quality_scores = Vec::new();
258
259 for i in 0..self.n_estimators {
261 let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
262 let seed = if self.random_state.is_some() {
263 self.random_state.unwrap().wrapping_add(i as u64)
265 } else {
266 rng.gen::<u64>()
267 };
268
269 let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
270 .sampling_strategy(strategy)
271 .random_state(seed);
272
273 let fitted_nystroem = nystroem.fit(x, &())?;
274
275 let quality = self.compute_quality_score(&fitted_nystroem, x)?;
277 quality_scores.push(quality);
278 estimators.push(fitted_nystroem);
279 }
280
281 let weights = match self.ensemble_method {
283 EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
284 EnsembleMethod::WeightedAverage => {
285 let total_quality: Float = quality_scores.iter().sum();
286 if total_quality > 0.0 {
287 quality_scores.iter().map(|&q| q / total_quality).collect()
288 } else {
289 vec![1.0 / self.n_estimators as Float; self.n_estimators]
290 }
291 }
292 EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
293 EnsembleMethod::BestApproximation => {
294 let best_idx = quality_scores
295 .iter()
296 .enumerate()
297 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
298 .map(|(idx, _)| idx)
299 .unwrap_or(0);
300 let mut weights = vec![0.0; self.n_estimators];
301 weights[best_idx] = 1.0;
302 weights
303 }
304 };
305
306 let n_features_out = match self.ensemble_method {
308 EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
309 _ => self.n_components,
310 };
311
312 Ok(EnsembleNystroem {
313 kernel: self.kernel,
314 n_estimators: self.n_estimators,
315 n_components: self.n_components,
316 ensemble_method: self.ensemble_method,
317 sampling_strategies: self.sampling_strategies,
318 quality_metric: self.quality_metric,
319 random_state: self.random_state,
320 estimators_: Some(estimators),
321 weights_: Some(weights),
322 n_features_out_: Some(n_features_out),
323 _state: PhantomData,
324 })
325 }
326}
327
328impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
329 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
330 let estimators = self.estimators_.as_ref().unwrap();
331 let weights = self.weights_.as_ref().unwrap();
332 let n_features_out = self.n_features_out_.unwrap();
333 let (n_samples, _) = x.dim();
334
335 match self.ensemble_method {
336 EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
337 let mut result = Array2::zeros((n_samples, self.n_components));
338
339 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
340 if weight > 0.0 {
341 let transformed = estimator.transform(x)?;
342 result += &(transformed * weight);
343 }
344 }
345
346 Ok(result)
347 }
348 EnsembleMethod::Concatenate => {
349 let mut result = Array2::zeros((n_samples, n_features_out));
350 let mut col_offset = 0;
351
352 for estimator in estimators.iter() {
353 let transformed = estimator.transform(x)?;
354 let n_cols = transformed.ncols();
355 result
356 .slice_mut(s![.., col_offset..col_offset + n_cols])
357 .assign(&transformed);
358 col_offset += n_cols;
359 }
360
361 Ok(result)
362 }
363 EnsembleMethod::BestApproximation => {
364 let best_idx = weights
365 .iter()
366 .enumerate()
367 .find(|(_, &w)| w > 0.0)
368 .map(|(idx, _)| idx)
369 .unwrap_or(0);
370
371 estimators[best_idx].transform(x)
372 }
373 }
374 }
375}
376
377impl EnsembleNystroem<Trained> {
378 pub fn estimators(&self) -> &[Nystroem<Trained>] {
380 self.estimators_.as_ref().unwrap()
381 }
382
383 pub fn weights(&self) -> &[Float] {
385 self.weights_.as_ref().unwrap()
386 }
387
388 pub fn n_features_out(&self) -> usize {
390 self.n_features_out_.unwrap()
391 }
392
393 pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
395 let estimators = self.estimators_.as_ref().unwrap();
396 let mut scores = Vec::new();
397
398 for estimator in estimators.iter() {
399 let score = self.compute_quality_score_for_estimator(estimator, x)?;
400 scores.push(score);
401 }
402
403 Ok(scores)
404 }
405
406 fn compute_quality_score_for_estimator(
408 &self,
409 estimator: &Nystroem<Trained>,
410 _x: &Array2<Float>,
411 ) -> Result<Float> {
412 match self.quality_metric {
413 QualityMetric::FrobeniusNorm => {
414 let components = estimator.components();
415 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
416 Ok(norm)
417 }
418 QualityMetric::Trace => {
419 let components = estimator.components();
420 let kernel_matrix = self.kernel.compute_kernel(components, components);
421 Ok(kernel_matrix.diag().sum())
422 }
423 QualityMetric::SpectralNorm => {
424 let components = estimator.components();
425 let kernel_matrix = self.kernel.compute_kernel(components, components);
426 self.power_iteration_spectral_norm(&kernel_matrix)
427 }
428 QualityMetric::NuclearNorm => {
429 let components = estimator.components();
430 let kernel_matrix = self.kernel.compute_kernel(components, components);
431 Ok(kernel_matrix.diag().sum())
432 }
433 }
434 }
435
436 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
438 let n = matrix.nrows();
439 if n == 0 {
440 return Ok(0.0);
441 }
442
443 let mut v = Array1::ones(n) / (n as Float).sqrt();
444 let max_iter = 100;
445 let tolerance = 1e-6;
446
447 for _ in 0..max_iter {
448 let v_new = matrix.dot(&v);
449 let norm = (v_new.dot(&v_new)).sqrt();
450
451 if norm < tolerance {
452 break;
453 }
454
455 let v_normalized = &v_new / norm;
456 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
457 v = v_normalized;
458
459 if diff < tolerance {
460 break;
461 }
462 }
463
464 let eigenvalue = v.dot(&matrix.dot(&v));
465 Ok(eigenvalue.abs())
466 }
467}
468
469use scirs2_core::ndarray::s;
471
472#[allow(non_snake_case)]
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use scirs2_core::ndarray::array;
477
478 #[test]
479 fn test_ensemble_nystroem_basic() {
480 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
481
482 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
483 let fitted = ensemble.fit(&x, &()).unwrap();
484 let x_transformed = fitted.transform(&x).unwrap();
485
486 assert_eq!(x_transformed.nrows(), 4);
487 assert_eq!(x_transformed.ncols(), 2); }
489
490 #[test]
491 fn test_ensemble_nystroem_average() {
492 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
493
494 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
495 .ensemble_method(EnsembleMethod::Average);
496 let fitted = ensemble.fit(&x, &()).unwrap();
497 let x_transformed = fitted.transform(&x).unwrap();
498
499 assert_eq!(x_transformed.shape(), &[3, 3]);
500 }
501
502 #[test]
503 fn test_ensemble_nystroem_concatenate() {
504 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
505
506 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
507 .ensemble_method(EnsembleMethod::Concatenate);
508 let fitted = ensemble.fit(&x, &()).unwrap();
509 let x_transformed = fitted.transform(&x).unwrap();
510
511 assert_eq!(x_transformed.shape(), &[3, 6]); }
513
514 #[test]
515 fn test_ensemble_nystroem_weighted_average() {
516 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
517
518 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
519 .ensemble_method(EnsembleMethod::WeightedAverage);
520 let fitted = ensemble.fit(&x, &()).unwrap();
521 let x_transformed = fitted.transform(&x).unwrap();
522
523 assert_eq!(x_transformed.shape(), &[4, 2]);
524
525 let weights = fitted.weights();
527 let weight_sum: Float = weights.iter().sum();
528 assert!((weight_sum - 1.0).abs() < 1e-6);
529 }
530
531 #[test]
532 fn test_ensemble_nystroem_best_approximation() {
533 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
534
535 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
536 .ensemble_method(EnsembleMethod::BestApproximation);
537 let fitted = ensemble.fit(&x, &()).unwrap();
538 let x_transformed = fitted.transform(&x).unwrap();
539
540 assert_eq!(x_transformed.shape(), &[3, 2]);
541
542 let weights = fitted.weights();
544 let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
545 assert_eq!(active_weights.len(), 1);
546 assert!((active_weights[0] - 1.0).abs() < 1e-10);
547 }
548
549 #[test]
550 fn test_ensemble_nystroem_custom_strategies() {
551 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
552
553 let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
554
555 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
556 let fitted = ensemble.fit(&x, &()).unwrap();
557 let x_transformed = fitted.transform(&x).unwrap();
558
559 assert_eq!(x_transformed.shape(), &[4, 3]);
560 assert_eq!(fitted.estimators().len(), 2);
561 }
562
563 #[test]
564 fn test_ensemble_nystroem_reproducibility() {
565 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
566
567 let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
568 let fitted1 = ensemble1.fit(&x, &()).unwrap();
569 let result1 = fitted1.transform(&x).unwrap();
570
571 let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
572 let fitted2 = ensemble2.fit(&x, &()).unwrap();
573 let result2 = fitted2.transform(&x).unwrap();
574
575 assert_eq!(result1.shape(), result2.shape());
577 for (a, b) in result1.iter().zip(result2.iter()) {
578 assert!(
579 (a - b).abs() < 1e-6,
580 "Values differ too much: {} vs {}",
581 a,
582 b
583 );
584 }
585 }
586
587 #[test]
588 fn test_ensemble_nystroem_quality_metrics() {
589 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
590
591 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
592 .quality_metric(QualityMetric::Trace);
593 let fitted = ensemble.fit(&x, &()).unwrap();
594 let quality_scores = fitted.quality_scores(&x).unwrap();
595
596 assert_eq!(quality_scores.len(), 2);
597 for score in quality_scores.iter() {
598 assert!(score.is_finite());
599 assert!(*score >= 0.0);
600 }
601 }
602
603 #[test]
604 fn test_ensemble_nystroem_invalid_parameters() {
605 let x = array![[1.0, 2.0]];
606
607 let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
609 assert!(ensemble.fit(&x, &()).is_err());
610
611 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
613 assert!(ensemble.fit(&x, &()).is_err());
614 }
615}