1use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use sklears_core::{
6 error::{Result, SklearsError},
7 traits::{Estimator, Fit, Trained, Transform, Untrained},
8 types::Float,
9};
10use std::marker::PhantomData;
11
12use scirs2_core::random::RngExt;
13use scirs2_core::random::{thread_rng, SeedableRng};
14#[derive(Debug, Clone)]
16pub enum EnsembleMethod {
17 Average,
19 WeightedAverage,
21 Concatenate,
23 BestApproximation,
25}
26
27#[derive(Debug, Clone)]
29pub enum QualityMetric {
30 FrobeniusNorm,
32 Trace,
34 SpectralNorm,
36 NuclearNorm,
38}
39
40#[derive(Debug, Clone)]
71pub struct EnsembleNystroem<State = Untrained> {
72 pub kernel: Kernel,
74 pub n_estimators: usize,
76 pub n_components: usize,
78 pub ensemble_method: EnsembleMethod,
80 pub sampling_strategies: Option<Vec<SamplingStrategy>>,
82 pub quality_metric: QualityMetric,
84 pub random_state: Option<u64>,
86
87 estimators_: Option<Vec<Nystroem<Trained>>>,
89 weights_: Option<Vec<Float>>,
90 n_features_out_: Option<usize>,
91
92 _state: PhantomData<State>,
93}
94
95impl EnsembleNystroem<Untrained> {
96 pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
97 Self {
98 kernel,
99 n_estimators,
100 n_components,
101 ensemble_method: EnsembleMethod::WeightedAverage,
102 sampling_strategies: None,
103 quality_metric: QualityMetric::FrobeniusNorm,
104 random_state: None,
105 estimators_: None,
106 weights_: None,
107 n_features_out_: None,
108 _state: PhantomData,
109 }
110 }
111
112 pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
114 self.ensemble_method = method;
115 self
116 }
117
118 pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
120 self.sampling_strategies = Some(strategies);
121 self
122 }
123
124 pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
126 self.quality_metric = metric;
127 self
128 }
129
130 pub fn random_state(mut self, seed: u64) -> Self {
132 self.random_state = Some(seed);
133 self
134 }
135
136 fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
138 if let Some(ref strategies) = self.sampling_strategies {
139 strategies.clone()
140 } else {
141 let mut strategies = Vec::new();
143 let base_strategies = [
144 SamplingStrategy::Random,
145 SamplingStrategy::KMeans,
146 SamplingStrategy::LeverageScore,
147 SamplingStrategy::ColumnNorm,
148 ];
149
150 for i in 0..self.n_estimators {
151 strategies.push(base_strategies[i % base_strategies.len()].clone());
152 }
153 strategies
154 }
155 }
156
157 fn compute_quality_score(
159 &self,
160 estimator: &Nystroem<Trained>,
161 _x: &Array2<Float>,
162 ) -> Result<Float> {
163 match self.quality_metric {
164 QualityMetric::FrobeniusNorm => {
165 let components = estimator.components();
167 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
168 Ok(norm)
169 }
170 QualityMetric::Trace => {
171 let components = estimator.components();
172 let kernel_matrix = self.kernel.compute_kernel(components, components);
173 Ok(kernel_matrix.diag().sum())
174 }
175 QualityMetric::SpectralNorm => {
176 let components = estimator.components();
178 let kernel_matrix = self.kernel.compute_kernel(components, components);
179 self.power_iteration_spectral_norm(&kernel_matrix)
180 }
181 QualityMetric::NuclearNorm => {
182 let components = estimator.components();
183 let kernel_matrix = self.kernel.compute_kernel(components, components);
184 Ok(kernel_matrix.diag().sum())
186 }
187 }
188 }
189
190 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
192 let n = matrix.nrows();
193 if n == 0 {
194 return Ok(0.0);
195 }
196
197 let mut v = Array1::ones(n) / (n as Float).sqrt();
198 let max_iter = 100;
199 let tolerance = 1e-6;
200
201 for _ in 0..max_iter {
202 let v_new = matrix.dot(&v);
203 let norm = (v_new.dot(&v_new)).sqrt();
204
205 if norm < tolerance {
206 break;
207 }
208
209 let v_normalized = &v_new / norm;
210 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
211 v = v_normalized;
212
213 if diff < tolerance {
214 break;
215 }
216 }
217
218 let eigenvalue = v.dot(&matrix.dot(&v));
219 Ok(eigenvalue.abs())
220 }
221}
222
223impl Estimator for EnsembleNystroem<Untrained> {
224 type Config = ();
225 type Error = SklearsError;
226 type Float = Float;
227
228 fn config(&self) -> &Self::Config {
229 &()
230 }
231}
232
233impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
234 type Fitted = EnsembleNystroem<Trained>;
235
236 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
237 if self.n_estimators == 0 {
238 return Err(SklearsError::InvalidInput(
239 "n_estimators must be positive".to_string(),
240 ));
241 }
242
243 if self.n_components == 0 {
244 return Err(SklearsError::InvalidInput(
245 "n_components must be positive".to_string(),
246 ));
247 }
248
249 let mut rng = if let Some(seed) = self.random_state {
250 RealStdRng::seed_from_u64(seed)
251 } else {
252 RealStdRng::from_seed(thread_rng().random())
253 };
254
255 let sampling_strategies = self.generate_sampling_strategies();
256 let mut estimators = Vec::new();
257 let mut quality_scores = Vec::new();
258
259 for i in 0..self.n_estimators {
261 let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
262 let seed = if self.random_state.is_some() {
263 self.random_state
265 .expect("operation should succeed")
266 .wrapping_add(i as u64)
267 } else {
268 rng.random::<u64>()
269 };
270
271 let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
272 .sampling_strategy(strategy)
273 .random_state(seed);
274
275 let fitted_nystroem = nystroem.fit(x, &())?;
276
277 let quality = self.compute_quality_score(&fitted_nystroem, x)?;
279 quality_scores.push(quality);
280 estimators.push(fitted_nystroem);
281 }
282
283 let weights = match self.ensemble_method {
285 EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
286 EnsembleMethod::WeightedAverage => {
287 let total_quality: Float = quality_scores.iter().sum();
288 if total_quality > 0.0 {
289 quality_scores.iter().map(|&q| q / total_quality).collect()
290 } else {
291 vec![1.0 / self.n_estimators as Float; self.n_estimators]
292 }
293 }
294 EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
295 EnsembleMethod::BestApproximation => {
296 let best_idx = quality_scores
297 .iter()
298 .enumerate()
299 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("operation should succeed"))
300 .map(|(idx, _)| idx)
301 .unwrap_or(0);
302 let mut weights = vec![0.0; self.n_estimators];
303 weights[best_idx] = 1.0;
304 weights
305 }
306 };
307
308 let n_features_out = match self.ensemble_method {
310 EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
311 _ => self.n_components,
312 };
313
314 Ok(EnsembleNystroem {
315 kernel: self.kernel,
316 n_estimators: self.n_estimators,
317 n_components: self.n_components,
318 ensemble_method: self.ensemble_method,
319 sampling_strategies: self.sampling_strategies,
320 quality_metric: self.quality_metric,
321 random_state: self.random_state,
322 estimators_: Some(estimators),
323 weights_: Some(weights),
324 n_features_out_: Some(n_features_out),
325 _state: PhantomData,
326 })
327 }
328}
329
330impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
331 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
332 let estimators = self.estimators_.as_ref().expect("operation should succeed");
333 let weights = self.weights_.as_ref().expect("operation should succeed");
334 let n_features_out = self.n_features_out_.expect("operation should succeed");
335 let (n_samples, _) = x.dim();
336
337 match self.ensemble_method {
338 EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
339 let mut result = Array2::zeros((n_samples, self.n_components));
340
341 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
342 if weight > 0.0 {
343 let transformed = estimator.transform(x)?;
344 result += &(transformed * weight);
345 }
346 }
347
348 Ok(result)
349 }
350 EnsembleMethod::Concatenate => {
351 let mut result = Array2::zeros((n_samples, n_features_out));
352 let mut col_offset = 0;
353
354 for estimator in estimators.iter() {
355 let transformed = estimator.transform(x)?;
356 let n_cols = transformed.ncols();
357 result
358 .slice_mut(s![.., col_offset..col_offset + n_cols])
359 .assign(&transformed);
360 col_offset += n_cols;
361 }
362
363 Ok(result)
364 }
365 EnsembleMethod::BestApproximation => {
366 let best_idx = weights
367 .iter()
368 .enumerate()
369 .find(|(_, &w)| w > 0.0)
370 .map(|(idx, _)| idx)
371 .unwrap_or(0);
372
373 estimators[best_idx].transform(x)
374 }
375 }
376 }
377}
378
379impl EnsembleNystroem<Trained> {
380 pub fn estimators(&self) -> &[Nystroem<Trained>] {
382 self.estimators_.as_ref().expect("operation should succeed")
383 }
384
385 pub fn weights(&self) -> &[Float] {
387 self.weights_.as_ref().expect("operation should succeed")
388 }
389
390 pub fn n_features_out(&self) -> usize {
392 self.n_features_out_.expect("operation should succeed")
393 }
394
395 pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
397 let estimators = self.estimators_.as_ref().expect("operation should succeed");
398 let mut scores = Vec::new();
399
400 for estimator in estimators.iter() {
401 let score = self.compute_quality_score_for_estimator(estimator, x)?;
402 scores.push(score);
403 }
404
405 Ok(scores)
406 }
407
408 fn compute_quality_score_for_estimator(
410 &self,
411 estimator: &Nystroem<Trained>,
412 _x: &Array2<Float>,
413 ) -> Result<Float> {
414 match self.quality_metric {
415 QualityMetric::FrobeniusNorm => {
416 let components = estimator.components();
417 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
418 Ok(norm)
419 }
420 QualityMetric::Trace => {
421 let components = estimator.components();
422 let kernel_matrix = self.kernel.compute_kernel(components, components);
423 Ok(kernel_matrix.diag().sum())
424 }
425 QualityMetric::SpectralNorm => {
426 let components = estimator.components();
427 let kernel_matrix = self.kernel.compute_kernel(components, components);
428 self.power_iteration_spectral_norm(&kernel_matrix)
429 }
430 QualityMetric::NuclearNorm => {
431 let components = estimator.components();
432 let kernel_matrix = self.kernel.compute_kernel(components, components);
433 Ok(kernel_matrix.diag().sum())
434 }
435 }
436 }
437
438 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
440 let n = matrix.nrows();
441 if n == 0 {
442 return Ok(0.0);
443 }
444
445 let mut v = Array1::ones(n) / (n as Float).sqrt();
446 let max_iter = 100;
447 let tolerance = 1e-6;
448
449 for _ in 0..max_iter {
450 let v_new = matrix.dot(&v);
451 let norm = (v_new.dot(&v_new)).sqrt();
452
453 if norm < tolerance {
454 break;
455 }
456
457 let v_normalized = &v_new / norm;
458 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
459 v = v_normalized;
460
461 if diff < tolerance {
462 break;
463 }
464 }
465
466 let eigenvalue = v.dot(&matrix.dot(&v));
467 Ok(eigenvalue.abs())
468 }
469}
470
471use scirs2_core::ndarray::s;
473
474#[allow(non_snake_case)]
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use scirs2_core::ndarray::array;
479
480 #[test]
481 fn test_ensemble_nystroem_basic() {
482 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
483
484 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
485 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
486 let x_transformed = fitted.transform(&x).expect("operation should succeed");
487
488 assert_eq!(x_transformed.nrows(), 4);
489 assert_eq!(x_transformed.ncols(), 2); }
491
492 #[test]
493 fn test_ensemble_nystroem_average() {
494 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
495
496 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
497 .ensemble_method(EnsembleMethod::Average);
498 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
499 let x_transformed = fitted.transform(&x).expect("operation should succeed");
500
501 assert_eq!(x_transformed.shape(), &[3, 3]);
502 }
503
504 #[test]
505 fn test_ensemble_nystroem_concatenate() {
506 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
507
508 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
509 .ensemble_method(EnsembleMethod::Concatenate);
510 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
511 let x_transformed = fitted.transform(&x).expect("operation should succeed");
512
513 assert_eq!(x_transformed.shape(), &[3, 6]); }
515
516 #[test]
517 fn test_ensemble_nystroem_weighted_average() {
518 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
519
520 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
521 .ensemble_method(EnsembleMethod::WeightedAverage);
522 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
523 let x_transformed = fitted.transform(&x).expect("operation should succeed");
524
525 assert_eq!(x_transformed.shape(), &[4, 2]);
526
527 let weights = fitted.weights();
529 let weight_sum: Float = weights.iter().sum();
530 assert!((weight_sum - 1.0).abs() < 1e-6);
531 }
532
533 #[test]
534 fn test_ensemble_nystroem_best_approximation() {
535 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
536
537 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
538 .ensemble_method(EnsembleMethod::BestApproximation);
539 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
540 let x_transformed = fitted.transform(&x).expect("operation should succeed");
541
542 assert_eq!(x_transformed.shape(), &[3, 2]);
543
544 let weights = fitted.weights();
546 let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
547 assert_eq!(active_weights.len(), 1);
548 assert!((active_weights[0] - 1.0).abs() < 1e-10);
549 }
550
551 #[test]
552 fn test_ensemble_nystroem_custom_strategies() {
553 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
554
555 let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
556
557 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
558 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
559 let x_transformed = fitted.transform(&x).expect("operation should succeed");
560
561 assert_eq!(x_transformed.shape(), &[4, 3]);
562 assert_eq!(fitted.estimators().len(), 2);
563 }
564
565 #[test]
566 fn test_ensemble_nystroem_reproducibility() {
567 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
568
569 let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
570 let fitted1 = ensemble1.fit(&x, &()).expect("operation should succeed");
571 let result1 = fitted1.transform(&x).expect("operation should succeed");
572
573 let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
574 let fitted2 = ensemble2.fit(&x, &()).expect("operation should succeed");
575 let result2 = fitted2.transform(&x).expect("operation should succeed");
576
577 assert_eq!(result1.shape(), result2.shape());
579 for (a, b) in result1.iter().zip(result2.iter()) {
580 assert!(
581 (a - b).abs() < 1e-6,
582 "Values differ too much: {} vs {}",
583 a,
584 b
585 );
586 }
587 }
588
589 #[test]
590 fn test_ensemble_nystroem_quality_metrics() {
591 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
592
593 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
594 .quality_metric(QualityMetric::Trace);
595 let fitted = ensemble.fit(&x, &()).expect("operation should succeed");
596 let quality_scores = fitted.quality_scores(&x).expect("operation should succeed");
597
598 assert_eq!(quality_scores.len(), 2);
599 for score in quality_scores.iter() {
600 assert!(score.is_finite());
601 assert!(*score >= 0.0);
602 }
603 }
604
605 #[test]
606 fn test_ensemble_nystroem_invalid_parameters() {
607 let x = array![[1.0, 2.0]];
608
609 let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
611 assert!(ensemble.fit(&x, &()).is_err());
612
613 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
615 assert!(ensemble.fit(&x, &()).is_err());
616 }
617}