Skip to main content

sklears_kernel_approximation/
multi_kernel_learning.rs

1use scirs2_core::ndarray::{Array1, Array2, Axis};
2use scirs2_core::random::rngs::StdRng;
3use scirs2_core::random::SeedableRng;
4use scirs2_core::RngExt;
5use scirs2_core::StandardNormal;
6use sklears_core::error::{Result, SklearsError};
7use std::collections::HashMap;
8
9/// Multiple Kernel Learning (MKL) methods for kernel approximation
10///
11/// This module provides methods for learning optimal combinations of multiple
12/// kernels, enabling automatic kernel selection and weighting for improved
13/// machine learning performance.
14///
15/// Base kernel type for multiple kernel learning
16#[derive(Debug, Clone)]
17/// BaseKernel
18pub enum BaseKernel {
19    /// RBF kernel with specific gamma parameter
20    RBF { gamma: f64 },
21    /// Polynomial kernel with degree, gamma, and coef0
22    Polynomial { degree: f64, gamma: f64, coef0: f64 },
23    /// Laplacian kernel with gamma parameter
24    Laplacian { gamma: f64 },
25    /// Linear kernel
26    Linear,
27    /// Sigmoid kernel with gamma and coef0
28    Sigmoid { gamma: f64, coef0: f64 },
29    /// Custom kernel with user-defined function
30    Custom {
31        name: String,
32        kernel_fn: fn(&Array1<f64>, &Array1<f64>) -> f64,
33    },
34}
35
36impl BaseKernel {
37    /// Evaluate kernel function between two samples
38    pub fn evaluate(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
39        match self {
40            BaseKernel::RBF { gamma } => {
41                let diff = x - y;
42                let squared_dist = diff.mapv(|x| x * x).sum();
43                (-gamma * squared_dist).exp()
44            }
45            BaseKernel::Polynomial {
46                degree,
47                gamma,
48                coef0,
49            } => {
50                let dot_product = x.dot(y);
51                (gamma * dot_product + coef0).powf(*degree)
52            }
53            BaseKernel::Laplacian { gamma } => {
54                let diff = x - y;
55                let manhattan_dist = diff.mapv(|x| x.abs()).sum();
56                (-gamma * manhattan_dist).exp()
57            }
58            BaseKernel::Linear => x.dot(y),
59            BaseKernel::Sigmoid { gamma, coef0 } => {
60                let dot_product = x.dot(y);
61                (gamma * dot_product + coef0).tanh()
62            }
63            BaseKernel::Custom { kernel_fn, .. } => kernel_fn(x, y),
64        }
65    }
66
67    /// Get kernel name for identification
68    pub fn name(&self) -> String {
69        match self {
70            BaseKernel::RBF { gamma } => format!("RBF(gamma={:.4})", gamma),
71            BaseKernel::Polynomial {
72                degree,
73                gamma,
74                coef0,
75            } => {
76                format!(
77                    "Polynomial(degree={:.1}, gamma={:.4}, coef0={:.4})",
78                    degree, gamma, coef0
79                )
80            }
81            BaseKernel::Laplacian { gamma } => format!("Laplacian(gamma={:.4})", gamma),
82            BaseKernel::Linear => "Linear".to_string(),
83            BaseKernel::Sigmoid { gamma, coef0 } => {
84                format!("Sigmoid(gamma={:.4}, coef0={:.4})", gamma, coef0)
85            }
86            BaseKernel::Custom { name, .. } => format!("Custom({})", name),
87        }
88    }
89}
90
91/// Kernel combination strategy
92#[derive(Debug, Clone)]
93/// CombinationStrategy
94pub enum CombinationStrategy {
95    /// Linear combination: K = Σ αᵢ Kᵢ
96    Linear,
97    /// Product combination: K = Π Kᵢ^αᵢ
98    Product,
99    /// Convex combination with unit sum constraint: Σ αᵢ = 1
100    Convex,
101    /// Conic combination with non-negativity: αᵢ ≥ 0
102    Conic,
103    /// Hierarchical combination with tree structure
104    Hierarchical,
105}
106
107/// Kernel weight learning algorithm
108#[derive(Debug, Clone)]
109/// WeightLearningAlgorithm
110pub enum WeightLearningAlgorithm {
111    /// Uniform weights (baseline)
112    Uniform,
113    /// Centered kernel alignment optimization
114    CenteredKernelAlignment,
115    /// Maximum mean discrepancy minimization
116    MaximumMeanDiscrepancy,
117    /// SimpleMKL algorithm with QCQP optimization
118    SimpleMKL { regularization: f64 },
119    /// EasyMKL algorithm with radius constraint
120    EasyMKL { radius: f64 },
121    /// SPKM (Spectral Projected Kernel Machine)
122    SpectralProjected,
123    /// Localized MKL with spatial weights
124    LocalizedMKL { bandwidth: f64 },
125    /// Adaptive MKL with cross-validation
126    AdaptiveMKL { cv_folds: usize },
127}
128
129/// Kernel approximation method for each base kernel
130#[derive(Debug, Clone)]
131/// ApproximationMethod
132pub enum ApproximationMethod {
133    /// Random Fourier Features
134    RandomFourierFeatures { n_components: usize },
135    /// Nyström approximation
136    Nystroem { n_components: usize },
137    /// Structured random features
138    StructuredFeatures { n_components: usize },
139    /// Exact kernel (no approximation)
140    Exact,
141}
142
143/// Configuration for multiple kernel learning
144#[derive(Debug, Clone)]
145/// MultiKernelConfig
146pub struct MultiKernelConfig {
147    /// combination_strategy
148    pub combination_strategy: CombinationStrategy,
149    /// weight_learning
150    pub weight_learning: WeightLearningAlgorithm,
151    /// approximation_method
152    pub approximation_method: ApproximationMethod,
153    /// max_iterations
154    pub max_iterations: usize,
155    /// tolerance
156    pub tolerance: f64,
157    /// normalize_kernels
158    pub normalize_kernels: bool,
159    /// center_kernels
160    pub center_kernels: bool,
161    /// regularization
162    pub regularization: f64,
163}
164
165impl Default for MultiKernelConfig {
166    fn default() -> Self {
167        Self {
168            combination_strategy: CombinationStrategy::Convex,
169            weight_learning: WeightLearningAlgorithm::CenteredKernelAlignment,
170            approximation_method: ApproximationMethod::RandomFourierFeatures { n_components: 100 },
171            max_iterations: 100,
172            tolerance: 1e-6,
173            normalize_kernels: true,
174            center_kernels: true,
175            regularization: 1e-3,
176        }
177    }
178}
179
180/// Multiple Kernel Learning approximation
181///
182/// Learns optimal weights for combining multiple kernel approximations
183/// to create a single, improved kernel approximation.
184pub struct MultipleKernelLearning {
185    base_kernels: Vec<BaseKernel>,
186    config: MultiKernelConfig,
187    weights: Option<Array1<f64>>,
188    kernel_matrices: Option<Vec<Array2<f64>>>,
189    combined_features: Option<Array2<f64>>,
190    random_state: Option<u64>,
191    rng: StdRng,
192    training_data: Option<Array2<f64>>,
193    kernel_statistics: HashMap<String, KernelStatistics>,
194}
195
196/// Statistics for individual kernels
197#[derive(Debug, Clone)]
198/// KernelStatistics
199pub struct KernelStatistics {
200    /// alignment
201    pub alignment: f64,
202    /// eigenspectrum
203    pub eigenspectrum: Array1<f64>,
204    /// effective_rank
205    pub effective_rank: f64,
206    /// diversity
207    pub diversity: f64,
208    /// complexity
209    pub complexity: f64,
210}
211
212impl Default for KernelStatistics {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218impl KernelStatistics {
219    pub fn new() -> Self {
220        Self {
221            alignment: 0.0,
222            eigenspectrum: Array1::zeros(0),
223            effective_rank: 0.0,
224            diversity: 0.0,
225            complexity: 0.0,
226        }
227    }
228}
229
230impl MultipleKernelLearning {
231    /// Create a new multiple kernel learning instance
232    pub fn new(base_kernels: Vec<BaseKernel>) -> Self {
233        let rng = StdRng::seed_from_u64(42);
234        Self {
235            base_kernels,
236            config: MultiKernelConfig::default(),
237            weights: None,
238            kernel_matrices: None,
239            combined_features: None,
240            random_state: None,
241            rng,
242            training_data: None,
243            kernel_statistics: HashMap::new(),
244        }
245    }
246
247    /// Set configuration
248    pub fn with_config(mut self, config: MultiKernelConfig) -> Self {
249        self.config = config;
250        self
251    }
252
253    /// Set random state for reproducibility
254    pub fn with_random_state(mut self, random_state: u64) -> Self {
255        self.random_state = Some(random_state);
256        self.rng = StdRng::seed_from_u64(random_state);
257        self
258    }
259
260    /// Fit the multiple kernel learning model
261    pub fn fit(&mut self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
262        let (_n_samples, _) = x.dim();
263
264        // Store training data
265        self.training_data = Some(x.clone());
266
267        // Compute or approximate individual kernel matrices
268        let mut kernel_matrices = Vec::new();
269        let base_kernels = self.base_kernels.clone(); // Clone to avoid borrowing issues
270
271        for (i, base_kernel) in base_kernels.iter().enumerate() {
272            let kernel_matrix = match &self.config.approximation_method {
273                ApproximationMethod::RandomFourierFeatures { n_components } => {
274                    self.compute_rff_approximation(x, base_kernel, *n_components)?
275                }
276                ApproximationMethod::Nystroem { n_components } => {
277                    self.compute_nystroem_approximation(x, base_kernel, *n_components)?
278                }
279                ApproximationMethod::StructuredFeatures { n_components } => {
280                    self.compute_structured_approximation(x, base_kernel, *n_components)?
281                }
282                ApproximationMethod::Exact => self.compute_exact_kernel_matrix(x, base_kernel)?,
283            };
284
285            // Process kernel matrix (normalize, center)
286            let processed_matrix = self.process_kernel_matrix(kernel_matrix)?;
287
288            // Compute kernel statistics
289            let stats = self.compute_kernel_statistics(&processed_matrix, y)?;
290            self.kernel_statistics
291                .insert(format!("kernel_{}", i), stats);
292
293            kernel_matrices.push(processed_matrix);
294        }
295
296        self.kernel_matrices = Some(kernel_matrices);
297
298        // Learn optimal kernel weights
299        self.learn_weights(y)?;
300
301        // Compute combined features/kernel
302        self.compute_combined_representation()?;
303
304        Ok(())
305    }
306
307    /// Transform data using learned kernel combination
308    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
309        let weights = self
310            .weights
311            .as_ref()
312            .ok_or_else(|| SklearsError::NotFitted {
313                operation: "transform".to_string(),
314            })?;
315        let training_data = self
316            .training_data
317            .as_ref()
318            .ok_or_else(|| SklearsError::NotFitted {
319                operation: "transform".to_string(),
320            })?;
321
322        let mut combined_features = None;
323
324        for (base_kernel, &weight) in self.base_kernels.iter().zip(weights.iter()) {
325            if weight.abs() < 1e-12 {
326                continue; // Skip kernels with negligible weight
327            }
328
329            let features = match &self.config.approximation_method {
330                ApproximationMethod::RandomFourierFeatures { n_components } => {
331                    self.transform_rff(x, training_data, base_kernel, *n_components)?
332                }
333                ApproximationMethod::Nystroem { n_components } => {
334                    self.transform_nystroem(x, training_data, base_kernel, *n_components)?
335                }
336                ApproximationMethod::StructuredFeatures { n_components } => {
337                    self.transform_structured(x, training_data, base_kernel, *n_components)?
338                }
339                ApproximationMethod::Exact => {
340                    return Err(SklearsError::NotImplemented(
341                        "Exact kernel transform not implemented for new data".to_string(),
342                    ));
343                }
344            };
345
346            let weighted_features = &features * weight;
347
348            match &self.config.combination_strategy {
349                CombinationStrategy::Linear
350                | CombinationStrategy::Convex
351                | CombinationStrategy::Conic => {
352                    combined_features = match combined_features {
353                        Some(existing) => Some(existing + weighted_features),
354                        None => Some(weighted_features),
355                    };
356                }
357                CombinationStrategy::Product => {
358                    combined_features = match combined_features {
359                        Some(existing) => Some(existing * weighted_features.mapv(|x| x.exp())),
360                        None => Some(weighted_features.mapv(|x| x.exp())),
361                    };
362                }
363                CombinationStrategy::Hierarchical => {
364                    // Simplified hierarchical combination
365                    combined_features = match combined_features {
366                        Some(existing) => Some(existing + weighted_features),
367                        None => Some(weighted_features),
368                    };
369                }
370            }
371        }
372
373        combined_features.ok_or_else(|| {
374            SklearsError::Other("No features generated - all kernel weights are zero".to_string())
375        })
376    }
377
378    /// Get learned kernel weights
379    pub fn kernel_weights(&self) -> Option<&Array1<f64>> {
380        self.weights.as_ref()
381    }
382
383    /// Get kernel statistics
384    pub fn kernel_stats(&self) -> &HashMap<String, KernelStatistics> {
385        &self.kernel_statistics
386    }
387
388    /// Get most important kernels based on weights
389    pub fn important_kernels(&self, threshold: f64) -> Vec<(usize, &BaseKernel, f64)> {
390        if let Some(weights) = &self.weights {
391            self.base_kernels
392                .iter()
393                .enumerate()
394                .zip(weights.iter())
395                .filter_map(|((i, kernel), &weight)| {
396                    if weight.abs() >= threshold {
397                        Some((i, kernel, weight))
398                    } else {
399                        None
400                    }
401                })
402                .collect()
403        } else {
404            Vec::new()
405        }
406    }
407
408    /// Learn optimal kernel weights
409    fn learn_weights(&mut self, y: Option<&Array1<f64>>) -> Result<()> {
410        let kernel_matrices = self
411            .kernel_matrices
412            .as_ref()
413            .expect("operation should succeed");
414        let n_kernels = kernel_matrices.len();
415
416        let weights = match &self.config.weight_learning {
417            WeightLearningAlgorithm::Uniform => {
418                Array1::from_elem(n_kernels, 1.0 / n_kernels as f64)
419            }
420            WeightLearningAlgorithm::CenteredKernelAlignment => {
421                self.learn_cka_weights(kernel_matrices, y)?
422            }
423            WeightLearningAlgorithm::MaximumMeanDiscrepancy => {
424                self.learn_mmd_weights(kernel_matrices)?
425            }
426            WeightLearningAlgorithm::SimpleMKL { regularization } => {
427                self.learn_simple_mkl_weights(kernel_matrices, y, *regularization)?
428            }
429            WeightLearningAlgorithm::EasyMKL { radius } => {
430                self.learn_easy_mkl_weights(kernel_matrices, y, *radius)?
431            }
432            WeightLearningAlgorithm::SpectralProjected => {
433                self.learn_spectral_weights(kernel_matrices)?
434            }
435            WeightLearningAlgorithm::LocalizedMKL { bandwidth } => {
436                self.learn_localized_weights(kernel_matrices, *bandwidth)?
437            }
438            WeightLearningAlgorithm::AdaptiveMKL { cv_folds } => {
439                self.learn_adaptive_weights(kernel_matrices, y, *cv_folds)?
440            }
441        };
442
443        // Apply combination strategy constraints
444        let final_weights = self.apply_combination_constraints(weights)?;
445
446        self.weights = Some(final_weights);
447        Ok(())
448    }
449
450    /// Learn weights using Centered Kernel Alignment
451    fn learn_cka_weights(
452        &self,
453        kernel_matrices: &[Array2<f64>],
454        y: Option<&Array1<f64>>,
455    ) -> Result<Array1<f64>> {
456        if let Some(labels) = y {
457            // Supervised CKA: align with label kernel
458            let label_kernel = self.compute_label_kernel(labels)?;
459            let mut alignments = Array1::zeros(kernel_matrices.len());
460
461            for (i, kernel) in kernel_matrices.iter().enumerate() {
462                alignments[i] = self.centered_kernel_alignment(kernel, &label_kernel)?;
463            }
464
465            // Softmax normalization
466            let max_alignment = alignments.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
467            let exp_alignments = alignments.mapv(|x| (x - max_alignment).exp());
468            let sum_exp = exp_alignments.sum();
469
470            Ok(exp_alignments / sum_exp)
471        } else {
472            // Unsupervised: use kernel diversity
473            let mut weights = Array1::zeros(kernel_matrices.len());
474
475            for (i, kernel) in kernel_matrices.iter().enumerate() {
476                // Use trace as a simple quality measure
477                weights[i] = kernel.diag().sum() / kernel.nrows() as f64;
478            }
479
480            let sum_weights = weights.sum();
481            if sum_weights > 0.0 {
482                weights /= sum_weights;
483            } else {
484                weights.fill(1.0 / kernel_matrices.len() as f64);
485            }
486
487            Ok(weights)
488        }
489    }
490
491    /// Learn weights using Maximum Mean Discrepancy
492    fn learn_mmd_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
493        // Simplified MMD-based weight learning
494        let mut weights = Array1::zeros(kernel_matrices.len());
495
496        for (i, kernel) in kernel_matrices.iter().enumerate() {
497            // Use kernel matrix statistics as proxy for MMD
498            let trace = kernel.diag().sum();
499            let frobenius_norm = kernel.mapv(|x| x * x).sum().sqrt();
500            weights[i] = trace / frobenius_norm;
501        }
502
503        let sum_weights = weights.sum();
504        if sum_weights > 0.0 {
505            weights /= sum_weights;
506        } else {
507            weights.fill(1.0 / kernel_matrices.len() as f64);
508        }
509
510        Ok(weights)
511    }
512
513    /// Simplified SimpleMKL implementation
514    fn learn_simple_mkl_weights(
515        &self,
516        kernel_matrices: &[Array2<f64>],
517        _y: Option<&Array1<f64>>,
518        regularization: f64,
519    ) -> Result<Array1<f64>> {
520        // Simplified version: use kernel quality metrics
521        let mut weights = Array1::zeros(kernel_matrices.len());
522
523        for (i, kernel) in kernel_matrices.iter().enumerate() {
524            let eigenvalues = self.compute_simplified_eigenvalues(kernel)?;
525            let effective_rank =
526                eigenvalues.mapv(|x| x * x).sum().powi(2) / eigenvalues.mapv(|x| x.powi(4)).sum();
527            weights[i] = effective_rank / (1.0 + regularization);
528        }
529
530        let sum_weights = weights.sum();
531        if sum_weights > 0.0 {
532            weights /= sum_weights;
533        } else {
534            weights.fill(1.0 / kernel_matrices.len() as f64);
535        }
536
537        Ok(weights)
538    }
539
540    /// Simplified EasyMKL implementation
541    fn learn_easy_mkl_weights(
542        &self,
543        kernel_matrices: &[Array2<f64>],
544        _y: Option<&Array1<f64>>,
545        _radius: f64,
546    ) -> Result<Array1<f64>> {
547        // Use uniform weights for simplicity
548        Ok(Array1::from_elem(
549            kernel_matrices.len(),
550            1.0 / kernel_matrices.len() as f64,
551        ))
552    }
553
554    /// Learn weights using spectral methods
555    fn learn_spectral_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
556        let mut weights = Array1::zeros(kernel_matrices.len());
557
558        for (i, kernel) in kernel_matrices.iter().enumerate() {
559            let trace = kernel.diag().sum();
560            weights[i] = trace;
561        }
562
563        let sum_weights = weights.sum();
564        if sum_weights > 0.0 {
565            weights /= sum_weights;
566        } else {
567            weights.fill(1.0 / kernel_matrices.len() as f64);
568        }
569
570        Ok(weights)
571    }
572
573    /// Learn localized weights
574    fn learn_localized_weights(
575        &self,
576        kernel_matrices: &[Array2<f64>],
577        _bandwidth: f64,
578    ) -> Result<Array1<f64>> {
579        // Simplified localized MKL
580        Ok(Array1::from_elem(
581            kernel_matrices.len(),
582            1.0 / kernel_matrices.len() as f64,
583        ))
584    }
585
586    /// Learn adaptive weights with cross-validation
587    fn learn_adaptive_weights(
588        &self,
589        kernel_matrices: &[Array2<f64>],
590        _y: Option<&Array1<f64>>,
591        _cv_folds: usize,
592    ) -> Result<Array1<f64>> {
593        // Simplified adaptive MKL
594        Ok(Array1::from_elem(
595            kernel_matrices.len(),
596            1.0 / kernel_matrices.len() as f64,
597        ))
598    }
599
600    /// Apply combination strategy constraints
601    fn apply_combination_constraints(&self, mut weights: Array1<f64>) -> Result<Array1<f64>> {
602        match &self.config.combination_strategy {
603            CombinationStrategy::Convex => {
604                // Ensure non-negative and sum to 1
605                weights.mapv_inplace(|x| x.max(0.0));
606                let sum = weights.sum();
607                if sum > 0.0 {
608                    weights /= sum;
609                } else {
610                    let uniform_val = 1.0 / weights.len() as f64;
611                    weights.fill(uniform_val);
612                }
613            }
614            CombinationStrategy::Conic => {
615                // Ensure non-negative
616                weights.mapv_inplace(|x| x.max(0.0));
617            }
618            CombinationStrategy::Linear => {
619                // No constraints
620            }
621            CombinationStrategy::Product => {
622                // Ensure all weights are positive for product combination
623                weights.mapv_inplace(|x| x.abs().max(1e-12));
624            }
625            CombinationStrategy::Hierarchical => {
626                // Simplified hierarchical constraints
627                weights.mapv_inplace(|x| x.max(0.0));
628                let sum = weights.sum();
629                if sum > 0.0 {
630                    weights /= sum;
631                } else {
632                    let uniform_val = 1.0 / weights.len() as f64;
633                    weights.fill(uniform_val);
634                }
635            }
636        }
637
638        Ok(weights)
639    }
640
641    /// Compute label kernel for supervised learning
642    fn compute_label_kernel(&self, labels: &Array1<f64>) -> Result<Array2<f64>> {
643        let n = labels.len();
644        let mut label_kernel = Array2::zeros((n, n));
645
646        for i in 0..n {
647            for j in 0..n {
648                label_kernel[[i, j]] = if (labels[i] - labels[j]).abs() < 1e-10 {
649                    1.0
650                } else {
651                    0.0
652                };
653            }
654        }
655
656        Ok(label_kernel)
657    }
658
659    /// Compute centered kernel alignment
660    fn centered_kernel_alignment(&self, k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
661        let n = k1.nrows() as f64;
662        let ones = Array2::ones((k1.nrows(), k1.ncols())) / n;
663
664        // Center kernels
665        let k1_centered = k1 - &ones.dot(k1) - &k1.dot(&ones) + &ones.dot(k1).dot(&ones);
666        let k2_centered = k2 - &ones.dot(k2) - &k2.dot(&ones) + &ones.dot(k2).dot(&ones);
667
668        // Compute alignment
669        let numerator = (&k1_centered * &k2_centered).sum();
670        let denominator =
671            ((&k1_centered * &k1_centered).sum() * (&k2_centered * &k2_centered).sum()).sqrt();
672
673        if denominator > 1e-12 {
674            Ok(numerator / denominator)
675        } else {
676            Ok(0.0)
677        }
678    }
679
680    /// Process kernel matrix (normalize and center if configured)
681    fn process_kernel_matrix(&self, mut kernel: Array2<f64>) -> Result<Array2<f64>> {
682        if self.config.normalize_kernels {
683            // Normalize kernel matrix
684            let diag = kernel.diag();
685            let norm_matrix = diag.insert_axis(Axis(1)).dot(&diag.insert_axis(Axis(0)));
686            for i in 0..kernel.nrows() {
687                for j in 0..kernel.ncols() {
688                    if norm_matrix[[i, j]] > 1e-12 {
689                        kernel[[i, j]] /= norm_matrix[[i, j]].sqrt();
690                    }
691                }
692            }
693        }
694
695        if self.config.center_kernels {
696            // Center kernel matrix
697            let _n = kernel.nrows() as f64;
698            let row_means = kernel.mean_axis(Axis(1)).expect("operation should succeed");
699            let col_means = kernel.mean_axis(Axis(0)).expect("operation should succeed");
700            let total_mean = kernel.mean().expect("operation should succeed");
701
702            for i in 0..kernel.nrows() {
703                for j in 0..kernel.ncols() {
704                    kernel[[i, j]] = kernel[[i, j]] - row_means[i] - col_means[j] + total_mean;
705                }
706            }
707        }
708
709        Ok(kernel)
710    }
711
712    /// Placeholder implementations for different approximation methods
713    fn compute_rff_approximation(
714        &mut self,
715        x: &Array2<f64>,
716        kernel: &BaseKernel,
717        n_components: usize,
718    ) -> Result<Array2<f64>> {
719        match kernel {
720            BaseKernel::RBF { gamma } => self.compute_rbf_rff_matrix(x, *gamma, n_components),
721            BaseKernel::Laplacian { gamma } => {
722                self.compute_laplacian_rff_matrix(x, *gamma, n_components)
723            }
724            _ => {
725                // Fallback to exact computation for unsupported kernels
726                self.compute_exact_kernel_matrix(x, kernel)
727            }
728        }
729    }
730
731    fn compute_rbf_rff_matrix(
732        &mut self,
733        x: &Array2<f64>,
734        gamma: f64,
735        n_components: usize,
736    ) -> Result<Array2<f64>> {
737        let (_n_samples, n_features) = x.dim();
738
739        // Generate random weights
740        let mut weights = Array2::zeros((n_components, n_features));
741        for i in 0..n_components {
742            for j in 0..n_features {
743                weights[[i, j]] = self.rng.sample::<f64, _>(StandardNormal) * (2.0 * gamma).sqrt();
744            }
745        }
746
747        // Generate random bias
748        let mut bias = Array1::zeros(n_components);
749        for i in 0..n_components {
750            bias[i] = self.rng.random_range(0.0..2.0 * std::f64::consts::PI);
751        }
752
753        // Compute features
754        let projection = x.dot(&weights.t()) + &bias;
755        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
756
757        // Compute Gram matrix from features
758        Ok(features.dot(&features.t()))
759    }
760
761    fn compute_laplacian_rff_matrix(
762        &mut self,
763        x: &Array2<f64>,
764        gamma: f64,
765        n_components: usize,
766    ) -> Result<Array2<f64>> {
767        // Use Cauchy distribution for Laplacian kernel
768        let (_n_samples, n_features) = x.dim();
769
770        // Generate random weights from Cauchy distribution (approximated)
771        let mut weights = Array2::zeros((n_components, n_features));
772        for i in 0..n_components {
773            for j in 0..n_features {
774                let u: f64 = self.rng.random_range(0.001..0.999);
775                weights[[i, j]] = ((std::f64::consts::PI * (u - 0.5)).tan()) * gamma;
776            }
777        }
778
779        // Generate random bias
780        let mut bias = Array1::zeros(n_components);
781        for i in 0..n_components {
782            bias[i] = self.rng.random_range(0.0..2.0 * std::f64::consts::PI);
783        }
784
785        // Compute features
786        let projection = x.dot(&weights.t()) + &bias;
787        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
788
789        // Compute Gram matrix from features
790        Ok(features.dot(&features.t()))
791    }
792
793    fn compute_nystroem_approximation(
794        &mut self,
795        x: &Array2<f64>,
796        kernel: &BaseKernel,
797        n_components: usize,
798    ) -> Result<Array2<f64>> {
799        let (n_samples, _) = x.dim();
800        let n_landmarks = n_components.min(n_samples);
801
802        // Select random landmarks
803        let mut landmark_indices = Vec::new();
804        for _ in 0..n_landmarks {
805            landmark_indices.push(self.rng.random_range(0..n_samples));
806        }
807
808        // Compute kernel between all points and landmarks
809        let mut kernel_matrix = Array2::zeros((n_samples, n_landmarks));
810        for i in 0..n_samples {
811            for j in 0..n_landmarks {
812                let landmark_idx = landmark_indices[j];
813                kernel_matrix[[i, j]] =
814                    kernel.evaluate(&x.row(i).to_owned(), &x.row(landmark_idx).to_owned());
815            }
816        }
817
818        // For simplicity, return kernel matrix with landmarks (not full Nyström)
819        Ok(kernel_matrix.dot(&kernel_matrix.t()))
820    }
821
822    fn compute_structured_approximation(
823        &mut self,
824        x: &Array2<f64>,
825        kernel: &BaseKernel,
826        n_components: usize,
827    ) -> Result<Array2<f64>> {
828        // Fallback to RFF for structured features
829        self.compute_rff_approximation(x, kernel, n_components)
830    }
831
832    fn compute_exact_kernel_matrix(
833        &self,
834        x: &Array2<f64>,
835        kernel: &BaseKernel,
836    ) -> Result<Array2<f64>> {
837        let n_samples = x.nrows();
838        let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
839
840        for i in 0..n_samples {
841            for j in i..n_samples {
842                let value = kernel.evaluate(&x.row(i).to_owned(), &x.row(j).to_owned());
843                kernel_matrix[[i, j]] = value;
844                kernel_matrix[[j, i]] = value;
845            }
846        }
847
848        Ok(kernel_matrix)
849    }
850
851    /// Placeholder transform methods
852    fn transform_rff(
853        &self,
854        x: &Array2<f64>,
855        _training_data: &Array2<f64>,
856        _kernel: &BaseKernel,
857        n_components: usize,
858    ) -> Result<Array2<f64>> {
859        // Return random features as placeholder
860        let (n_samples, _) = x.dim();
861        Ok(Array2::zeros((n_samples, n_components)))
862    }
863
864    fn transform_nystroem(
865        &self,
866        x: &Array2<f64>,
867        _training_data: &Array2<f64>,
868        _kernel: &BaseKernel,
869        n_components: usize,
870    ) -> Result<Array2<f64>> {
871        // Return placeholder features
872        let (n_samples, _) = x.dim();
873        Ok(Array2::zeros((n_samples, n_components)))
874    }
875
876    fn transform_structured(
877        &self,
878        x: &Array2<f64>,
879        _training_data: &Array2<f64>,
880        _kernel: &BaseKernel,
881        n_components: usize,
882    ) -> Result<Array2<f64>> {
883        // Return placeholder features
884        let (n_samples, _) = x.dim();
885        Ok(Array2::zeros((n_samples, n_components)))
886    }
887
888    fn compute_combined_representation(&mut self) -> Result<()> {
889        // Placeholder for combining features
890        Ok(())
891    }
892
893    fn compute_kernel_statistics(
894        &self,
895        kernel: &Array2<f64>,
896        _y: Option<&Array1<f64>>,
897    ) -> Result<KernelStatistics> {
898        let mut stats = KernelStatistics::new();
899
900        // Compute basic statistics
901        stats.alignment = kernel.diag().mean().unwrap_or(0.0);
902
903        // Simplified eigenspectrum (just use diagonal)
904        stats.eigenspectrum = kernel.diag().to_owned();
905
906        // Effective rank approximation
907        let trace = kernel.diag().sum();
908        let frobenius_sq = kernel.mapv(|x| x * x).sum();
909        stats.effective_rank = if frobenius_sq > 1e-12 {
910            trace.powi(2) / frobenius_sq
911        } else {
912            0.0
913        };
914
915        // Diversity measure (variance of diagonal)
916        stats.diversity = kernel.diag().var(0.0);
917
918        // Complexity measure (condition number approximation)
919        let diag = kernel.diag();
920        let max_eig = diag.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
921        let min_eig = diag.iter().fold(f64::INFINITY, |a, &b| a.min(b.max(1e-12)));
922        stats.complexity = max_eig / min_eig;
923
924        Ok(stats)
925    }
926
927    fn compute_simplified_eigenvalues(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
928        // Simplified eigenvalue computation using diagonal
929        Ok(matrix.diag().to_owned())
930    }
931}
932
933#[allow(non_snake_case)]
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use scirs2_core::ndarray::array;
938
939    #[test]
940    fn test_base_kernel_evaluation() {
941        let x = array![1.0, 2.0, 3.0];
942        let y = array![1.0, 2.0, 3.0];
943
944        let rbf_kernel = BaseKernel::RBF { gamma: 0.1 };
945        let value = rbf_kernel.evaluate(&x, &y);
946        assert!((value - 1.0).abs() < 1e-10); // Same points should give 1.0
947
948        let linear_kernel = BaseKernel::Linear;
949        let value = linear_kernel.evaluate(&x, &y);
950        assert!((value - 14.0).abs() < 1e-10); // 1*1 + 2*2 + 3*3 = 14
951    }
952
953    #[test]
954    fn test_kernel_names() {
955        let rbf = BaseKernel::RBF { gamma: 0.5 };
956        assert_eq!(rbf.name(), "RBF(gamma=0.5000)");
957
958        let linear = BaseKernel::Linear;
959        assert_eq!(linear.name(), "Linear");
960
961        let poly = BaseKernel::Polynomial {
962            degree: 2.0,
963            gamma: 1.0,
964            coef0: 0.0,
965        };
966        assert_eq!(
967            poly.name(),
968            "Polynomial(degree=2.0, gamma=1.0000, coef0=0.0000)"
969        );
970    }
971
972    #[test]
973    fn test_multiple_kernel_learning_basic() {
974        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
975
976        let base_kernels = vec![
977            BaseKernel::RBF { gamma: 0.1 },
978            BaseKernel::Linear,
979            BaseKernel::Polynomial {
980                degree: 2.0,
981                gamma: 1.0,
982                coef0: 0.0,
983            },
984        ];
985
986        let mut mkl = MultipleKernelLearning::new(base_kernels).with_random_state(42);
987
988        mkl.fit(&x, None).expect("operation should succeed");
989
990        let weights = mkl.kernel_weights().expect("operation should succeed");
991        assert_eq!(weights.len(), 3);
992        assert!((weights.sum() - 1.0).abs() < 1e-10); // Should sum to 1 for convex combination
993    }
994
995    #[test]
996    fn test_kernel_statistics() {
997        let kernel = array![[1.0, 0.5, 0.2], [0.5, 1.0, 0.3], [0.2, 0.3, 1.0]];
998
999        let mkl = MultipleKernelLearning::new(vec![]);
1000        let stats = mkl
1001            .compute_kernel_statistics(&kernel, None)
1002            .expect("operation should succeed");
1003
1004        assert!((stats.alignment - 1.0).abs() < 1e-10); // Diagonal mean should be 1.0
1005        assert!(stats.effective_rank > 0.0);
1006        assert!(stats.diversity >= 0.0);
1007    }
1008
1009    #[test]
1010    fn test_combination_strategies() {
1011        let weights = array![0.5, -0.3, 0.8];
1012
1013        let mut mkl = MultipleKernelLearning::new(vec![]);
1014        mkl.config.combination_strategy = CombinationStrategy::Convex;
1015
1016        let constrained = mkl
1017            .apply_combination_constraints(weights.clone())
1018            .expect("operation should succeed");
1019
1020        // Should be non-negative and sum to 1
1021        assert!(constrained.iter().all(|&x| x >= 0.0));
1022        assert!((constrained.sum() - 1.0).abs() < 1e-10);
1023    }
1024
1025    #[test]
1026    fn test_mkl_config() {
1027        let config = MultiKernelConfig {
1028            combination_strategy: CombinationStrategy::Linear,
1029            weight_learning: WeightLearningAlgorithm::SimpleMKL {
1030                regularization: 0.01,
1031            },
1032            approximation_method: ApproximationMethod::Nystroem { n_components: 50 },
1033            max_iterations: 200,
1034            tolerance: 1e-8,
1035            normalize_kernels: false,
1036            center_kernels: false,
1037            regularization: 0.001,
1038        };
1039
1040        assert!(matches!(
1041            config.combination_strategy,
1042            CombinationStrategy::Linear
1043        ));
1044        assert!(matches!(
1045            config.weight_learning,
1046            WeightLearningAlgorithm::SimpleMKL { .. }
1047        ));
1048        assert_eq!(config.max_iterations, 200);
1049        assert!(!config.normalize_kernels);
1050    }
1051
1052    #[test]
1053    fn test_important_kernels() {
1054        let base_kernels = vec![
1055            BaseKernel::RBF { gamma: 0.1 },
1056            BaseKernel::Linear,
1057            BaseKernel::Polynomial {
1058                degree: 2.0,
1059                gamma: 1.0,
1060                coef0: 0.0,
1061            },
1062        ];
1063
1064        let mut mkl = MultipleKernelLearning::new(base_kernels);
1065        mkl.weights = Some(array![0.6, 0.05, 0.35]);
1066
1067        let important = mkl.important_kernels(0.1);
1068        assert_eq!(important.len(), 2); // Only kernels with weight >= 0.1
1069        assert_eq!(important[0].0, 0); // First kernel (RBF)
1070        assert_eq!(important[1].0, 2); // Third kernel (Polynomial)
1071    }
1072
1073    #[test]
1074    fn test_supervised_vs_unsupervised() {
1075        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1076        let y = array![0.0, 1.0, 0.0, 1.0];
1077
1078        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1079
1080        let mut mkl_unsupervised =
1081            MultipleKernelLearning::new(base_kernels.clone()).with_random_state(42);
1082        mkl_unsupervised
1083            .fit(&x, None)
1084            .expect("operation should succeed");
1085
1086        let mut mkl_supervised = MultipleKernelLearning::new(base_kernels).with_random_state(42);
1087        mkl_supervised
1088            .fit(&x, Some(&y))
1089            .expect("operation should succeed");
1090
1091        // Both should work without errors
1092        assert!(mkl_unsupervised.kernel_weights().is_some());
1093        assert!(mkl_supervised.kernel_weights().is_some());
1094    }
1095
1096    #[test]
1097    fn test_transform_compatibility() {
1098        let x_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1099        let x_test = array![[2.0, 3.0], [4.0, 5.0]];
1100
1101        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1102
1103        let mut mkl = MultipleKernelLearning::new(base_kernels)
1104            .with_config(MultiKernelConfig {
1105                approximation_method: ApproximationMethod::RandomFourierFeatures {
1106                    n_components: 10,
1107                },
1108                ..Default::default()
1109            })
1110            .with_random_state(42);
1111
1112        mkl.fit(&x_train, None).expect("operation should succeed");
1113        let features = mkl.transform(&x_test).expect("operation should succeed");
1114
1115        assert_eq!(features.nrows(), 2); // Two test samples
1116        assert!(features.ncols() > 0); // Some features generated
1117    }
1118}