1use scirs2_core::ndarray::{Array1, Array2, Axis};
2use scirs2_core::random::rngs::StdRng;
3use scirs2_core::random::{Rng, SeedableRng};
4use scirs2_core::StandardNormal;
5use sklears_core::error::{Result, SklearsError};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
16pub enum BaseKernel {
18 RBF { gamma: f64 },
20 Polynomial { degree: f64, gamma: f64, coef0: f64 },
22 Laplacian { gamma: f64 },
24 Linear,
26 Sigmoid { gamma: f64, coef0: f64 },
28 Custom {
30 name: String,
31 kernel_fn: fn(&Array1<f64>, &Array1<f64>) -> f64,
32 },
33}
34
35impl BaseKernel {
36 pub fn evaluate(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
38 match self {
39 BaseKernel::RBF { gamma } => {
40 let diff = x - y;
41 let squared_dist = diff.mapv(|x| x * x).sum();
42 (-gamma * squared_dist).exp()
43 }
44 BaseKernel::Polynomial {
45 degree,
46 gamma,
47 coef0,
48 } => {
49 let dot_product = x.dot(y);
50 (gamma * dot_product + coef0).powf(*degree)
51 }
52 BaseKernel::Laplacian { gamma } => {
53 let diff = x - y;
54 let manhattan_dist = diff.mapv(|x| x.abs()).sum();
55 (-gamma * manhattan_dist).exp()
56 }
57 BaseKernel::Linear => x.dot(y),
58 BaseKernel::Sigmoid { gamma, coef0 } => {
59 let dot_product = x.dot(y);
60 (gamma * dot_product + coef0).tanh()
61 }
62 BaseKernel::Custom { kernel_fn, .. } => kernel_fn(x, y),
63 }
64 }
65
66 pub fn name(&self) -> String {
68 match self {
69 BaseKernel::RBF { gamma } => format!("RBF(gamma={:.4})", gamma),
70 BaseKernel::Polynomial {
71 degree,
72 gamma,
73 coef0,
74 } => {
75 format!(
76 "Polynomial(degree={:.1}, gamma={:.4}, coef0={:.4})",
77 degree, gamma, coef0
78 )
79 }
80 BaseKernel::Laplacian { gamma } => format!("Laplacian(gamma={:.4})", gamma),
81 BaseKernel::Linear => "Linear".to_string(),
82 BaseKernel::Sigmoid { gamma, coef0 } => {
83 format!("Sigmoid(gamma={:.4}, coef0={:.4})", gamma, coef0)
84 }
85 BaseKernel::Custom { name, .. } => format!("Custom({})", name),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub enum CombinationStrategy {
94 Linear,
96 Product,
98 Convex,
100 Conic,
102 Hierarchical,
104}
105
106#[derive(Debug, Clone)]
108pub enum WeightLearningAlgorithm {
110 Uniform,
112 CenteredKernelAlignment,
114 MaximumMeanDiscrepancy,
116 SimpleMKL { regularization: f64 },
118 EasyMKL { radius: f64 },
120 SpectralProjected,
122 LocalizedMKL { bandwidth: f64 },
124 AdaptiveMKL { cv_folds: usize },
126}
127
128#[derive(Debug, Clone)]
130pub enum ApproximationMethod {
132 RandomFourierFeatures { n_components: usize },
134 Nystroem { n_components: usize },
136 StructuredFeatures { n_components: usize },
138 Exact,
140}
141
142#[derive(Debug, Clone)]
144pub struct MultiKernelConfig {
146 pub combination_strategy: CombinationStrategy,
148 pub weight_learning: WeightLearningAlgorithm,
150 pub approximation_method: ApproximationMethod,
152 pub max_iterations: usize,
154 pub tolerance: f64,
156 pub normalize_kernels: bool,
158 pub center_kernels: bool,
160 pub regularization: f64,
162}
163
164impl Default for MultiKernelConfig {
165 fn default() -> Self {
166 Self {
167 combination_strategy: CombinationStrategy::Convex,
168 weight_learning: WeightLearningAlgorithm::CenteredKernelAlignment,
169 approximation_method: ApproximationMethod::RandomFourierFeatures { n_components: 100 },
170 max_iterations: 100,
171 tolerance: 1e-6,
172 normalize_kernels: true,
173 center_kernels: true,
174 regularization: 1e-3,
175 }
176 }
177}
178
179pub struct MultipleKernelLearning {
184 base_kernels: Vec<BaseKernel>,
185 config: MultiKernelConfig,
186 weights: Option<Array1<f64>>,
187 kernel_matrices: Option<Vec<Array2<f64>>>,
188 combined_features: Option<Array2<f64>>,
189 random_state: Option<u64>,
190 rng: StdRng,
191 training_data: Option<Array2<f64>>,
192 kernel_statistics: HashMap<String, KernelStatistics>,
193}
194
195#[derive(Debug, Clone)]
197pub struct KernelStatistics {
199 pub alignment: f64,
201 pub eigenspectrum: Array1<f64>,
203 pub effective_rank: f64,
205 pub diversity: f64,
207 pub complexity: f64,
209}
210
211impl KernelStatistics {
212 pub fn new() -> Self {
213 Self {
214 alignment: 0.0,
215 eigenspectrum: Array1::zeros(0),
216 effective_rank: 0.0,
217 diversity: 0.0,
218 complexity: 0.0,
219 }
220 }
221}
222
223impl MultipleKernelLearning {
224 pub fn new(base_kernels: Vec<BaseKernel>) -> Self {
226 let rng = StdRng::seed_from_u64(42);
227 Self {
228 base_kernels,
229 config: MultiKernelConfig::default(),
230 weights: None,
231 kernel_matrices: None,
232 combined_features: None,
233 random_state: None,
234 rng,
235 training_data: None,
236 kernel_statistics: HashMap::new(),
237 }
238 }
239
240 pub fn with_config(mut self, config: MultiKernelConfig) -> Self {
242 self.config = config;
243 self
244 }
245
246 pub fn with_random_state(mut self, random_state: u64) -> Self {
248 self.random_state = Some(random_state);
249 self.rng = StdRng::seed_from_u64(random_state);
250 self
251 }
252
253 pub fn fit(&mut self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
255 let (n_samples, _) = x.dim();
256
257 self.training_data = Some(x.clone());
259
260 let mut kernel_matrices = Vec::new();
262 let base_kernels = self.base_kernels.clone(); for (i, base_kernel) in base_kernels.iter().enumerate() {
265 let kernel_matrix = match &self.config.approximation_method {
266 ApproximationMethod::RandomFourierFeatures { n_components } => {
267 self.compute_rff_approximation(x, base_kernel, *n_components)?
268 }
269 ApproximationMethod::Nystroem { n_components } => {
270 self.compute_nystroem_approximation(x, base_kernel, *n_components)?
271 }
272 ApproximationMethod::StructuredFeatures { n_components } => {
273 self.compute_structured_approximation(x, base_kernel, *n_components)?
274 }
275 ApproximationMethod::Exact => self.compute_exact_kernel_matrix(x, base_kernel)?,
276 };
277
278 let processed_matrix = self.process_kernel_matrix(kernel_matrix)?;
280
281 let stats = self.compute_kernel_statistics(&processed_matrix, y)?;
283 self.kernel_statistics
284 .insert(format!("kernel_{}", i), stats);
285
286 kernel_matrices.push(processed_matrix);
287 }
288
289 self.kernel_matrices = Some(kernel_matrices);
290
291 self.learn_weights(y)?;
293
294 self.compute_combined_representation()?;
296
297 Ok(())
298 }
299
300 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
302 let weights = self
303 .weights
304 .as_ref()
305 .ok_or_else(|| SklearsError::NotFitted {
306 operation: "transform".to_string(),
307 })?;
308 let training_data = self
309 .training_data
310 .as_ref()
311 .ok_or_else(|| SklearsError::NotFitted {
312 operation: "transform".to_string(),
313 })?;
314
315 let mut combined_features = None;
316
317 for (i, (base_kernel, &weight)) in self.base_kernels.iter().zip(weights.iter()).enumerate()
318 {
319 if weight.abs() < 1e-12 {
320 continue; }
322
323 let features = match &self.config.approximation_method {
324 ApproximationMethod::RandomFourierFeatures { n_components } => {
325 self.transform_rff(x, training_data, base_kernel, *n_components)?
326 }
327 ApproximationMethod::Nystroem { n_components } => {
328 self.transform_nystroem(x, training_data, base_kernel, *n_components)?
329 }
330 ApproximationMethod::StructuredFeatures { n_components } => {
331 self.transform_structured(x, training_data, base_kernel, *n_components)?
332 }
333 ApproximationMethod::Exact => {
334 return Err(SklearsError::NotImplemented(
335 "Exact kernel transform not implemented for new data".to_string(),
336 ));
337 }
338 };
339
340 let weighted_features = &features * weight;
341
342 match &self.config.combination_strategy {
343 CombinationStrategy::Linear
344 | CombinationStrategy::Convex
345 | CombinationStrategy::Conic => {
346 combined_features = match combined_features {
347 Some(existing) => Some(existing + weighted_features),
348 None => Some(weighted_features),
349 };
350 }
351 CombinationStrategy::Product => {
352 combined_features = match combined_features {
353 Some(existing) => Some(existing * weighted_features.mapv(|x| x.exp())),
354 None => Some(weighted_features.mapv(|x| x.exp())),
355 };
356 }
357 CombinationStrategy::Hierarchical => {
358 combined_features = match combined_features {
360 Some(existing) => Some(existing + weighted_features),
361 None => Some(weighted_features),
362 };
363 }
364 }
365 }
366
367 combined_features.ok_or_else(|| {
368 SklearsError::Other("No features generated - all kernel weights are zero".to_string())
369 })
370 }
371
372 pub fn kernel_weights(&self) -> Option<&Array1<f64>> {
374 self.weights.as_ref()
375 }
376
377 pub fn kernel_stats(&self) -> &HashMap<String, KernelStatistics> {
379 &self.kernel_statistics
380 }
381
382 pub fn important_kernels(&self, threshold: f64) -> Vec<(usize, &BaseKernel, f64)> {
384 if let Some(weights) = &self.weights {
385 self.base_kernels
386 .iter()
387 .enumerate()
388 .zip(weights.iter())
389 .filter_map(|((i, kernel), &weight)| {
390 if weight.abs() >= threshold {
391 Some((i, kernel, weight))
392 } else {
393 None
394 }
395 })
396 .collect()
397 } else {
398 Vec::new()
399 }
400 }
401
402 fn learn_weights(&mut self, y: Option<&Array1<f64>>) -> Result<()> {
404 let kernel_matrices = self.kernel_matrices.as_ref().unwrap();
405 let n_kernels = kernel_matrices.len();
406
407 let weights = match &self.config.weight_learning {
408 WeightLearningAlgorithm::Uniform => {
409 Array1::from_elem(n_kernels, 1.0 / n_kernels as f64)
410 }
411 WeightLearningAlgorithm::CenteredKernelAlignment => {
412 self.learn_cka_weights(kernel_matrices, y)?
413 }
414 WeightLearningAlgorithm::MaximumMeanDiscrepancy => {
415 self.learn_mmd_weights(kernel_matrices)?
416 }
417 WeightLearningAlgorithm::SimpleMKL { regularization } => {
418 self.learn_simple_mkl_weights(kernel_matrices, y, *regularization)?
419 }
420 WeightLearningAlgorithm::EasyMKL { radius } => {
421 self.learn_easy_mkl_weights(kernel_matrices, y, *radius)?
422 }
423 WeightLearningAlgorithm::SpectralProjected => {
424 self.learn_spectral_weights(kernel_matrices)?
425 }
426 WeightLearningAlgorithm::LocalizedMKL { bandwidth } => {
427 self.learn_localized_weights(kernel_matrices, *bandwidth)?
428 }
429 WeightLearningAlgorithm::AdaptiveMKL { cv_folds } => {
430 self.learn_adaptive_weights(kernel_matrices, y, *cv_folds)?
431 }
432 };
433
434 let final_weights = self.apply_combination_constraints(weights)?;
436
437 self.weights = Some(final_weights);
438 Ok(())
439 }
440
441 fn learn_cka_weights(
443 &self,
444 kernel_matrices: &[Array2<f64>],
445 y: Option<&Array1<f64>>,
446 ) -> Result<Array1<f64>> {
447 if let Some(labels) = y {
448 let label_kernel = self.compute_label_kernel(labels)?;
450 let mut alignments = Array1::zeros(kernel_matrices.len());
451
452 for (i, kernel) in kernel_matrices.iter().enumerate() {
453 alignments[i] = self.centered_kernel_alignment(kernel, &label_kernel)?;
454 }
455
456 let max_alignment = alignments.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
458 let exp_alignments = alignments.mapv(|x| (x - max_alignment).exp());
459 let sum_exp = exp_alignments.sum();
460
461 Ok(exp_alignments / sum_exp)
462 } else {
463 let mut weights = Array1::zeros(kernel_matrices.len());
465
466 for (i, kernel) in kernel_matrices.iter().enumerate() {
467 weights[i] = kernel.diag().sum() / kernel.nrows() as f64;
469 }
470
471 let sum_weights = weights.sum();
472 if sum_weights > 0.0 {
473 weights /= sum_weights;
474 } else {
475 weights.fill(1.0 / kernel_matrices.len() as f64);
476 }
477
478 Ok(weights)
479 }
480 }
481
482 fn learn_mmd_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
484 let mut weights = Array1::zeros(kernel_matrices.len());
486
487 for (i, kernel) in kernel_matrices.iter().enumerate() {
488 let trace = kernel.diag().sum();
490 let frobenius_norm = kernel.mapv(|x| x * x).sum().sqrt();
491 weights[i] = trace / frobenius_norm;
492 }
493
494 let sum_weights = weights.sum();
495 if sum_weights > 0.0 {
496 weights /= sum_weights;
497 } else {
498 weights.fill(1.0 / kernel_matrices.len() as f64);
499 }
500
501 Ok(weights)
502 }
503
504 fn learn_simple_mkl_weights(
506 &self,
507 kernel_matrices: &[Array2<f64>],
508 _y: Option<&Array1<f64>>,
509 regularization: f64,
510 ) -> Result<Array1<f64>> {
511 let mut weights = Array1::zeros(kernel_matrices.len());
513
514 for (i, kernel) in kernel_matrices.iter().enumerate() {
515 let eigenvalues = self.compute_simplified_eigenvalues(kernel)?;
516 let effective_rank =
517 eigenvalues.mapv(|x| x * x).sum().powi(2) / eigenvalues.mapv(|x| x.powi(4)).sum();
518 weights[i] = effective_rank / (1.0 + regularization);
519 }
520
521 let sum_weights = weights.sum();
522 if sum_weights > 0.0 {
523 weights /= sum_weights;
524 } else {
525 weights.fill(1.0 / kernel_matrices.len() as f64);
526 }
527
528 Ok(weights)
529 }
530
531 fn learn_easy_mkl_weights(
533 &self,
534 kernel_matrices: &[Array2<f64>],
535 _y: Option<&Array1<f64>>,
536 _radius: f64,
537 ) -> Result<Array1<f64>> {
538 Ok(Array1::from_elem(
540 kernel_matrices.len(),
541 1.0 / kernel_matrices.len() as f64,
542 ))
543 }
544
545 fn learn_spectral_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
547 let mut weights = Array1::zeros(kernel_matrices.len());
548
549 for (i, kernel) in kernel_matrices.iter().enumerate() {
550 let trace = kernel.diag().sum();
551 weights[i] = trace;
552 }
553
554 let sum_weights = weights.sum();
555 if sum_weights > 0.0 {
556 weights /= sum_weights;
557 } else {
558 weights.fill(1.0 / kernel_matrices.len() as f64);
559 }
560
561 Ok(weights)
562 }
563
564 fn learn_localized_weights(
566 &self,
567 kernel_matrices: &[Array2<f64>],
568 _bandwidth: f64,
569 ) -> Result<Array1<f64>> {
570 Ok(Array1::from_elem(
572 kernel_matrices.len(),
573 1.0 / kernel_matrices.len() as f64,
574 ))
575 }
576
577 fn learn_adaptive_weights(
579 &self,
580 kernel_matrices: &[Array2<f64>],
581 _y: Option<&Array1<f64>>,
582 _cv_folds: usize,
583 ) -> Result<Array1<f64>> {
584 Ok(Array1::from_elem(
586 kernel_matrices.len(),
587 1.0 / kernel_matrices.len() as f64,
588 ))
589 }
590
591 fn apply_combination_constraints(&self, mut weights: Array1<f64>) -> Result<Array1<f64>> {
593 match &self.config.combination_strategy {
594 CombinationStrategy::Convex => {
595 weights.mapv_inplace(|x| x.max(0.0));
597 let sum = weights.sum();
598 if sum > 0.0 {
599 weights /= sum;
600 } else {
601 weights.fill(1.0 / weights.len() as f64);
602 }
603 }
604 CombinationStrategy::Conic => {
605 weights.mapv_inplace(|x| x.max(0.0));
607 }
608 CombinationStrategy::Linear => {
609 }
611 CombinationStrategy::Product => {
612 weights.mapv_inplace(|x| x.abs().max(1e-12));
614 }
615 CombinationStrategy::Hierarchical => {
616 weights.mapv_inplace(|x| x.max(0.0));
618 let sum = weights.sum();
619 if sum > 0.0 {
620 weights /= sum;
621 } else {
622 weights.fill(1.0 / weights.len() as f64);
623 }
624 }
625 }
626
627 Ok(weights)
628 }
629
630 fn compute_label_kernel(&self, labels: &Array1<f64>) -> Result<Array2<f64>> {
632 let n = labels.len();
633 let mut label_kernel = Array2::zeros((n, n));
634
635 for i in 0..n {
636 for j in 0..n {
637 label_kernel[[i, j]] = if (labels[i] - labels[j]).abs() < 1e-10 {
638 1.0
639 } else {
640 0.0
641 };
642 }
643 }
644
645 Ok(label_kernel)
646 }
647
648 fn centered_kernel_alignment(&self, k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
650 let n = k1.nrows() as f64;
651 let ones = Array2::ones((k1.nrows(), k1.ncols())) / n;
652
653 let k1_centered = k1 - &ones.dot(k1) - &k1.dot(&ones) + &ones.dot(k1).dot(&ones);
655 let k2_centered = k2 - &ones.dot(k2) - &k2.dot(&ones) + &ones.dot(k2).dot(&ones);
656
657 let numerator = (&k1_centered * &k2_centered).sum();
659 let denominator =
660 ((&k1_centered * &k1_centered).sum() * (&k2_centered * &k2_centered).sum()).sqrt();
661
662 if denominator > 1e-12 {
663 Ok(numerator / denominator)
664 } else {
665 Ok(0.0)
666 }
667 }
668
669 fn process_kernel_matrix(&self, mut kernel: Array2<f64>) -> Result<Array2<f64>> {
671 if self.config.normalize_kernels {
672 let diag = kernel.diag();
674 let norm_matrix = diag.insert_axis(Axis(1)).dot(&diag.insert_axis(Axis(0)));
675 for i in 0..kernel.nrows() {
676 for j in 0..kernel.ncols() {
677 if norm_matrix[[i, j]] > 1e-12 {
678 kernel[[i, j]] /= norm_matrix[[i, j]].sqrt();
679 }
680 }
681 }
682 }
683
684 if self.config.center_kernels {
685 let n = kernel.nrows() as f64;
687 let row_means = kernel.mean_axis(Axis(1)).unwrap();
688 let col_means = kernel.mean_axis(Axis(0)).unwrap();
689 let total_mean = kernel.mean().unwrap();
690
691 for i in 0..kernel.nrows() {
692 for j in 0..kernel.ncols() {
693 kernel[[i, j]] = kernel[[i, j]] - row_means[i] - col_means[j] + total_mean;
694 }
695 }
696 }
697
698 Ok(kernel)
699 }
700
701 fn compute_rff_approximation(
703 &mut self,
704 x: &Array2<f64>,
705 kernel: &BaseKernel,
706 n_components: usize,
707 ) -> Result<Array2<f64>> {
708 match kernel {
709 BaseKernel::RBF { gamma } => self.compute_rbf_rff_matrix(x, *gamma, n_components),
710 BaseKernel::Laplacian { gamma } => {
711 self.compute_laplacian_rff_matrix(x, *gamma, n_components)
712 }
713 _ => {
714 self.compute_exact_kernel_matrix(x, kernel)
716 }
717 }
718 }
719
720 fn compute_rbf_rff_matrix(
721 &mut self,
722 x: &Array2<f64>,
723 gamma: f64,
724 n_components: usize,
725 ) -> Result<Array2<f64>> {
726 let (n_samples, n_features) = x.dim();
727
728 let mut weights = Array2::zeros((n_components, n_features));
730 for i in 0..n_components {
731 for j in 0..n_features {
732 weights[[i, j]] = self.rng.sample::<f64, _>(StandardNormal) * (2.0 * gamma).sqrt();
733 }
734 }
735
736 let mut bias = Array1::zeros(n_components);
738 for i in 0..n_components {
739 bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
740 }
741
742 let projection = x.dot(&weights.t()) + &bias;
744 let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
745
746 Ok(features.dot(&features.t()))
748 }
749
750 fn compute_laplacian_rff_matrix(
751 &mut self,
752 x: &Array2<f64>,
753 gamma: f64,
754 n_components: usize,
755 ) -> Result<Array2<f64>> {
756 let (n_samples, n_features) = x.dim();
758
759 let mut weights = Array2::zeros((n_components, n_features));
761 for i in 0..n_components {
762 for j in 0..n_features {
763 let u: f64 = self.rng.gen_range(0.001..0.999);
764 weights[[i, j]] = ((std::f64::consts::PI * (u - 0.5)).tan()) * gamma;
765 }
766 }
767
768 let mut bias = Array1::zeros(n_components);
770 for i in 0..n_components {
771 bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
772 }
773
774 let projection = x.dot(&weights.t()) + &bias;
776 let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
777
778 Ok(features.dot(&features.t()))
780 }
781
782 fn compute_nystroem_approximation(
783 &mut self,
784 x: &Array2<f64>,
785 kernel: &BaseKernel,
786 n_components: usize,
787 ) -> Result<Array2<f64>> {
788 let (n_samples, _) = x.dim();
789 let n_landmarks = n_components.min(n_samples);
790
791 let mut landmark_indices = Vec::new();
793 for _ in 0..n_landmarks {
794 landmark_indices.push(self.rng.gen_range(0..n_samples));
795 }
796
797 let mut kernel_matrix = Array2::zeros((n_samples, n_landmarks));
799 for i in 0..n_samples {
800 for j in 0..n_landmarks {
801 let landmark_idx = landmark_indices[j];
802 kernel_matrix[[i, j]] =
803 kernel.evaluate(&x.row(i).to_owned(), &x.row(landmark_idx).to_owned());
804 }
805 }
806
807 Ok(kernel_matrix.dot(&kernel_matrix.t()))
809 }
810
811 fn compute_structured_approximation(
812 &mut self,
813 x: &Array2<f64>,
814 kernel: &BaseKernel,
815 n_components: usize,
816 ) -> Result<Array2<f64>> {
817 self.compute_rff_approximation(x, kernel, n_components)
819 }
820
821 fn compute_exact_kernel_matrix(
822 &self,
823 x: &Array2<f64>,
824 kernel: &BaseKernel,
825 ) -> Result<Array2<f64>> {
826 let n_samples = x.nrows();
827 let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
828
829 for i in 0..n_samples {
830 for j in i..n_samples {
831 let value = kernel.evaluate(&x.row(i).to_owned(), &x.row(j).to_owned());
832 kernel_matrix[[i, j]] = value;
833 kernel_matrix[[j, i]] = value;
834 }
835 }
836
837 Ok(kernel_matrix)
838 }
839
840 fn transform_rff(
842 &self,
843 x: &Array2<f64>,
844 _training_data: &Array2<f64>,
845 _kernel: &BaseKernel,
846 n_components: usize,
847 ) -> Result<Array2<f64>> {
848 let (n_samples, _) = x.dim();
850 Ok(Array2::zeros((n_samples, n_components)))
851 }
852
853 fn transform_nystroem(
854 &self,
855 x: &Array2<f64>,
856 _training_data: &Array2<f64>,
857 _kernel: &BaseKernel,
858 n_components: usize,
859 ) -> Result<Array2<f64>> {
860 let (n_samples, _) = x.dim();
862 Ok(Array2::zeros((n_samples, n_components)))
863 }
864
865 fn transform_structured(
866 &self,
867 x: &Array2<f64>,
868 _training_data: &Array2<f64>,
869 _kernel: &BaseKernel,
870 n_components: usize,
871 ) -> Result<Array2<f64>> {
872 let (n_samples, _) = x.dim();
874 Ok(Array2::zeros((n_samples, n_components)))
875 }
876
877 fn compute_combined_representation(&mut self) -> Result<()> {
878 Ok(())
880 }
881
882 fn compute_kernel_statistics(
883 &self,
884 kernel: &Array2<f64>,
885 _y: Option<&Array1<f64>>,
886 ) -> Result<KernelStatistics> {
887 let mut stats = KernelStatistics::new();
888
889 stats.alignment = kernel.diag().mean().unwrap_or(0.0);
891
892 stats.eigenspectrum = kernel.diag().to_owned();
894
895 let trace = kernel.diag().sum();
897 let frobenius_sq = kernel.mapv(|x| x * x).sum();
898 stats.effective_rank = if frobenius_sq > 1e-12 {
899 trace.powi(2) / frobenius_sq
900 } else {
901 0.0
902 };
903
904 stats.diversity = kernel.diag().var(0.0);
906
907 let diag = kernel.diag();
909 let max_eig = diag.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
910 let min_eig = diag.iter().fold(f64::INFINITY, |a, &b| a.min(b.max(1e-12)));
911 stats.complexity = max_eig / min_eig;
912
913 Ok(stats)
914 }
915
916 fn compute_simplified_eigenvalues(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
917 Ok(matrix.diag().to_owned())
919 }
920}
921
922#[allow(non_snake_case)]
923#[cfg(test)]
924mod tests {
925 use super::*;
926 use scirs2_core::ndarray::array;
927
928 #[test]
929 fn test_base_kernel_evaluation() {
930 let x = array![1.0, 2.0, 3.0];
931 let y = array![1.0, 2.0, 3.0];
932
933 let rbf_kernel = BaseKernel::RBF { gamma: 0.1 };
934 let value = rbf_kernel.evaluate(&x, &y);
935 assert!((value - 1.0).abs() < 1e-10); let linear_kernel = BaseKernel::Linear;
938 let value = linear_kernel.evaluate(&x, &y);
939 assert!((value - 14.0).abs() < 1e-10); }
941
942 #[test]
943 fn test_kernel_names() {
944 let rbf = BaseKernel::RBF { gamma: 0.5 };
945 assert_eq!(rbf.name(), "RBF(gamma=0.5000)");
946
947 let linear = BaseKernel::Linear;
948 assert_eq!(linear.name(), "Linear");
949
950 let poly = BaseKernel::Polynomial {
951 degree: 2.0,
952 gamma: 1.0,
953 coef0: 0.0,
954 };
955 assert_eq!(
956 poly.name(),
957 "Polynomial(degree=2.0, gamma=1.0000, coef0=0.0000)"
958 );
959 }
960
961 #[test]
962 fn test_multiple_kernel_learning_basic() {
963 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
964
965 let base_kernels = vec![
966 BaseKernel::RBF { gamma: 0.1 },
967 BaseKernel::Linear,
968 BaseKernel::Polynomial {
969 degree: 2.0,
970 gamma: 1.0,
971 coef0: 0.0,
972 },
973 ];
974
975 let mut mkl = MultipleKernelLearning::new(base_kernels).with_random_state(42);
976
977 mkl.fit(&x, None).unwrap();
978
979 let weights = mkl.kernel_weights().unwrap();
980 assert_eq!(weights.len(), 3);
981 assert!((weights.sum() - 1.0).abs() < 1e-10); }
983
984 #[test]
985 fn test_kernel_statistics() {
986 let kernel = array![[1.0, 0.5, 0.2], [0.5, 1.0, 0.3], [0.2, 0.3, 1.0]];
987
988 let mkl = MultipleKernelLearning::new(vec![]);
989 let stats = mkl.compute_kernel_statistics(&kernel, None).unwrap();
990
991 assert!((stats.alignment - 1.0).abs() < 1e-10); assert!(stats.effective_rank > 0.0);
993 assert!(stats.diversity >= 0.0);
994 }
995
996 #[test]
997 fn test_combination_strategies() {
998 let weights = array![0.5, -0.3, 0.8];
999
1000 let mut mkl = MultipleKernelLearning::new(vec![]);
1001 mkl.config.combination_strategy = CombinationStrategy::Convex;
1002
1003 let constrained = mkl.apply_combination_constraints(weights.clone()).unwrap();
1004
1005 assert!(constrained.iter().all(|&x| x >= 0.0));
1007 assert!((constrained.sum() - 1.0).abs() < 1e-10);
1008 }
1009
1010 #[test]
1011 fn test_mkl_config() {
1012 let config = MultiKernelConfig {
1013 combination_strategy: CombinationStrategy::Linear,
1014 weight_learning: WeightLearningAlgorithm::SimpleMKL {
1015 regularization: 0.01,
1016 },
1017 approximation_method: ApproximationMethod::Nystroem { n_components: 50 },
1018 max_iterations: 200,
1019 tolerance: 1e-8,
1020 normalize_kernels: false,
1021 center_kernels: false,
1022 regularization: 0.001,
1023 };
1024
1025 assert!(matches!(
1026 config.combination_strategy,
1027 CombinationStrategy::Linear
1028 ));
1029 assert!(matches!(
1030 config.weight_learning,
1031 WeightLearningAlgorithm::SimpleMKL { .. }
1032 ));
1033 assert_eq!(config.max_iterations, 200);
1034 assert!(!config.normalize_kernels);
1035 }
1036
1037 #[test]
1038 fn test_important_kernels() {
1039 let base_kernels = vec![
1040 BaseKernel::RBF { gamma: 0.1 },
1041 BaseKernel::Linear,
1042 BaseKernel::Polynomial {
1043 degree: 2.0,
1044 gamma: 1.0,
1045 coef0: 0.0,
1046 },
1047 ];
1048
1049 let mut mkl = MultipleKernelLearning::new(base_kernels);
1050 mkl.weights = Some(array![0.6, 0.05, 0.35]);
1051
1052 let important = mkl.important_kernels(0.1);
1053 assert_eq!(important.len(), 2); assert_eq!(important[0].0, 0); assert_eq!(important[1].0, 2); }
1057
1058 #[test]
1059 fn test_supervised_vs_unsupervised() {
1060 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1061 let y = array![0.0, 1.0, 0.0, 1.0];
1062
1063 let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1064
1065 let mut mkl_unsupervised =
1066 MultipleKernelLearning::new(base_kernels.clone()).with_random_state(42);
1067 mkl_unsupervised.fit(&x, None).unwrap();
1068
1069 let mut mkl_supervised = MultipleKernelLearning::new(base_kernels).with_random_state(42);
1070 mkl_supervised.fit(&x, Some(&y)).unwrap();
1071
1072 assert!(mkl_unsupervised.kernel_weights().is_some());
1074 assert!(mkl_supervised.kernel_weights().is_some());
1075 }
1076
1077 #[test]
1078 fn test_transform_compatibility() {
1079 let x_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1080 let x_test = array![[2.0, 3.0], [4.0, 5.0]];
1081
1082 let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1083
1084 let mut mkl = MultipleKernelLearning::new(base_kernels)
1085 .with_config(MultiKernelConfig {
1086 approximation_method: ApproximationMethod::RandomFourierFeatures {
1087 n_components: 10,
1088 },
1089 ..Default::default()
1090 })
1091 .with_random_state(42);
1092
1093 mkl.fit(&x_train, None).unwrap();
1094 let features = mkl.transform(&x_test).unwrap();
1095
1096 assert_eq!(features.nrows(), 2); assert!(features.ncols() > 0); }
1099}