sklears_kernel_approximation/
multi_kernel_learning.rs

1use scirs2_core::ndarray::{Array1, Array2, Axis};
2use scirs2_core::random::rngs::StdRng;
3use scirs2_core::random::{Rng, SeedableRng};
4use scirs2_core::StandardNormal;
5use sklears_core::error::{Result, SklearsError};
6use std::collections::HashMap;
7
8/// Multiple Kernel Learning (MKL) methods for kernel approximation
9///
10/// This module provides methods for learning optimal combinations of multiple
11/// kernels, enabling automatic kernel selection and weighting for improved
12/// machine learning performance.
13///
14/// Base kernel type for multiple kernel learning
15#[derive(Debug, Clone)]
16/// BaseKernel
17pub enum BaseKernel {
18    /// RBF kernel with specific gamma parameter
19    RBF { gamma: f64 },
20    /// Polynomial kernel with degree, gamma, and coef0
21    Polynomial { degree: f64, gamma: f64, coef0: f64 },
22    /// Laplacian kernel with gamma parameter
23    Laplacian { gamma: f64 },
24    /// Linear kernel
25    Linear,
26    /// Sigmoid kernel with gamma and coef0
27    Sigmoid { gamma: f64, coef0: f64 },
28    /// Custom kernel with user-defined function
29    Custom {
30        name: String,
31        kernel_fn: fn(&Array1<f64>, &Array1<f64>) -> f64,
32    },
33}
34
35impl BaseKernel {
36    /// Evaluate kernel function between two samples
37    pub fn evaluate(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
38        match self {
39            BaseKernel::RBF { gamma } => {
40                let diff = x - y;
41                let squared_dist = diff.mapv(|x| x * x).sum();
42                (-gamma * squared_dist).exp()
43            }
44            BaseKernel::Polynomial {
45                degree,
46                gamma,
47                coef0,
48            } => {
49                let dot_product = x.dot(y);
50                (gamma * dot_product + coef0).powf(*degree)
51            }
52            BaseKernel::Laplacian { gamma } => {
53                let diff = x - y;
54                let manhattan_dist = diff.mapv(|x| x.abs()).sum();
55                (-gamma * manhattan_dist).exp()
56            }
57            BaseKernel::Linear => x.dot(y),
58            BaseKernel::Sigmoid { gamma, coef0 } => {
59                let dot_product = x.dot(y);
60                (gamma * dot_product + coef0).tanh()
61            }
62            BaseKernel::Custom { kernel_fn, .. } => kernel_fn(x, y),
63        }
64    }
65
66    /// Get kernel name for identification
67    pub fn name(&self) -> String {
68        match self {
69            BaseKernel::RBF { gamma } => format!("RBF(gamma={:.4})", gamma),
70            BaseKernel::Polynomial {
71                degree,
72                gamma,
73                coef0,
74            } => {
75                format!(
76                    "Polynomial(degree={:.1}, gamma={:.4}, coef0={:.4})",
77                    degree, gamma, coef0
78                )
79            }
80            BaseKernel::Laplacian { gamma } => format!("Laplacian(gamma={:.4})", gamma),
81            BaseKernel::Linear => "Linear".to_string(),
82            BaseKernel::Sigmoid { gamma, coef0 } => {
83                format!("Sigmoid(gamma={:.4}, coef0={:.4})", gamma, coef0)
84            }
85            BaseKernel::Custom { name, .. } => format!("Custom({})", name),
86        }
87    }
88}
89
90/// Kernel combination strategy
91#[derive(Debug, Clone)]
92/// CombinationStrategy
93pub enum CombinationStrategy {
94    /// Linear combination: K = Σ αᵢ Kᵢ
95    Linear,
96    /// Product combination: K = Π Kᵢ^αᵢ
97    Product,
98    /// Convex combination with unit sum constraint: Σ αᵢ = 1
99    Convex,
100    /// Conic combination with non-negativity: αᵢ ≥ 0
101    Conic,
102    /// Hierarchical combination with tree structure
103    Hierarchical,
104}
105
106/// Kernel weight learning algorithm
107#[derive(Debug, Clone)]
108/// WeightLearningAlgorithm
109pub enum WeightLearningAlgorithm {
110    /// Uniform weights (baseline)
111    Uniform,
112    /// Centered kernel alignment optimization
113    CenteredKernelAlignment,
114    /// Maximum mean discrepancy minimization
115    MaximumMeanDiscrepancy,
116    /// SimpleMKL algorithm with QCQP optimization
117    SimpleMKL { regularization: f64 },
118    /// EasyMKL algorithm with radius constraint
119    EasyMKL { radius: f64 },
120    /// SPKM (Spectral Projected Kernel Machine)
121    SpectralProjected,
122    /// Localized MKL with spatial weights
123    LocalizedMKL { bandwidth: f64 },
124    /// Adaptive MKL with cross-validation
125    AdaptiveMKL { cv_folds: usize },
126}
127
128/// Kernel approximation method for each base kernel
129#[derive(Debug, Clone)]
130/// ApproximationMethod
131pub enum ApproximationMethod {
132    /// Random Fourier Features
133    RandomFourierFeatures { n_components: usize },
134    /// Nyström approximation
135    Nystroem { n_components: usize },
136    /// Structured random features
137    StructuredFeatures { n_components: usize },
138    /// Exact kernel (no approximation)
139    Exact,
140}
141
142/// Configuration for multiple kernel learning
143#[derive(Debug, Clone)]
144/// MultiKernelConfig
145pub struct MultiKernelConfig {
146    /// combination_strategy
147    pub combination_strategy: CombinationStrategy,
148    /// weight_learning
149    pub weight_learning: WeightLearningAlgorithm,
150    /// approximation_method
151    pub approximation_method: ApproximationMethod,
152    /// max_iterations
153    pub max_iterations: usize,
154    /// tolerance
155    pub tolerance: f64,
156    /// normalize_kernels
157    pub normalize_kernels: bool,
158    /// center_kernels
159    pub center_kernels: bool,
160    /// regularization
161    pub regularization: f64,
162}
163
164impl Default for MultiKernelConfig {
165    fn default() -> Self {
166        Self {
167            combination_strategy: CombinationStrategy::Convex,
168            weight_learning: WeightLearningAlgorithm::CenteredKernelAlignment,
169            approximation_method: ApproximationMethod::RandomFourierFeatures { n_components: 100 },
170            max_iterations: 100,
171            tolerance: 1e-6,
172            normalize_kernels: true,
173            center_kernels: true,
174            regularization: 1e-3,
175        }
176    }
177}
178
179/// Multiple Kernel Learning approximation
180///
181/// Learns optimal weights for combining multiple kernel approximations
182/// to create a single, improved kernel approximation.
183pub struct MultipleKernelLearning {
184    base_kernels: Vec<BaseKernel>,
185    config: MultiKernelConfig,
186    weights: Option<Array1<f64>>,
187    kernel_matrices: Option<Vec<Array2<f64>>>,
188    combined_features: Option<Array2<f64>>,
189    random_state: Option<u64>,
190    rng: StdRng,
191    training_data: Option<Array2<f64>>,
192    kernel_statistics: HashMap<String, KernelStatistics>,
193}
194
195/// Statistics for individual kernels
196#[derive(Debug, Clone)]
197/// KernelStatistics
198pub struct KernelStatistics {
199    /// alignment
200    pub alignment: f64,
201    /// eigenspectrum
202    pub eigenspectrum: Array1<f64>,
203    /// effective_rank
204    pub effective_rank: f64,
205    /// diversity
206    pub diversity: f64,
207    /// complexity
208    pub complexity: f64,
209}
210
211impl KernelStatistics {
212    pub fn new() -> Self {
213        Self {
214            alignment: 0.0,
215            eigenspectrum: Array1::zeros(0),
216            effective_rank: 0.0,
217            diversity: 0.0,
218            complexity: 0.0,
219        }
220    }
221}
222
223impl MultipleKernelLearning {
224    /// Create a new multiple kernel learning instance
225    pub fn new(base_kernels: Vec<BaseKernel>) -> Self {
226        let rng = StdRng::seed_from_u64(42);
227        Self {
228            base_kernels,
229            config: MultiKernelConfig::default(),
230            weights: None,
231            kernel_matrices: None,
232            combined_features: None,
233            random_state: None,
234            rng,
235            training_data: None,
236            kernel_statistics: HashMap::new(),
237        }
238    }
239
240    /// Set configuration
241    pub fn with_config(mut self, config: MultiKernelConfig) -> Self {
242        self.config = config;
243        self
244    }
245
246    /// Set random state for reproducibility
247    pub fn with_random_state(mut self, random_state: u64) -> Self {
248        self.random_state = Some(random_state);
249        self.rng = StdRng::seed_from_u64(random_state);
250        self
251    }
252
253    /// Fit the multiple kernel learning model
254    pub fn fit(&mut self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
255        let (n_samples, _) = x.dim();
256
257        // Store training data
258        self.training_data = Some(x.clone());
259
260        // Compute or approximate individual kernel matrices
261        let mut kernel_matrices = Vec::new();
262        let base_kernels = self.base_kernels.clone(); // Clone to avoid borrowing issues
263
264        for (i, base_kernel) in base_kernels.iter().enumerate() {
265            let kernel_matrix = match &self.config.approximation_method {
266                ApproximationMethod::RandomFourierFeatures { n_components } => {
267                    self.compute_rff_approximation(x, base_kernel, *n_components)?
268                }
269                ApproximationMethod::Nystroem { n_components } => {
270                    self.compute_nystroem_approximation(x, base_kernel, *n_components)?
271                }
272                ApproximationMethod::StructuredFeatures { n_components } => {
273                    self.compute_structured_approximation(x, base_kernel, *n_components)?
274                }
275                ApproximationMethod::Exact => self.compute_exact_kernel_matrix(x, base_kernel)?,
276            };
277
278            // Process kernel matrix (normalize, center)
279            let processed_matrix = self.process_kernel_matrix(kernel_matrix)?;
280
281            // Compute kernel statistics
282            let stats = self.compute_kernel_statistics(&processed_matrix, y)?;
283            self.kernel_statistics
284                .insert(format!("kernel_{}", i), stats);
285
286            kernel_matrices.push(processed_matrix);
287        }
288
289        self.kernel_matrices = Some(kernel_matrices);
290
291        // Learn optimal kernel weights
292        self.learn_weights(y)?;
293
294        // Compute combined features/kernel
295        self.compute_combined_representation()?;
296
297        Ok(())
298    }
299
300    /// Transform data using learned kernel combination
301    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
302        let weights = self
303            .weights
304            .as_ref()
305            .ok_or_else(|| SklearsError::NotFitted {
306                operation: "transform".to_string(),
307            })?;
308        let training_data = self
309            .training_data
310            .as_ref()
311            .ok_or_else(|| SklearsError::NotFitted {
312                operation: "transform".to_string(),
313            })?;
314
315        let mut combined_features = None;
316
317        for (i, (base_kernel, &weight)) in self.base_kernels.iter().zip(weights.iter()).enumerate()
318        {
319            if weight.abs() < 1e-12 {
320                continue; // Skip kernels with negligible weight
321            }
322
323            let features = match &self.config.approximation_method {
324                ApproximationMethod::RandomFourierFeatures { n_components } => {
325                    self.transform_rff(x, training_data, base_kernel, *n_components)?
326                }
327                ApproximationMethod::Nystroem { n_components } => {
328                    self.transform_nystroem(x, training_data, base_kernel, *n_components)?
329                }
330                ApproximationMethod::StructuredFeatures { n_components } => {
331                    self.transform_structured(x, training_data, base_kernel, *n_components)?
332                }
333                ApproximationMethod::Exact => {
334                    return Err(SklearsError::NotImplemented(
335                        "Exact kernel transform not implemented for new data".to_string(),
336                    ));
337                }
338            };
339
340            let weighted_features = &features * weight;
341
342            match &self.config.combination_strategy {
343                CombinationStrategy::Linear
344                | CombinationStrategy::Convex
345                | CombinationStrategy::Conic => {
346                    combined_features = match combined_features {
347                        Some(existing) => Some(existing + weighted_features),
348                        None => Some(weighted_features),
349                    };
350                }
351                CombinationStrategy::Product => {
352                    combined_features = match combined_features {
353                        Some(existing) => Some(existing * weighted_features.mapv(|x| x.exp())),
354                        None => Some(weighted_features.mapv(|x| x.exp())),
355                    };
356                }
357                CombinationStrategy::Hierarchical => {
358                    // Simplified hierarchical combination
359                    combined_features = match combined_features {
360                        Some(existing) => Some(existing + weighted_features),
361                        None => Some(weighted_features),
362                    };
363                }
364            }
365        }
366
367        combined_features.ok_or_else(|| {
368            SklearsError::Other("No features generated - all kernel weights are zero".to_string())
369        })
370    }
371
372    /// Get learned kernel weights
373    pub fn kernel_weights(&self) -> Option<&Array1<f64>> {
374        self.weights.as_ref()
375    }
376
377    /// Get kernel statistics
378    pub fn kernel_stats(&self) -> &HashMap<String, KernelStatistics> {
379        &self.kernel_statistics
380    }
381
382    /// Get most important kernels based on weights
383    pub fn important_kernels(&self, threshold: f64) -> Vec<(usize, &BaseKernel, f64)> {
384        if let Some(weights) = &self.weights {
385            self.base_kernels
386                .iter()
387                .enumerate()
388                .zip(weights.iter())
389                .filter_map(|((i, kernel), &weight)| {
390                    if weight.abs() >= threshold {
391                        Some((i, kernel, weight))
392                    } else {
393                        None
394                    }
395                })
396                .collect()
397        } else {
398            Vec::new()
399        }
400    }
401
402    /// Learn optimal kernel weights
403    fn learn_weights(&mut self, y: Option<&Array1<f64>>) -> Result<()> {
404        let kernel_matrices = self.kernel_matrices.as_ref().unwrap();
405        let n_kernels = kernel_matrices.len();
406
407        let weights = match &self.config.weight_learning {
408            WeightLearningAlgorithm::Uniform => {
409                Array1::from_elem(n_kernels, 1.0 / n_kernels as f64)
410            }
411            WeightLearningAlgorithm::CenteredKernelAlignment => {
412                self.learn_cka_weights(kernel_matrices, y)?
413            }
414            WeightLearningAlgorithm::MaximumMeanDiscrepancy => {
415                self.learn_mmd_weights(kernel_matrices)?
416            }
417            WeightLearningAlgorithm::SimpleMKL { regularization } => {
418                self.learn_simple_mkl_weights(kernel_matrices, y, *regularization)?
419            }
420            WeightLearningAlgorithm::EasyMKL { radius } => {
421                self.learn_easy_mkl_weights(kernel_matrices, y, *radius)?
422            }
423            WeightLearningAlgorithm::SpectralProjected => {
424                self.learn_spectral_weights(kernel_matrices)?
425            }
426            WeightLearningAlgorithm::LocalizedMKL { bandwidth } => {
427                self.learn_localized_weights(kernel_matrices, *bandwidth)?
428            }
429            WeightLearningAlgorithm::AdaptiveMKL { cv_folds } => {
430                self.learn_adaptive_weights(kernel_matrices, y, *cv_folds)?
431            }
432        };
433
434        // Apply combination strategy constraints
435        let final_weights = self.apply_combination_constraints(weights)?;
436
437        self.weights = Some(final_weights);
438        Ok(())
439    }
440
441    /// Learn weights using Centered Kernel Alignment
442    fn learn_cka_weights(
443        &self,
444        kernel_matrices: &[Array2<f64>],
445        y: Option<&Array1<f64>>,
446    ) -> Result<Array1<f64>> {
447        if let Some(labels) = y {
448            // Supervised CKA: align with label kernel
449            let label_kernel = self.compute_label_kernel(labels)?;
450            let mut alignments = Array1::zeros(kernel_matrices.len());
451
452            for (i, kernel) in kernel_matrices.iter().enumerate() {
453                alignments[i] = self.centered_kernel_alignment(kernel, &label_kernel)?;
454            }
455
456            // Softmax normalization
457            let max_alignment = alignments.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
458            let exp_alignments = alignments.mapv(|x| (x - max_alignment).exp());
459            let sum_exp = exp_alignments.sum();
460
461            Ok(exp_alignments / sum_exp)
462        } else {
463            // Unsupervised: use kernel diversity
464            let mut weights = Array1::zeros(kernel_matrices.len());
465
466            for (i, kernel) in kernel_matrices.iter().enumerate() {
467                // Use trace as a simple quality measure
468                weights[i] = kernel.diag().sum() / kernel.nrows() as f64;
469            }
470
471            let sum_weights = weights.sum();
472            if sum_weights > 0.0 {
473                weights /= sum_weights;
474            } else {
475                weights.fill(1.0 / kernel_matrices.len() as f64);
476            }
477
478            Ok(weights)
479        }
480    }
481
482    /// Learn weights using Maximum Mean Discrepancy
483    fn learn_mmd_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
484        // Simplified MMD-based weight learning
485        let mut weights = Array1::zeros(kernel_matrices.len());
486
487        for (i, kernel) in kernel_matrices.iter().enumerate() {
488            // Use kernel matrix statistics as proxy for MMD
489            let trace = kernel.diag().sum();
490            let frobenius_norm = kernel.mapv(|x| x * x).sum().sqrt();
491            weights[i] = trace / frobenius_norm;
492        }
493
494        let sum_weights = weights.sum();
495        if sum_weights > 0.0 {
496            weights /= sum_weights;
497        } else {
498            weights.fill(1.0 / kernel_matrices.len() as f64);
499        }
500
501        Ok(weights)
502    }
503
504    /// Simplified SimpleMKL implementation
505    fn learn_simple_mkl_weights(
506        &self,
507        kernel_matrices: &[Array2<f64>],
508        _y: Option<&Array1<f64>>,
509        regularization: f64,
510    ) -> Result<Array1<f64>> {
511        // Simplified version: use kernel quality metrics
512        let mut weights = Array1::zeros(kernel_matrices.len());
513
514        for (i, kernel) in kernel_matrices.iter().enumerate() {
515            let eigenvalues = self.compute_simplified_eigenvalues(kernel)?;
516            let effective_rank =
517                eigenvalues.mapv(|x| x * x).sum().powi(2) / eigenvalues.mapv(|x| x.powi(4)).sum();
518            weights[i] = effective_rank / (1.0 + regularization);
519        }
520
521        let sum_weights = weights.sum();
522        if sum_weights > 0.0 {
523            weights /= sum_weights;
524        } else {
525            weights.fill(1.0 / kernel_matrices.len() as f64);
526        }
527
528        Ok(weights)
529    }
530
531    /// Simplified EasyMKL implementation
532    fn learn_easy_mkl_weights(
533        &self,
534        kernel_matrices: &[Array2<f64>],
535        _y: Option<&Array1<f64>>,
536        _radius: f64,
537    ) -> Result<Array1<f64>> {
538        // Use uniform weights for simplicity
539        Ok(Array1::from_elem(
540            kernel_matrices.len(),
541            1.0 / kernel_matrices.len() as f64,
542        ))
543    }
544
545    /// Learn weights using spectral methods
546    fn learn_spectral_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
547        let mut weights = Array1::zeros(kernel_matrices.len());
548
549        for (i, kernel) in kernel_matrices.iter().enumerate() {
550            let trace = kernel.diag().sum();
551            weights[i] = trace;
552        }
553
554        let sum_weights = weights.sum();
555        if sum_weights > 0.0 {
556            weights /= sum_weights;
557        } else {
558            weights.fill(1.0 / kernel_matrices.len() as f64);
559        }
560
561        Ok(weights)
562    }
563
564    /// Learn localized weights
565    fn learn_localized_weights(
566        &self,
567        kernel_matrices: &[Array2<f64>],
568        _bandwidth: f64,
569    ) -> Result<Array1<f64>> {
570        // Simplified localized MKL
571        Ok(Array1::from_elem(
572            kernel_matrices.len(),
573            1.0 / kernel_matrices.len() as f64,
574        ))
575    }
576
577    /// Learn adaptive weights with cross-validation
578    fn learn_adaptive_weights(
579        &self,
580        kernel_matrices: &[Array2<f64>],
581        _y: Option<&Array1<f64>>,
582        _cv_folds: usize,
583    ) -> Result<Array1<f64>> {
584        // Simplified adaptive MKL
585        Ok(Array1::from_elem(
586            kernel_matrices.len(),
587            1.0 / kernel_matrices.len() as f64,
588        ))
589    }
590
591    /// Apply combination strategy constraints
592    fn apply_combination_constraints(&self, mut weights: Array1<f64>) -> Result<Array1<f64>> {
593        match &self.config.combination_strategy {
594            CombinationStrategy::Convex => {
595                // Ensure non-negative and sum to 1
596                weights.mapv_inplace(|x| x.max(0.0));
597                let sum = weights.sum();
598                if sum > 0.0 {
599                    weights /= sum;
600                } else {
601                    weights.fill(1.0 / weights.len() as f64);
602                }
603            }
604            CombinationStrategy::Conic => {
605                // Ensure non-negative
606                weights.mapv_inplace(|x| x.max(0.0));
607            }
608            CombinationStrategy::Linear => {
609                // No constraints
610            }
611            CombinationStrategy::Product => {
612                // Ensure all weights are positive for product combination
613                weights.mapv_inplace(|x| x.abs().max(1e-12));
614            }
615            CombinationStrategy::Hierarchical => {
616                // Simplified hierarchical constraints
617                weights.mapv_inplace(|x| x.max(0.0));
618                let sum = weights.sum();
619                if sum > 0.0 {
620                    weights /= sum;
621                } else {
622                    weights.fill(1.0 / weights.len() as f64);
623                }
624            }
625        }
626
627        Ok(weights)
628    }
629
630    /// Compute label kernel for supervised learning
631    fn compute_label_kernel(&self, labels: &Array1<f64>) -> Result<Array2<f64>> {
632        let n = labels.len();
633        let mut label_kernel = Array2::zeros((n, n));
634
635        for i in 0..n {
636            for j in 0..n {
637                label_kernel[[i, j]] = if (labels[i] - labels[j]).abs() < 1e-10 {
638                    1.0
639                } else {
640                    0.0
641                };
642            }
643        }
644
645        Ok(label_kernel)
646    }
647
648    /// Compute centered kernel alignment
649    fn centered_kernel_alignment(&self, k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
650        let n = k1.nrows() as f64;
651        let ones = Array2::ones((k1.nrows(), k1.ncols())) / n;
652
653        // Center kernels
654        let k1_centered = k1 - &ones.dot(k1) - &k1.dot(&ones) + &ones.dot(k1).dot(&ones);
655        let k2_centered = k2 - &ones.dot(k2) - &k2.dot(&ones) + &ones.dot(k2).dot(&ones);
656
657        // Compute alignment
658        let numerator = (&k1_centered * &k2_centered).sum();
659        let denominator =
660            ((&k1_centered * &k1_centered).sum() * (&k2_centered * &k2_centered).sum()).sqrt();
661
662        if denominator > 1e-12 {
663            Ok(numerator / denominator)
664        } else {
665            Ok(0.0)
666        }
667    }
668
669    /// Process kernel matrix (normalize and center if configured)
670    fn process_kernel_matrix(&self, mut kernel: Array2<f64>) -> Result<Array2<f64>> {
671        if self.config.normalize_kernels {
672            // Normalize kernel matrix
673            let diag = kernel.diag();
674            let norm_matrix = diag.insert_axis(Axis(1)).dot(&diag.insert_axis(Axis(0)));
675            for i in 0..kernel.nrows() {
676                for j in 0..kernel.ncols() {
677                    if norm_matrix[[i, j]] > 1e-12 {
678                        kernel[[i, j]] /= norm_matrix[[i, j]].sqrt();
679                    }
680                }
681            }
682        }
683
684        if self.config.center_kernels {
685            // Center kernel matrix
686            let n = kernel.nrows() as f64;
687            let row_means = kernel.mean_axis(Axis(1)).unwrap();
688            let col_means = kernel.mean_axis(Axis(0)).unwrap();
689            let total_mean = kernel.mean().unwrap();
690
691            for i in 0..kernel.nrows() {
692                for j in 0..kernel.ncols() {
693                    kernel[[i, j]] = kernel[[i, j]] - row_means[i] - col_means[j] + total_mean;
694                }
695            }
696        }
697
698        Ok(kernel)
699    }
700
701    /// Placeholder implementations for different approximation methods
702    fn compute_rff_approximation(
703        &mut self,
704        x: &Array2<f64>,
705        kernel: &BaseKernel,
706        n_components: usize,
707    ) -> Result<Array2<f64>> {
708        match kernel {
709            BaseKernel::RBF { gamma } => self.compute_rbf_rff_matrix(x, *gamma, n_components),
710            BaseKernel::Laplacian { gamma } => {
711                self.compute_laplacian_rff_matrix(x, *gamma, n_components)
712            }
713            _ => {
714                // Fallback to exact computation for unsupported kernels
715                self.compute_exact_kernel_matrix(x, kernel)
716            }
717        }
718    }
719
720    fn compute_rbf_rff_matrix(
721        &mut self,
722        x: &Array2<f64>,
723        gamma: f64,
724        n_components: usize,
725    ) -> Result<Array2<f64>> {
726        let (n_samples, n_features) = x.dim();
727
728        // Generate random weights
729        let mut weights = Array2::zeros((n_components, n_features));
730        for i in 0..n_components {
731            for j in 0..n_features {
732                weights[[i, j]] = self.rng.sample::<f64, _>(StandardNormal) * (2.0 * gamma).sqrt();
733            }
734        }
735
736        // Generate random bias
737        let mut bias = Array1::zeros(n_components);
738        for i in 0..n_components {
739            bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
740        }
741
742        // Compute features
743        let projection = x.dot(&weights.t()) + &bias;
744        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
745
746        // Compute Gram matrix from features
747        Ok(features.dot(&features.t()))
748    }
749
750    fn compute_laplacian_rff_matrix(
751        &mut self,
752        x: &Array2<f64>,
753        gamma: f64,
754        n_components: usize,
755    ) -> Result<Array2<f64>> {
756        // Use Cauchy distribution for Laplacian kernel
757        let (n_samples, n_features) = x.dim();
758
759        // Generate random weights from Cauchy distribution (approximated)
760        let mut weights = Array2::zeros((n_components, n_features));
761        for i in 0..n_components {
762            for j in 0..n_features {
763                let u: f64 = self.rng.gen_range(0.001..0.999);
764                weights[[i, j]] = ((std::f64::consts::PI * (u - 0.5)).tan()) * gamma;
765            }
766        }
767
768        // Generate random bias
769        let mut bias = Array1::zeros(n_components);
770        for i in 0..n_components {
771            bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
772        }
773
774        // Compute features
775        let projection = x.dot(&weights.t()) + &bias;
776        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
777
778        // Compute Gram matrix from features
779        Ok(features.dot(&features.t()))
780    }
781
782    fn compute_nystroem_approximation(
783        &mut self,
784        x: &Array2<f64>,
785        kernel: &BaseKernel,
786        n_components: usize,
787    ) -> Result<Array2<f64>> {
788        let (n_samples, _) = x.dim();
789        let n_landmarks = n_components.min(n_samples);
790
791        // Select random landmarks
792        let mut landmark_indices = Vec::new();
793        for _ in 0..n_landmarks {
794            landmark_indices.push(self.rng.gen_range(0..n_samples));
795        }
796
797        // Compute kernel between all points and landmarks
798        let mut kernel_matrix = Array2::zeros((n_samples, n_landmarks));
799        for i in 0..n_samples {
800            for j in 0..n_landmarks {
801                let landmark_idx = landmark_indices[j];
802                kernel_matrix[[i, j]] =
803                    kernel.evaluate(&x.row(i).to_owned(), &x.row(landmark_idx).to_owned());
804            }
805        }
806
807        // For simplicity, return kernel matrix with landmarks (not full Nyström)
808        Ok(kernel_matrix.dot(&kernel_matrix.t()))
809    }
810
811    fn compute_structured_approximation(
812        &mut self,
813        x: &Array2<f64>,
814        kernel: &BaseKernel,
815        n_components: usize,
816    ) -> Result<Array2<f64>> {
817        // Fallback to RFF for structured features
818        self.compute_rff_approximation(x, kernel, n_components)
819    }
820
821    fn compute_exact_kernel_matrix(
822        &self,
823        x: &Array2<f64>,
824        kernel: &BaseKernel,
825    ) -> Result<Array2<f64>> {
826        let n_samples = x.nrows();
827        let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
828
829        for i in 0..n_samples {
830            for j in i..n_samples {
831                let value = kernel.evaluate(&x.row(i).to_owned(), &x.row(j).to_owned());
832                kernel_matrix[[i, j]] = value;
833                kernel_matrix[[j, i]] = value;
834            }
835        }
836
837        Ok(kernel_matrix)
838    }
839
840    /// Placeholder transform methods
841    fn transform_rff(
842        &self,
843        x: &Array2<f64>,
844        _training_data: &Array2<f64>,
845        _kernel: &BaseKernel,
846        n_components: usize,
847    ) -> Result<Array2<f64>> {
848        // Return random features as placeholder
849        let (n_samples, _) = x.dim();
850        Ok(Array2::zeros((n_samples, n_components)))
851    }
852
853    fn transform_nystroem(
854        &self,
855        x: &Array2<f64>,
856        _training_data: &Array2<f64>,
857        _kernel: &BaseKernel,
858        n_components: usize,
859    ) -> Result<Array2<f64>> {
860        // Return placeholder features
861        let (n_samples, _) = x.dim();
862        Ok(Array2::zeros((n_samples, n_components)))
863    }
864
865    fn transform_structured(
866        &self,
867        x: &Array2<f64>,
868        _training_data: &Array2<f64>,
869        _kernel: &BaseKernel,
870        n_components: usize,
871    ) -> Result<Array2<f64>> {
872        // Return placeholder features
873        let (n_samples, _) = x.dim();
874        Ok(Array2::zeros((n_samples, n_components)))
875    }
876
877    fn compute_combined_representation(&mut self) -> Result<()> {
878        // Placeholder for combining features
879        Ok(())
880    }
881
882    fn compute_kernel_statistics(
883        &self,
884        kernel: &Array2<f64>,
885        _y: Option<&Array1<f64>>,
886    ) -> Result<KernelStatistics> {
887        let mut stats = KernelStatistics::new();
888
889        // Compute basic statistics
890        stats.alignment = kernel.diag().mean().unwrap_or(0.0);
891
892        // Simplified eigenspectrum (just use diagonal)
893        stats.eigenspectrum = kernel.diag().to_owned();
894
895        // Effective rank approximation
896        let trace = kernel.diag().sum();
897        let frobenius_sq = kernel.mapv(|x| x * x).sum();
898        stats.effective_rank = if frobenius_sq > 1e-12 {
899            trace.powi(2) / frobenius_sq
900        } else {
901            0.0
902        };
903
904        // Diversity measure (variance of diagonal)
905        stats.diversity = kernel.diag().var(0.0);
906
907        // Complexity measure (condition number approximation)
908        let diag = kernel.diag();
909        let max_eig = diag.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
910        let min_eig = diag.iter().fold(f64::INFINITY, |a, &b| a.min(b.max(1e-12)));
911        stats.complexity = max_eig / min_eig;
912
913        Ok(stats)
914    }
915
916    fn compute_simplified_eigenvalues(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
917        // Simplified eigenvalue computation using diagonal
918        Ok(matrix.diag().to_owned())
919    }
920}
921
922#[allow(non_snake_case)]
923#[cfg(test)]
924mod tests {
925    use super::*;
926    use scirs2_core::ndarray::array;
927
928    #[test]
929    fn test_base_kernel_evaluation() {
930        let x = array![1.0, 2.0, 3.0];
931        let y = array![1.0, 2.0, 3.0];
932
933        let rbf_kernel = BaseKernel::RBF { gamma: 0.1 };
934        let value = rbf_kernel.evaluate(&x, &y);
935        assert!((value - 1.0).abs() < 1e-10); // Same points should give 1.0
936
937        let linear_kernel = BaseKernel::Linear;
938        let value = linear_kernel.evaluate(&x, &y);
939        assert!((value - 14.0).abs() < 1e-10); // 1*1 + 2*2 + 3*3 = 14
940    }
941
942    #[test]
943    fn test_kernel_names() {
944        let rbf = BaseKernel::RBF { gamma: 0.5 };
945        assert_eq!(rbf.name(), "RBF(gamma=0.5000)");
946
947        let linear = BaseKernel::Linear;
948        assert_eq!(linear.name(), "Linear");
949
950        let poly = BaseKernel::Polynomial {
951            degree: 2.0,
952            gamma: 1.0,
953            coef0: 0.0,
954        };
955        assert_eq!(
956            poly.name(),
957            "Polynomial(degree=2.0, gamma=1.0000, coef0=0.0000)"
958        );
959    }
960
961    #[test]
962    fn test_multiple_kernel_learning_basic() {
963        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
964
965        let base_kernels = vec![
966            BaseKernel::RBF { gamma: 0.1 },
967            BaseKernel::Linear,
968            BaseKernel::Polynomial {
969                degree: 2.0,
970                gamma: 1.0,
971                coef0: 0.0,
972            },
973        ];
974
975        let mut mkl = MultipleKernelLearning::new(base_kernels).with_random_state(42);
976
977        mkl.fit(&x, None).unwrap();
978
979        let weights = mkl.kernel_weights().unwrap();
980        assert_eq!(weights.len(), 3);
981        assert!((weights.sum() - 1.0).abs() < 1e-10); // Should sum to 1 for convex combination
982    }
983
984    #[test]
985    fn test_kernel_statistics() {
986        let kernel = array![[1.0, 0.5, 0.2], [0.5, 1.0, 0.3], [0.2, 0.3, 1.0]];
987
988        let mkl = MultipleKernelLearning::new(vec![]);
989        let stats = mkl.compute_kernel_statistics(&kernel, None).unwrap();
990
991        assert!((stats.alignment - 1.0).abs() < 1e-10); // Diagonal mean should be 1.0
992        assert!(stats.effective_rank > 0.0);
993        assert!(stats.diversity >= 0.0);
994    }
995
996    #[test]
997    fn test_combination_strategies() {
998        let weights = array![0.5, -0.3, 0.8];
999
1000        let mut mkl = MultipleKernelLearning::new(vec![]);
1001        mkl.config.combination_strategy = CombinationStrategy::Convex;
1002
1003        let constrained = mkl.apply_combination_constraints(weights.clone()).unwrap();
1004
1005        // Should be non-negative and sum to 1
1006        assert!(constrained.iter().all(|&x| x >= 0.0));
1007        assert!((constrained.sum() - 1.0).abs() < 1e-10);
1008    }
1009
1010    #[test]
1011    fn test_mkl_config() {
1012        let config = MultiKernelConfig {
1013            combination_strategy: CombinationStrategy::Linear,
1014            weight_learning: WeightLearningAlgorithm::SimpleMKL {
1015                regularization: 0.01,
1016            },
1017            approximation_method: ApproximationMethod::Nystroem { n_components: 50 },
1018            max_iterations: 200,
1019            tolerance: 1e-8,
1020            normalize_kernels: false,
1021            center_kernels: false,
1022            regularization: 0.001,
1023        };
1024
1025        assert!(matches!(
1026            config.combination_strategy,
1027            CombinationStrategy::Linear
1028        ));
1029        assert!(matches!(
1030            config.weight_learning,
1031            WeightLearningAlgorithm::SimpleMKL { .. }
1032        ));
1033        assert_eq!(config.max_iterations, 200);
1034        assert!(!config.normalize_kernels);
1035    }
1036
1037    #[test]
1038    fn test_important_kernels() {
1039        let base_kernels = vec![
1040            BaseKernel::RBF { gamma: 0.1 },
1041            BaseKernel::Linear,
1042            BaseKernel::Polynomial {
1043                degree: 2.0,
1044                gamma: 1.0,
1045                coef0: 0.0,
1046            },
1047        ];
1048
1049        let mut mkl = MultipleKernelLearning::new(base_kernels);
1050        mkl.weights = Some(array![0.6, 0.05, 0.35]);
1051
1052        let important = mkl.important_kernels(0.1);
1053        assert_eq!(important.len(), 2); // Only kernels with weight >= 0.1
1054        assert_eq!(important[0].0, 0); // First kernel (RBF)
1055        assert_eq!(important[1].0, 2); // Third kernel (Polynomial)
1056    }
1057
1058    #[test]
1059    fn test_supervised_vs_unsupervised() {
1060        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1061        let y = array![0.0, 1.0, 0.0, 1.0];
1062
1063        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1064
1065        let mut mkl_unsupervised =
1066            MultipleKernelLearning::new(base_kernels.clone()).with_random_state(42);
1067        mkl_unsupervised.fit(&x, None).unwrap();
1068
1069        let mut mkl_supervised = MultipleKernelLearning::new(base_kernels).with_random_state(42);
1070        mkl_supervised.fit(&x, Some(&y)).unwrap();
1071
1072        // Both should work without errors
1073        assert!(mkl_unsupervised.kernel_weights().is_some());
1074        assert!(mkl_supervised.kernel_weights().is_some());
1075    }
1076
1077    #[test]
1078    fn test_transform_compatibility() {
1079        let x_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1080        let x_test = array![[2.0, 3.0], [4.0, 5.0]];
1081
1082        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1083
1084        let mut mkl = MultipleKernelLearning::new(base_kernels)
1085            .with_config(MultiKernelConfig {
1086                approximation_method: ApproximationMethod::RandomFourierFeatures {
1087                    n_components: 10,
1088                },
1089                ..Default::default()
1090            })
1091            .with_random_state(42);
1092
1093        mkl.fit(&x_train, None).unwrap();
1094        let features = mkl.transform(&x_test).unwrap();
1095
1096        assert_eq!(features.nrows(), 2); // Two test samples
1097        assert!(features.ncols() > 0); // Some features generated
1098    }
1099}