1use crate::nystroem::{Kernel, Nystroem, SamplingStrategy};
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use sklears_core::{
6 error::{Result, SklearsError},
7 traits::{Estimator, Fit, Trained, Transform, Untrained},
8 types::Float,
9};
10use std::marker::PhantomData;
11
12use scirs2_core::random::{thread_rng, Rng, SeedableRng};
13#[derive(Debug, Clone)]
15pub enum EnsembleMethod {
16 Average,
18 WeightedAverage,
20 Concatenate,
22 BestApproximation,
24}
25
26#[derive(Debug, Clone)]
28pub enum QualityMetric {
29 FrobeniusNorm,
31 Trace,
33 SpectralNorm,
35 NuclearNorm,
37}
38
39#[derive(Debug, Clone)]
70pub struct EnsembleNystroem<State = Untrained> {
71 pub kernel: Kernel,
73 pub n_estimators: usize,
75 pub n_components: usize,
77 pub ensemble_method: EnsembleMethod,
79 pub sampling_strategies: Option<Vec<SamplingStrategy>>,
81 pub quality_metric: QualityMetric,
83 pub random_state: Option<u64>,
85
86 estimators_: Option<Vec<Nystroem<Trained>>>,
88 weights_: Option<Vec<Float>>,
89 n_features_out_: Option<usize>,
90
91 _state: PhantomData<State>,
92}
93
94impl EnsembleNystroem<Untrained> {
95 pub fn new(kernel: Kernel, n_estimators: usize, n_components: usize) -> Self {
96 Self {
97 kernel,
98 n_estimators,
99 n_components,
100 ensemble_method: EnsembleMethod::WeightedAverage,
101 sampling_strategies: None,
102 quality_metric: QualityMetric::FrobeniusNorm,
103 random_state: None,
104 estimators_: None,
105 weights_: None,
106 n_features_out_: None,
107 _state: PhantomData,
108 }
109 }
110
111 pub fn ensemble_method(mut self, method: EnsembleMethod) -> Self {
113 self.ensemble_method = method;
114 self
115 }
116
117 pub fn sampling_strategies(mut self, strategies: Vec<SamplingStrategy>) -> Self {
119 self.sampling_strategies = Some(strategies);
120 self
121 }
122
123 pub fn quality_metric(mut self, metric: QualityMetric) -> Self {
125 self.quality_metric = metric;
126 self
127 }
128
129 pub fn random_state(mut self, seed: u64) -> Self {
131 self.random_state = Some(seed);
132 self
133 }
134
135 fn generate_sampling_strategies(&self) -> Vec<SamplingStrategy> {
137 if let Some(ref strategies) = self.sampling_strategies {
138 strategies.clone()
139 } else {
140 let mut strategies = Vec::new();
142 let base_strategies = vec![
143 SamplingStrategy::Random,
144 SamplingStrategy::KMeans,
145 SamplingStrategy::LeverageScore,
146 SamplingStrategy::ColumnNorm,
147 ];
148
149 for i in 0..self.n_estimators {
150 strategies.push(base_strategies[i % base_strategies.len()].clone());
151 }
152 strategies
153 }
154 }
155
156 fn compute_quality_score(
158 &self,
159 estimator: &Nystroem<Trained>,
160 x: &Array2<Float>,
161 ) -> Result<Float> {
162 match self.quality_metric {
163 QualityMetric::FrobeniusNorm => {
164 let components = estimator.components();
166 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
167 Ok(norm)
168 }
169 QualityMetric::Trace => {
170 let components = estimator.components();
171 let kernel_matrix = self.kernel.compute_kernel(components, components);
172 Ok(kernel_matrix.diag().sum())
173 }
174 QualityMetric::SpectralNorm => {
175 let components = estimator.components();
177 let kernel_matrix = self.kernel.compute_kernel(components, components);
178 self.power_iteration_spectral_norm(&kernel_matrix)
179 }
180 QualityMetric::NuclearNorm => {
181 let components = estimator.components();
182 let kernel_matrix = self.kernel.compute_kernel(components, components);
183 Ok(kernel_matrix.diag().sum())
185 }
186 }
187 }
188
189 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
191 let n = matrix.nrows();
192 if n == 0 {
193 return Ok(0.0);
194 }
195
196 let mut v = Array1::ones(n) / (n as Float).sqrt();
197 let max_iter = 100;
198 let tolerance = 1e-6;
199
200 for _ in 0..max_iter {
201 let v_new = matrix.dot(&v);
202 let norm = (v_new.dot(&v_new)).sqrt();
203
204 if norm < tolerance {
205 break;
206 }
207
208 let v_normalized = &v_new / norm;
209 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
210 v = v_normalized;
211
212 if diff < tolerance {
213 break;
214 }
215 }
216
217 let eigenvalue = v.dot(&matrix.dot(&v));
218 Ok(eigenvalue.abs())
219 }
220}
221
222impl Estimator for EnsembleNystroem<Untrained> {
223 type Config = ();
224 type Error = SklearsError;
225 type Float = Float;
226
227 fn config(&self) -> &Self::Config {
228 &()
229 }
230}
231
232impl Fit<Array2<Float>, ()> for EnsembleNystroem<Untrained> {
233 type Fitted = EnsembleNystroem<Trained>;
234
235 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
236 if self.n_estimators == 0 {
237 return Err(SklearsError::InvalidInput(
238 "n_estimators must be positive".to_string(),
239 ));
240 }
241
242 if self.n_components == 0 {
243 return Err(SklearsError::InvalidInput(
244 "n_components must be positive".to_string(),
245 ));
246 }
247
248 let mut rng = if let Some(seed) = self.random_state {
249 RealStdRng::seed_from_u64(seed)
250 } else {
251 RealStdRng::from_seed(thread_rng().gen())
252 };
253
254 let sampling_strategies = self.generate_sampling_strategies();
255 let mut estimators = Vec::new();
256 let mut quality_scores = Vec::new();
257
258 for i in 0..self.n_estimators {
260 let strategy = sampling_strategies[i % sampling_strategies.len()].clone();
261 let seed = if self.random_state.is_some() {
262 self.random_state.unwrap().wrapping_add(i as u64)
264 } else {
265 rng.gen::<u64>()
266 };
267
268 let nystroem = Nystroem::new(self.kernel.clone(), self.n_components)
269 .sampling_strategy(strategy)
270 .random_state(seed);
271
272 let fitted_nystroem = nystroem.fit(x, &())?;
273
274 let quality = self.compute_quality_score(&fitted_nystroem, x)?;
276 quality_scores.push(quality);
277 estimators.push(fitted_nystroem);
278 }
279
280 let weights = match self.ensemble_method {
282 EnsembleMethod::Average => vec![1.0 / self.n_estimators as Float; self.n_estimators],
283 EnsembleMethod::WeightedAverage => {
284 let total_quality: Float = quality_scores.iter().sum();
285 if total_quality > 0.0 {
286 quality_scores.iter().map(|&q| q / total_quality).collect()
287 } else {
288 vec![1.0 / self.n_estimators as Float; self.n_estimators]
289 }
290 }
291 EnsembleMethod::Concatenate => vec![1.0; self.n_estimators],
292 EnsembleMethod::BestApproximation => {
293 let best_idx = quality_scores
294 .iter()
295 .enumerate()
296 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
297 .map(|(idx, _)| idx)
298 .unwrap_or(0);
299 let mut weights = vec![0.0; self.n_estimators];
300 weights[best_idx] = 1.0;
301 weights
302 }
303 };
304
305 let n_features_out = match self.ensemble_method {
307 EnsembleMethod::Concatenate => self.n_estimators * self.n_components,
308 _ => self.n_components,
309 };
310
311 Ok(EnsembleNystroem {
312 kernel: self.kernel,
313 n_estimators: self.n_estimators,
314 n_components: self.n_components,
315 ensemble_method: self.ensemble_method,
316 sampling_strategies: self.sampling_strategies,
317 quality_metric: self.quality_metric,
318 random_state: self.random_state,
319 estimators_: Some(estimators),
320 weights_: Some(weights),
321 n_features_out_: Some(n_features_out),
322 _state: PhantomData,
323 })
324 }
325}
326
327impl Transform<Array2<Float>, Array2<Float>> for EnsembleNystroem<Trained> {
328 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
329 let estimators = self.estimators_.as_ref().unwrap();
330 let weights = self.weights_.as_ref().unwrap();
331 let n_features_out = self.n_features_out_.unwrap();
332 let (n_samples, _) = x.dim();
333
334 match self.ensemble_method {
335 EnsembleMethod::Average | EnsembleMethod::WeightedAverage => {
336 let mut result = Array2::zeros((n_samples, self.n_components));
337
338 for (estimator, &weight) in estimators.iter().zip(weights.iter()) {
339 if weight > 0.0 {
340 let transformed = estimator.transform(x)?;
341 result = result + &(transformed * weight);
342 }
343 }
344
345 Ok(result)
346 }
347 EnsembleMethod::Concatenate => {
348 let mut result = Array2::zeros((n_samples, n_features_out));
349 let mut col_offset = 0;
350
351 for estimator in estimators.iter() {
352 let transformed = estimator.transform(x)?;
353 let n_cols = transformed.ncols();
354 result
355 .slice_mut(s![.., col_offset..col_offset + n_cols])
356 .assign(&transformed);
357 col_offset += n_cols;
358 }
359
360 Ok(result)
361 }
362 EnsembleMethod::BestApproximation => {
363 let best_idx = weights
364 .iter()
365 .enumerate()
366 .find(|(_, &w)| w > 0.0)
367 .map(|(idx, _)| idx)
368 .unwrap_or(0);
369
370 estimators[best_idx].transform(x)
371 }
372 }
373 }
374}
375
376impl EnsembleNystroem<Trained> {
377 pub fn estimators(&self) -> &[Nystroem<Trained>] {
379 self.estimators_.as_ref().unwrap()
380 }
381
382 pub fn weights(&self) -> &[Float] {
384 self.weights_.as_ref().unwrap()
385 }
386
387 pub fn n_features_out(&self) -> usize {
389 self.n_features_out_.unwrap()
390 }
391
392 pub fn quality_scores(&self, x: &Array2<Float>) -> Result<Vec<Float>> {
394 let estimators = self.estimators_.as_ref().unwrap();
395 let mut scores = Vec::new();
396
397 for estimator in estimators.iter() {
398 let score = self.compute_quality_score_for_estimator(estimator, x)?;
399 scores.push(score);
400 }
401
402 Ok(scores)
403 }
404
405 fn compute_quality_score_for_estimator(
407 &self,
408 estimator: &Nystroem<Trained>,
409 x: &Array2<Float>,
410 ) -> Result<Float> {
411 match self.quality_metric {
412 QualityMetric::FrobeniusNorm => {
413 let components = estimator.components();
414 let norm = components.dot(&components.t()).mapv(|v| v * v).sum().sqrt();
415 Ok(norm)
416 }
417 QualityMetric::Trace => {
418 let components = estimator.components();
419 let kernel_matrix = self.kernel.compute_kernel(components, components);
420 Ok(kernel_matrix.diag().sum())
421 }
422 QualityMetric::SpectralNorm => {
423 let components = estimator.components();
424 let kernel_matrix = self.kernel.compute_kernel(components, components);
425 self.power_iteration_spectral_norm(&kernel_matrix)
426 }
427 QualityMetric::NuclearNorm => {
428 let components = estimator.components();
429 let kernel_matrix = self.kernel.compute_kernel(components, components);
430 Ok(kernel_matrix.diag().sum())
431 }
432 }
433 }
434
435 fn power_iteration_spectral_norm(&self, matrix: &Array2<Float>) -> Result<Float> {
437 let n = matrix.nrows();
438 if n == 0 {
439 return Ok(0.0);
440 }
441
442 let mut v = Array1::ones(n) / (n as Float).sqrt();
443 let max_iter = 100;
444 let tolerance = 1e-6;
445
446 for _ in 0..max_iter {
447 let v_new = matrix.dot(&v);
448 let norm = (v_new.dot(&v_new)).sqrt();
449
450 if norm < tolerance {
451 break;
452 }
453
454 let v_normalized = &v_new / norm;
455 let diff = (&v_normalized - &v).dot(&(&v_normalized - &v)).sqrt();
456 v = v_normalized;
457
458 if diff < tolerance {
459 break;
460 }
461 }
462
463 let eigenvalue = v.dot(&matrix.dot(&v));
464 Ok(eigenvalue.abs())
465 }
466}
467
468use scirs2_core::ndarray::s;
470
471#[allow(non_snake_case)]
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use scirs2_core::ndarray::array;
476
477 #[test]
478 fn test_ensemble_nystroem_basic() {
479 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
480
481 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2);
482 let fitted = ensemble.fit(&x, &()).unwrap();
483 let x_transformed = fitted.transform(&x).unwrap();
484
485 assert_eq!(x_transformed.nrows(), 4);
486 assert_eq!(x_transformed.ncols(), 2); }
488
489 #[test]
490 fn test_ensemble_nystroem_average() {
491 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
492
493 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 3)
494 .ensemble_method(EnsembleMethod::Average);
495 let fitted = ensemble.fit(&x, &()).unwrap();
496 let x_transformed = fitted.transform(&x).unwrap();
497
498 assert_eq!(x_transformed.shape(), &[3, 3]);
499 }
500
501 #[test]
502 fn test_ensemble_nystroem_concatenate() {
503 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
504
505 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3)
506 .ensemble_method(EnsembleMethod::Concatenate);
507 let fitted = ensemble.fit(&x, &()).unwrap();
508 let x_transformed = fitted.transform(&x).unwrap();
509
510 assert_eq!(x_transformed.shape(), &[3, 6]); }
512
513 #[test]
514 fn test_ensemble_nystroem_weighted_average() {
515 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
516
517 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.5 }, 3, 2)
518 .ensemble_method(EnsembleMethod::WeightedAverage);
519 let fitted = ensemble.fit(&x, &()).unwrap();
520 let x_transformed = fitted.transform(&x).unwrap();
521
522 assert_eq!(x_transformed.shape(), &[4, 2]);
523
524 let weights = fitted.weights();
526 let weight_sum: Float = weights.iter().sum();
527 assert!((weight_sum - 1.0).abs() < 1e-6);
528 }
529
530 #[test]
531 fn test_ensemble_nystroem_best_approximation() {
532 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
533
534 let ensemble = EnsembleNystroem::new(Kernel::Linear, 3, 2)
535 .ensemble_method(EnsembleMethod::BestApproximation);
536 let fitted = ensemble.fit(&x, &()).unwrap();
537 let x_transformed = fitted.transform(&x).unwrap();
538
539 assert_eq!(x_transformed.shape(), &[3, 2]);
540
541 let weights = fitted.weights();
543 let active_weights: Vec<&Float> = weights.iter().filter(|&&w| w > 0.0).collect();
544 assert_eq!(active_weights.len(), 1);
545 assert!((active_weights[0] - 1.0).abs() < 1e-10);
546 }
547
548 #[test]
549 fn test_ensemble_nystroem_custom_strategies() {
550 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
551
552 let strategies = vec![SamplingStrategy::Random, SamplingStrategy::LeverageScore];
553
554 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 3).sampling_strategies(strategies);
555 let fitted = ensemble.fit(&x, &()).unwrap();
556 let x_transformed = fitted.transform(&x).unwrap();
557
558 assert_eq!(x_transformed.shape(), &[4, 3]);
559 assert_eq!(fitted.estimators().len(), 2);
560 }
561
562 #[test]
563 fn test_ensemble_nystroem_reproducibility() {
564 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
565
566 let ensemble1 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
567 let fitted1 = ensemble1.fit(&x, &()).unwrap();
568 let result1 = fitted1.transform(&x).unwrap();
569
570 let ensemble2 = EnsembleNystroem::new(Kernel::Linear, 2, 3).random_state(42);
571 let fitted2 = ensemble2.fit(&x, &()).unwrap();
572 let result2 = fitted2.transform(&x).unwrap();
573
574 assert_eq!(result1.shape(), result2.shape());
576 for (a, b) in result1.iter().zip(result2.iter()) {
577 assert!(
578 (a - b).abs() < 1e-6,
579 "Values differ too much: {} vs {}",
580 a,
581 b
582 );
583 }
584 }
585
586 #[test]
587 fn test_ensemble_nystroem_quality_metrics() {
588 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0],];
589
590 let ensemble = EnsembleNystroem::new(Kernel::Rbf { gamma: 0.1 }, 2, 2)
591 .quality_metric(QualityMetric::Trace);
592 let fitted = ensemble.fit(&x, &()).unwrap();
593 let quality_scores = fitted.quality_scores(&x).unwrap();
594
595 assert_eq!(quality_scores.len(), 2);
596 for score in quality_scores.iter() {
597 assert!(score.is_finite());
598 assert!(*score >= 0.0);
599 }
600 }
601
602 #[test]
603 fn test_ensemble_nystroem_invalid_parameters() {
604 let x = array![[1.0, 2.0]];
605
606 let ensemble = EnsembleNystroem::new(Kernel::Linear, 0, 2);
608 assert!(ensemble.fit(&x, &()).is_err());
609
610 let ensemble = EnsembleNystroem::new(Kernel::Linear, 2, 0);
612 assert!(ensemble.fit(&x, &()).is_err());
613 }
614}