sklears_kernel_approximation/
multi_kernel_learning.rs

1use scirs2_core::ndarray::{Array1, Array2, Axis};
2use scirs2_core::random::rngs::StdRng;
3use scirs2_core::random::Rng;
4use scirs2_core::random::SeedableRng;
5use scirs2_core::StandardNormal;
6use sklears_core::error::{Result, SklearsError};
7use std::collections::HashMap;
8
9/// Multiple Kernel Learning (MKL) methods for kernel approximation
10///
11/// This module provides methods for learning optimal combinations of multiple
12/// kernels, enabling automatic kernel selection and weighting for improved
13/// machine learning performance.
14///
15/// Base kernel type for multiple kernel learning
16#[derive(Debug, Clone)]
17/// BaseKernel
18pub enum BaseKernel {
19    /// RBF kernel with specific gamma parameter
20    RBF { gamma: f64 },
21    /// Polynomial kernel with degree, gamma, and coef0
22    Polynomial { degree: f64, gamma: f64, coef0: f64 },
23    /// Laplacian kernel with gamma parameter
24    Laplacian { gamma: f64 },
25    /// Linear kernel
26    Linear,
27    /// Sigmoid kernel with gamma and coef0
28    Sigmoid { gamma: f64, coef0: f64 },
29    /// Custom kernel with user-defined function
30    Custom {
31        name: String,
32        kernel_fn: fn(&Array1<f64>, &Array1<f64>) -> f64,
33    },
34}
35
36impl BaseKernel {
37    /// Evaluate kernel function between two samples
38    pub fn evaluate(&self, x: &Array1<f64>, y: &Array1<f64>) -> f64 {
39        match self {
40            BaseKernel::RBF { gamma } => {
41                let diff = x - y;
42                let squared_dist = diff.mapv(|x| x * x).sum();
43                (-gamma * squared_dist).exp()
44            }
45            BaseKernel::Polynomial {
46                degree,
47                gamma,
48                coef0,
49            } => {
50                let dot_product = x.dot(y);
51                (gamma * dot_product + coef0).powf(*degree)
52            }
53            BaseKernel::Laplacian { gamma } => {
54                let diff = x - y;
55                let manhattan_dist = diff.mapv(|x| x.abs()).sum();
56                (-gamma * manhattan_dist).exp()
57            }
58            BaseKernel::Linear => x.dot(y),
59            BaseKernel::Sigmoid { gamma, coef0 } => {
60                let dot_product = x.dot(y);
61                (gamma * dot_product + coef0).tanh()
62            }
63            BaseKernel::Custom { kernel_fn, .. } => kernel_fn(x, y),
64        }
65    }
66
67    /// Get kernel name for identification
68    pub fn name(&self) -> String {
69        match self {
70            BaseKernel::RBF { gamma } => format!("RBF(gamma={:.4})", gamma),
71            BaseKernel::Polynomial {
72                degree,
73                gamma,
74                coef0,
75            } => {
76                format!(
77                    "Polynomial(degree={:.1}, gamma={:.4}, coef0={:.4})",
78                    degree, gamma, coef0
79                )
80            }
81            BaseKernel::Laplacian { gamma } => format!("Laplacian(gamma={:.4})", gamma),
82            BaseKernel::Linear => "Linear".to_string(),
83            BaseKernel::Sigmoid { gamma, coef0 } => {
84                format!("Sigmoid(gamma={:.4}, coef0={:.4})", gamma, coef0)
85            }
86            BaseKernel::Custom { name, .. } => format!("Custom({})", name),
87        }
88    }
89}
90
91/// Kernel combination strategy
92#[derive(Debug, Clone)]
93/// CombinationStrategy
94pub enum CombinationStrategy {
95    /// Linear combination: K = Σ αᵢ Kᵢ
96    Linear,
97    /// Product combination: K = Π Kᵢ^αᵢ
98    Product,
99    /// Convex combination with unit sum constraint: Σ αᵢ = 1
100    Convex,
101    /// Conic combination with non-negativity: αᵢ ≥ 0
102    Conic,
103    /// Hierarchical combination with tree structure
104    Hierarchical,
105}
106
107/// Kernel weight learning algorithm
108#[derive(Debug, Clone)]
109/// WeightLearningAlgorithm
110pub enum WeightLearningAlgorithm {
111    /// Uniform weights (baseline)
112    Uniform,
113    /// Centered kernel alignment optimization
114    CenteredKernelAlignment,
115    /// Maximum mean discrepancy minimization
116    MaximumMeanDiscrepancy,
117    /// SimpleMKL algorithm with QCQP optimization
118    SimpleMKL { regularization: f64 },
119    /// EasyMKL algorithm with radius constraint
120    EasyMKL { radius: f64 },
121    /// SPKM (Spectral Projected Kernel Machine)
122    SpectralProjected,
123    /// Localized MKL with spatial weights
124    LocalizedMKL { bandwidth: f64 },
125    /// Adaptive MKL with cross-validation
126    AdaptiveMKL { cv_folds: usize },
127}
128
129/// Kernel approximation method for each base kernel
130#[derive(Debug, Clone)]
131/// ApproximationMethod
132pub enum ApproximationMethod {
133    /// Random Fourier Features
134    RandomFourierFeatures { n_components: usize },
135    /// Nyström approximation
136    Nystroem { n_components: usize },
137    /// Structured random features
138    StructuredFeatures { n_components: usize },
139    /// Exact kernel (no approximation)
140    Exact,
141}
142
143/// Configuration for multiple kernel learning
144#[derive(Debug, Clone)]
145/// MultiKernelConfig
146pub struct MultiKernelConfig {
147    /// combination_strategy
148    pub combination_strategy: CombinationStrategy,
149    /// weight_learning
150    pub weight_learning: WeightLearningAlgorithm,
151    /// approximation_method
152    pub approximation_method: ApproximationMethod,
153    /// max_iterations
154    pub max_iterations: usize,
155    /// tolerance
156    pub tolerance: f64,
157    /// normalize_kernels
158    pub normalize_kernels: bool,
159    /// center_kernels
160    pub center_kernels: bool,
161    /// regularization
162    pub regularization: f64,
163}
164
165impl Default for MultiKernelConfig {
166    fn default() -> Self {
167        Self {
168            combination_strategy: CombinationStrategy::Convex,
169            weight_learning: WeightLearningAlgorithm::CenteredKernelAlignment,
170            approximation_method: ApproximationMethod::RandomFourierFeatures { n_components: 100 },
171            max_iterations: 100,
172            tolerance: 1e-6,
173            normalize_kernels: true,
174            center_kernels: true,
175            regularization: 1e-3,
176        }
177    }
178}
179
180/// Multiple Kernel Learning approximation
181///
182/// Learns optimal weights for combining multiple kernel approximations
183/// to create a single, improved kernel approximation.
184pub struct MultipleKernelLearning {
185    base_kernels: Vec<BaseKernel>,
186    config: MultiKernelConfig,
187    weights: Option<Array1<f64>>,
188    kernel_matrices: Option<Vec<Array2<f64>>>,
189    combined_features: Option<Array2<f64>>,
190    random_state: Option<u64>,
191    rng: StdRng,
192    training_data: Option<Array2<f64>>,
193    kernel_statistics: HashMap<String, KernelStatistics>,
194}
195
196/// Statistics for individual kernels
197#[derive(Debug, Clone)]
198/// KernelStatistics
199pub struct KernelStatistics {
200    /// alignment
201    pub alignment: f64,
202    /// eigenspectrum
203    pub eigenspectrum: Array1<f64>,
204    /// effective_rank
205    pub effective_rank: f64,
206    /// diversity
207    pub diversity: f64,
208    /// complexity
209    pub complexity: f64,
210}
211
212impl Default for KernelStatistics {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218impl KernelStatistics {
219    pub fn new() -> Self {
220        Self {
221            alignment: 0.0,
222            eigenspectrum: Array1::zeros(0),
223            effective_rank: 0.0,
224            diversity: 0.0,
225            complexity: 0.0,
226        }
227    }
228}
229
230impl MultipleKernelLearning {
231    /// Create a new multiple kernel learning instance
232    pub fn new(base_kernels: Vec<BaseKernel>) -> Self {
233        let rng = StdRng::seed_from_u64(42);
234        Self {
235            base_kernels,
236            config: MultiKernelConfig::default(),
237            weights: None,
238            kernel_matrices: None,
239            combined_features: None,
240            random_state: None,
241            rng,
242            training_data: None,
243            kernel_statistics: HashMap::new(),
244        }
245    }
246
247    /// Set configuration
248    pub fn with_config(mut self, config: MultiKernelConfig) -> Self {
249        self.config = config;
250        self
251    }
252
253    /// Set random state for reproducibility
254    pub fn with_random_state(mut self, random_state: u64) -> Self {
255        self.random_state = Some(random_state);
256        self.rng = StdRng::seed_from_u64(random_state);
257        self
258    }
259
260    /// Fit the multiple kernel learning model
261    pub fn fit(&mut self, x: &Array2<f64>, y: Option<&Array1<f64>>) -> Result<()> {
262        let (_n_samples, _) = x.dim();
263
264        // Store training data
265        self.training_data = Some(x.clone());
266
267        // Compute or approximate individual kernel matrices
268        let mut kernel_matrices = Vec::new();
269        let base_kernels = self.base_kernels.clone(); // Clone to avoid borrowing issues
270
271        for (i, base_kernel) in base_kernels.iter().enumerate() {
272            let kernel_matrix = match &self.config.approximation_method {
273                ApproximationMethod::RandomFourierFeatures { n_components } => {
274                    self.compute_rff_approximation(x, base_kernel, *n_components)?
275                }
276                ApproximationMethod::Nystroem { n_components } => {
277                    self.compute_nystroem_approximation(x, base_kernel, *n_components)?
278                }
279                ApproximationMethod::StructuredFeatures { n_components } => {
280                    self.compute_structured_approximation(x, base_kernel, *n_components)?
281                }
282                ApproximationMethod::Exact => self.compute_exact_kernel_matrix(x, base_kernel)?,
283            };
284
285            // Process kernel matrix (normalize, center)
286            let processed_matrix = self.process_kernel_matrix(kernel_matrix)?;
287
288            // Compute kernel statistics
289            let stats = self.compute_kernel_statistics(&processed_matrix, y)?;
290            self.kernel_statistics
291                .insert(format!("kernel_{}", i), stats);
292
293            kernel_matrices.push(processed_matrix);
294        }
295
296        self.kernel_matrices = Some(kernel_matrices);
297
298        // Learn optimal kernel weights
299        self.learn_weights(y)?;
300
301        // Compute combined features/kernel
302        self.compute_combined_representation()?;
303
304        Ok(())
305    }
306
307    /// Transform data using learned kernel combination
308    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
309        let weights = self
310            .weights
311            .as_ref()
312            .ok_or_else(|| SklearsError::NotFitted {
313                operation: "transform".to_string(),
314            })?;
315        let training_data = self
316            .training_data
317            .as_ref()
318            .ok_or_else(|| SklearsError::NotFitted {
319                operation: "transform".to_string(),
320            })?;
321
322        let mut combined_features = None;
323
324        for (base_kernel, &weight) in self.base_kernels.iter().zip(weights.iter()) {
325            if weight.abs() < 1e-12 {
326                continue; // Skip kernels with negligible weight
327            }
328
329            let features = match &self.config.approximation_method {
330                ApproximationMethod::RandomFourierFeatures { n_components } => {
331                    self.transform_rff(x, training_data, base_kernel, *n_components)?
332                }
333                ApproximationMethod::Nystroem { n_components } => {
334                    self.transform_nystroem(x, training_data, base_kernel, *n_components)?
335                }
336                ApproximationMethod::StructuredFeatures { n_components } => {
337                    self.transform_structured(x, training_data, base_kernel, *n_components)?
338                }
339                ApproximationMethod::Exact => {
340                    return Err(SklearsError::NotImplemented(
341                        "Exact kernel transform not implemented for new data".to_string(),
342                    ));
343                }
344            };
345
346            let weighted_features = &features * weight;
347
348            match &self.config.combination_strategy {
349                CombinationStrategy::Linear
350                | CombinationStrategy::Convex
351                | CombinationStrategy::Conic => {
352                    combined_features = match combined_features {
353                        Some(existing) => Some(existing + weighted_features),
354                        None => Some(weighted_features),
355                    };
356                }
357                CombinationStrategy::Product => {
358                    combined_features = match combined_features {
359                        Some(existing) => Some(existing * weighted_features.mapv(|x| x.exp())),
360                        None => Some(weighted_features.mapv(|x| x.exp())),
361                    };
362                }
363                CombinationStrategy::Hierarchical => {
364                    // Simplified hierarchical combination
365                    combined_features = match combined_features {
366                        Some(existing) => Some(existing + weighted_features),
367                        None => Some(weighted_features),
368                    };
369                }
370            }
371        }
372
373        combined_features.ok_or_else(|| {
374            SklearsError::Other("No features generated - all kernel weights are zero".to_string())
375        })
376    }
377
378    /// Get learned kernel weights
379    pub fn kernel_weights(&self) -> Option<&Array1<f64>> {
380        self.weights.as_ref()
381    }
382
383    /// Get kernel statistics
384    pub fn kernel_stats(&self) -> &HashMap<String, KernelStatistics> {
385        &self.kernel_statistics
386    }
387
388    /// Get most important kernels based on weights
389    pub fn important_kernels(&self, threshold: f64) -> Vec<(usize, &BaseKernel, f64)> {
390        if let Some(weights) = &self.weights {
391            self.base_kernels
392                .iter()
393                .enumerate()
394                .zip(weights.iter())
395                .filter_map(|((i, kernel), &weight)| {
396                    if weight.abs() >= threshold {
397                        Some((i, kernel, weight))
398                    } else {
399                        None
400                    }
401                })
402                .collect()
403        } else {
404            Vec::new()
405        }
406    }
407
408    /// Learn optimal kernel weights
409    fn learn_weights(&mut self, y: Option<&Array1<f64>>) -> Result<()> {
410        let kernel_matrices = self.kernel_matrices.as_ref().unwrap();
411        let n_kernels = kernel_matrices.len();
412
413        let weights = match &self.config.weight_learning {
414            WeightLearningAlgorithm::Uniform => {
415                Array1::from_elem(n_kernels, 1.0 / n_kernels as f64)
416            }
417            WeightLearningAlgorithm::CenteredKernelAlignment => {
418                self.learn_cka_weights(kernel_matrices, y)?
419            }
420            WeightLearningAlgorithm::MaximumMeanDiscrepancy => {
421                self.learn_mmd_weights(kernel_matrices)?
422            }
423            WeightLearningAlgorithm::SimpleMKL { regularization } => {
424                self.learn_simple_mkl_weights(kernel_matrices, y, *regularization)?
425            }
426            WeightLearningAlgorithm::EasyMKL { radius } => {
427                self.learn_easy_mkl_weights(kernel_matrices, y, *radius)?
428            }
429            WeightLearningAlgorithm::SpectralProjected => {
430                self.learn_spectral_weights(kernel_matrices)?
431            }
432            WeightLearningAlgorithm::LocalizedMKL { bandwidth } => {
433                self.learn_localized_weights(kernel_matrices, *bandwidth)?
434            }
435            WeightLearningAlgorithm::AdaptiveMKL { cv_folds } => {
436                self.learn_adaptive_weights(kernel_matrices, y, *cv_folds)?
437            }
438        };
439
440        // Apply combination strategy constraints
441        let final_weights = self.apply_combination_constraints(weights)?;
442
443        self.weights = Some(final_weights);
444        Ok(())
445    }
446
447    /// Learn weights using Centered Kernel Alignment
448    fn learn_cka_weights(
449        &self,
450        kernel_matrices: &[Array2<f64>],
451        y: Option<&Array1<f64>>,
452    ) -> Result<Array1<f64>> {
453        if let Some(labels) = y {
454            // Supervised CKA: align with label kernel
455            let label_kernel = self.compute_label_kernel(labels)?;
456            let mut alignments = Array1::zeros(kernel_matrices.len());
457
458            for (i, kernel) in kernel_matrices.iter().enumerate() {
459                alignments[i] = self.centered_kernel_alignment(kernel, &label_kernel)?;
460            }
461
462            // Softmax normalization
463            let max_alignment = alignments.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
464            let exp_alignments = alignments.mapv(|x| (x - max_alignment).exp());
465            let sum_exp = exp_alignments.sum();
466
467            Ok(exp_alignments / sum_exp)
468        } else {
469            // Unsupervised: use kernel diversity
470            let mut weights = Array1::zeros(kernel_matrices.len());
471
472            for (i, kernel) in kernel_matrices.iter().enumerate() {
473                // Use trace as a simple quality measure
474                weights[i] = kernel.diag().sum() / kernel.nrows() as f64;
475            }
476
477            let sum_weights = weights.sum();
478            if sum_weights > 0.0 {
479                weights /= sum_weights;
480            } else {
481                weights.fill(1.0 / kernel_matrices.len() as f64);
482            }
483
484            Ok(weights)
485        }
486    }
487
488    /// Learn weights using Maximum Mean Discrepancy
489    fn learn_mmd_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
490        // Simplified MMD-based weight learning
491        let mut weights = Array1::zeros(kernel_matrices.len());
492
493        for (i, kernel) in kernel_matrices.iter().enumerate() {
494            // Use kernel matrix statistics as proxy for MMD
495            let trace = kernel.diag().sum();
496            let frobenius_norm = kernel.mapv(|x| x * x).sum().sqrt();
497            weights[i] = trace / frobenius_norm;
498        }
499
500        let sum_weights = weights.sum();
501        if sum_weights > 0.0 {
502            weights /= sum_weights;
503        } else {
504            weights.fill(1.0 / kernel_matrices.len() as f64);
505        }
506
507        Ok(weights)
508    }
509
510    /// Simplified SimpleMKL implementation
511    fn learn_simple_mkl_weights(
512        &self,
513        kernel_matrices: &[Array2<f64>],
514        _y: Option<&Array1<f64>>,
515        regularization: f64,
516    ) -> Result<Array1<f64>> {
517        // Simplified version: use kernel quality metrics
518        let mut weights = Array1::zeros(kernel_matrices.len());
519
520        for (i, kernel) in kernel_matrices.iter().enumerate() {
521            let eigenvalues = self.compute_simplified_eigenvalues(kernel)?;
522            let effective_rank =
523                eigenvalues.mapv(|x| x * x).sum().powi(2) / eigenvalues.mapv(|x| x.powi(4)).sum();
524            weights[i] = effective_rank / (1.0 + regularization);
525        }
526
527        let sum_weights = weights.sum();
528        if sum_weights > 0.0 {
529            weights /= sum_weights;
530        } else {
531            weights.fill(1.0 / kernel_matrices.len() as f64);
532        }
533
534        Ok(weights)
535    }
536
537    /// Simplified EasyMKL implementation
538    fn learn_easy_mkl_weights(
539        &self,
540        kernel_matrices: &[Array2<f64>],
541        _y: Option<&Array1<f64>>,
542        _radius: f64,
543    ) -> Result<Array1<f64>> {
544        // Use uniform weights for simplicity
545        Ok(Array1::from_elem(
546            kernel_matrices.len(),
547            1.0 / kernel_matrices.len() as f64,
548        ))
549    }
550
551    /// Learn weights using spectral methods
552    fn learn_spectral_weights(&self, kernel_matrices: &[Array2<f64>]) -> Result<Array1<f64>> {
553        let mut weights = Array1::zeros(kernel_matrices.len());
554
555        for (i, kernel) in kernel_matrices.iter().enumerate() {
556            let trace = kernel.diag().sum();
557            weights[i] = trace;
558        }
559
560        let sum_weights = weights.sum();
561        if sum_weights > 0.0 {
562            weights /= sum_weights;
563        } else {
564            weights.fill(1.0 / kernel_matrices.len() as f64);
565        }
566
567        Ok(weights)
568    }
569
570    /// Learn localized weights
571    fn learn_localized_weights(
572        &self,
573        kernel_matrices: &[Array2<f64>],
574        _bandwidth: f64,
575    ) -> Result<Array1<f64>> {
576        // Simplified localized MKL
577        Ok(Array1::from_elem(
578            kernel_matrices.len(),
579            1.0 / kernel_matrices.len() as f64,
580        ))
581    }
582
583    /// Learn adaptive weights with cross-validation
584    fn learn_adaptive_weights(
585        &self,
586        kernel_matrices: &[Array2<f64>],
587        _y: Option<&Array1<f64>>,
588        _cv_folds: usize,
589    ) -> Result<Array1<f64>> {
590        // Simplified adaptive MKL
591        Ok(Array1::from_elem(
592            kernel_matrices.len(),
593            1.0 / kernel_matrices.len() as f64,
594        ))
595    }
596
597    /// Apply combination strategy constraints
598    fn apply_combination_constraints(&self, mut weights: Array1<f64>) -> Result<Array1<f64>> {
599        match &self.config.combination_strategy {
600            CombinationStrategy::Convex => {
601                // Ensure non-negative and sum to 1
602                weights.mapv_inplace(|x| x.max(0.0));
603                let sum = weights.sum();
604                if sum > 0.0 {
605                    weights /= sum;
606                } else {
607                    let uniform_val = 1.0 / weights.len() as f64;
608                    weights.fill(uniform_val);
609                }
610            }
611            CombinationStrategy::Conic => {
612                // Ensure non-negative
613                weights.mapv_inplace(|x| x.max(0.0));
614            }
615            CombinationStrategy::Linear => {
616                // No constraints
617            }
618            CombinationStrategy::Product => {
619                // Ensure all weights are positive for product combination
620                weights.mapv_inplace(|x| x.abs().max(1e-12));
621            }
622            CombinationStrategy::Hierarchical => {
623                // Simplified hierarchical constraints
624                weights.mapv_inplace(|x| x.max(0.0));
625                let sum = weights.sum();
626                if sum > 0.0 {
627                    weights /= sum;
628                } else {
629                    let uniform_val = 1.0 / weights.len() as f64;
630                    weights.fill(uniform_val);
631                }
632            }
633        }
634
635        Ok(weights)
636    }
637
638    /// Compute label kernel for supervised learning
639    fn compute_label_kernel(&self, labels: &Array1<f64>) -> Result<Array2<f64>> {
640        let n = labels.len();
641        let mut label_kernel = Array2::zeros((n, n));
642
643        for i in 0..n {
644            for j in 0..n {
645                label_kernel[[i, j]] = if (labels[i] - labels[j]).abs() < 1e-10 {
646                    1.0
647                } else {
648                    0.0
649                };
650            }
651        }
652
653        Ok(label_kernel)
654    }
655
656    /// Compute centered kernel alignment
657    fn centered_kernel_alignment(&self, k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
658        let n = k1.nrows() as f64;
659        let ones = Array2::ones((k1.nrows(), k1.ncols())) / n;
660
661        // Center kernels
662        let k1_centered = k1 - &ones.dot(k1) - &k1.dot(&ones) + &ones.dot(k1).dot(&ones);
663        let k2_centered = k2 - &ones.dot(k2) - &k2.dot(&ones) + &ones.dot(k2).dot(&ones);
664
665        // Compute alignment
666        let numerator = (&k1_centered * &k2_centered).sum();
667        let denominator =
668            ((&k1_centered * &k1_centered).sum() * (&k2_centered * &k2_centered).sum()).sqrt();
669
670        if denominator > 1e-12 {
671            Ok(numerator / denominator)
672        } else {
673            Ok(0.0)
674        }
675    }
676
677    /// Process kernel matrix (normalize and center if configured)
678    fn process_kernel_matrix(&self, mut kernel: Array2<f64>) -> Result<Array2<f64>> {
679        if self.config.normalize_kernels {
680            // Normalize kernel matrix
681            let diag = kernel.diag();
682            let norm_matrix = diag.insert_axis(Axis(1)).dot(&diag.insert_axis(Axis(0)));
683            for i in 0..kernel.nrows() {
684                for j in 0..kernel.ncols() {
685                    if norm_matrix[[i, j]] > 1e-12 {
686                        kernel[[i, j]] /= norm_matrix[[i, j]].sqrt();
687                    }
688                }
689            }
690        }
691
692        if self.config.center_kernels {
693            // Center kernel matrix
694            let _n = kernel.nrows() as f64;
695            let row_means = kernel.mean_axis(Axis(1)).unwrap();
696            let col_means = kernel.mean_axis(Axis(0)).unwrap();
697            let total_mean = kernel.mean().unwrap();
698
699            for i in 0..kernel.nrows() {
700                for j in 0..kernel.ncols() {
701                    kernel[[i, j]] = kernel[[i, j]] - row_means[i] - col_means[j] + total_mean;
702                }
703            }
704        }
705
706        Ok(kernel)
707    }
708
709    /// Placeholder implementations for different approximation methods
710    fn compute_rff_approximation(
711        &mut self,
712        x: &Array2<f64>,
713        kernel: &BaseKernel,
714        n_components: usize,
715    ) -> Result<Array2<f64>> {
716        match kernel {
717            BaseKernel::RBF { gamma } => self.compute_rbf_rff_matrix(x, *gamma, n_components),
718            BaseKernel::Laplacian { gamma } => {
719                self.compute_laplacian_rff_matrix(x, *gamma, n_components)
720            }
721            _ => {
722                // Fallback to exact computation for unsupported kernels
723                self.compute_exact_kernel_matrix(x, kernel)
724            }
725        }
726    }
727
728    fn compute_rbf_rff_matrix(
729        &mut self,
730        x: &Array2<f64>,
731        gamma: f64,
732        n_components: usize,
733    ) -> Result<Array2<f64>> {
734        let (_n_samples, n_features) = x.dim();
735
736        // Generate random weights
737        let mut weights = Array2::zeros((n_components, n_features));
738        for i in 0..n_components {
739            for j in 0..n_features {
740                weights[[i, j]] = self.rng.sample::<f64, _>(StandardNormal) * (2.0 * gamma).sqrt();
741            }
742        }
743
744        // Generate random bias
745        let mut bias = Array1::zeros(n_components);
746        for i in 0..n_components {
747            bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
748        }
749
750        // Compute features
751        let projection = x.dot(&weights.t()) + &bias;
752        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
753
754        // Compute Gram matrix from features
755        Ok(features.dot(&features.t()))
756    }
757
758    fn compute_laplacian_rff_matrix(
759        &mut self,
760        x: &Array2<f64>,
761        gamma: f64,
762        n_components: usize,
763    ) -> Result<Array2<f64>> {
764        // Use Cauchy distribution for Laplacian kernel
765        let (_n_samples, n_features) = x.dim();
766
767        // Generate random weights from Cauchy distribution (approximated)
768        let mut weights = Array2::zeros((n_components, n_features));
769        for i in 0..n_components {
770            for j in 0..n_features {
771                let u: f64 = self.rng.gen_range(0.001..0.999);
772                weights[[i, j]] = ((std::f64::consts::PI * (u - 0.5)).tan()) * gamma;
773            }
774        }
775
776        // Generate random bias
777        let mut bias = Array1::zeros(n_components);
778        for i in 0..n_components {
779            bias[i] = self.rng.gen_range(0.0..2.0 * std::f64::consts::PI);
780        }
781
782        // Compute features
783        let projection = x.dot(&weights.t()) + &bias;
784        let features = projection.mapv(|x| x.cos()) * (2.0 / n_components as f64).sqrt();
785
786        // Compute Gram matrix from features
787        Ok(features.dot(&features.t()))
788    }
789
790    fn compute_nystroem_approximation(
791        &mut self,
792        x: &Array2<f64>,
793        kernel: &BaseKernel,
794        n_components: usize,
795    ) -> Result<Array2<f64>> {
796        let (n_samples, _) = x.dim();
797        let n_landmarks = n_components.min(n_samples);
798
799        // Select random landmarks
800        let mut landmark_indices = Vec::new();
801        for _ in 0..n_landmarks {
802            landmark_indices.push(self.rng.gen_range(0..n_samples));
803        }
804
805        // Compute kernel between all points and landmarks
806        let mut kernel_matrix = Array2::zeros((n_samples, n_landmarks));
807        for i in 0..n_samples {
808            for j in 0..n_landmarks {
809                let landmark_idx = landmark_indices[j];
810                kernel_matrix[[i, j]] =
811                    kernel.evaluate(&x.row(i).to_owned(), &x.row(landmark_idx).to_owned());
812            }
813        }
814
815        // For simplicity, return kernel matrix with landmarks (not full Nyström)
816        Ok(kernel_matrix.dot(&kernel_matrix.t()))
817    }
818
819    fn compute_structured_approximation(
820        &mut self,
821        x: &Array2<f64>,
822        kernel: &BaseKernel,
823        n_components: usize,
824    ) -> Result<Array2<f64>> {
825        // Fallback to RFF for structured features
826        self.compute_rff_approximation(x, kernel, n_components)
827    }
828
829    fn compute_exact_kernel_matrix(
830        &self,
831        x: &Array2<f64>,
832        kernel: &BaseKernel,
833    ) -> Result<Array2<f64>> {
834        let n_samples = x.nrows();
835        let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
836
837        for i in 0..n_samples {
838            for j in i..n_samples {
839                let value = kernel.evaluate(&x.row(i).to_owned(), &x.row(j).to_owned());
840                kernel_matrix[[i, j]] = value;
841                kernel_matrix[[j, i]] = value;
842            }
843        }
844
845        Ok(kernel_matrix)
846    }
847
848    /// Placeholder transform methods
849    fn transform_rff(
850        &self,
851        x: &Array2<f64>,
852        _training_data: &Array2<f64>,
853        _kernel: &BaseKernel,
854        n_components: usize,
855    ) -> Result<Array2<f64>> {
856        // Return random features as placeholder
857        let (n_samples, _) = x.dim();
858        Ok(Array2::zeros((n_samples, n_components)))
859    }
860
861    fn transform_nystroem(
862        &self,
863        x: &Array2<f64>,
864        _training_data: &Array2<f64>,
865        _kernel: &BaseKernel,
866        n_components: usize,
867    ) -> Result<Array2<f64>> {
868        // Return placeholder features
869        let (n_samples, _) = x.dim();
870        Ok(Array2::zeros((n_samples, n_components)))
871    }
872
873    fn transform_structured(
874        &self,
875        x: &Array2<f64>,
876        _training_data: &Array2<f64>,
877        _kernel: &BaseKernel,
878        n_components: usize,
879    ) -> Result<Array2<f64>> {
880        // Return placeholder features
881        let (n_samples, _) = x.dim();
882        Ok(Array2::zeros((n_samples, n_components)))
883    }
884
885    fn compute_combined_representation(&mut self) -> Result<()> {
886        // Placeholder for combining features
887        Ok(())
888    }
889
890    fn compute_kernel_statistics(
891        &self,
892        kernel: &Array2<f64>,
893        _y: Option<&Array1<f64>>,
894    ) -> Result<KernelStatistics> {
895        let mut stats = KernelStatistics::new();
896
897        // Compute basic statistics
898        stats.alignment = kernel.diag().mean().unwrap_or(0.0);
899
900        // Simplified eigenspectrum (just use diagonal)
901        stats.eigenspectrum = kernel.diag().to_owned();
902
903        // Effective rank approximation
904        let trace = kernel.diag().sum();
905        let frobenius_sq = kernel.mapv(|x| x * x).sum();
906        stats.effective_rank = if frobenius_sq > 1e-12 {
907            trace.powi(2) / frobenius_sq
908        } else {
909            0.0
910        };
911
912        // Diversity measure (variance of diagonal)
913        stats.diversity = kernel.diag().var(0.0);
914
915        // Complexity measure (condition number approximation)
916        let diag = kernel.diag();
917        let max_eig = diag.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
918        let min_eig = diag.iter().fold(f64::INFINITY, |a, &b| a.min(b.max(1e-12)));
919        stats.complexity = max_eig / min_eig;
920
921        Ok(stats)
922    }
923
924    fn compute_simplified_eigenvalues(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
925        // Simplified eigenvalue computation using diagonal
926        Ok(matrix.diag().to_owned())
927    }
928}
929
930#[allow(non_snake_case)]
931#[cfg(test)]
932mod tests {
933    use super::*;
934    use scirs2_core::ndarray::array;
935
936    #[test]
937    fn test_base_kernel_evaluation() {
938        let x = array![1.0, 2.0, 3.0];
939        let y = array![1.0, 2.0, 3.0];
940
941        let rbf_kernel = BaseKernel::RBF { gamma: 0.1 };
942        let value = rbf_kernel.evaluate(&x, &y);
943        assert!((value - 1.0).abs() < 1e-10); // Same points should give 1.0
944
945        let linear_kernel = BaseKernel::Linear;
946        let value = linear_kernel.evaluate(&x, &y);
947        assert!((value - 14.0).abs() < 1e-10); // 1*1 + 2*2 + 3*3 = 14
948    }
949
950    #[test]
951    fn test_kernel_names() {
952        let rbf = BaseKernel::RBF { gamma: 0.5 };
953        assert_eq!(rbf.name(), "RBF(gamma=0.5000)");
954
955        let linear = BaseKernel::Linear;
956        assert_eq!(linear.name(), "Linear");
957
958        let poly = BaseKernel::Polynomial {
959            degree: 2.0,
960            gamma: 1.0,
961            coef0: 0.0,
962        };
963        assert_eq!(
964            poly.name(),
965            "Polynomial(degree=2.0, gamma=1.0000, coef0=0.0000)"
966        );
967    }
968
969    #[test]
970    fn test_multiple_kernel_learning_basic() {
971        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
972
973        let base_kernels = vec![
974            BaseKernel::RBF { gamma: 0.1 },
975            BaseKernel::Linear,
976            BaseKernel::Polynomial {
977                degree: 2.0,
978                gamma: 1.0,
979                coef0: 0.0,
980            },
981        ];
982
983        let mut mkl = MultipleKernelLearning::new(base_kernels).with_random_state(42);
984
985        mkl.fit(&x, None).unwrap();
986
987        let weights = mkl.kernel_weights().unwrap();
988        assert_eq!(weights.len(), 3);
989        assert!((weights.sum() - 1.0).abs() < 1e-10); // Should sum to 1 for convex combination
990    }
991
992    #[test]
993    fn test_kernel_statistics() {
994        let kernel = array![[1.0, 0.5, 0.2], [0.5, 1.0, 0.3], [0.2, 0.3, 1.0]];
995
996        let mkl = MultipleKernelLearning::new(vec![]);
997        let stats = mkl.compute_kernel_statistics(&kernel, None).unwrap();
998
999        assert!((stats.alignment - 1.0).abs() < 1e-10); // Diagonal mean should be 1.0
1000        assert!(stats.effective_rank > 0.0);
1001        assert!(stats.diversity >= 0.0);
1002    }
1003
1004    #[test]
1005    fn test_combination_strategies() {
1006        let weights = array![0.5, -0.3, 0.8];
1007
1008        let mut mkl = MultipleKernelLearning::new(vec![]);
1009        mkl.config.combination_strategy = CombinationStrategy::Convex;
1010
1011        let constrained = mkl.apply_combination_constraints(weights.clone()).unwrap();
1012
1013        // Should be non-negative and sum to 1
1014        assert!(constrained.iter().all(|&x| x >= 0.0));
1015        assert!((constrained.sum() - 1.0).abs() < 1e-10);
1016    }
1017
1018    #[test]
1019    fn test_mkl_config() {
1020        let config = MultiKernelConfig {
1021            combination_strategy: CombinationStrategy::Linear,
1022            weight_learning: WeightLearningAlgorithm::SimpleMKL {
1023                regularization: 0.01,
1024            },
1025            approximation_method: ApproximationMethod::Nystroem { n_components: 50 },
1026            max_iterations: 200,
1027            tolerance: 1e-8,
1028            normalize_kernels: false,
1029            center_kernels: false,
1030            regularization: 0.001,
1031        };
1032
1033        assert!(matches!(
1034            config.combination_strategy,
1035            CombinationStrategy::Linear
1036        ));
1037        assert!(matches!(
1038            config.weight_learning,
1039            WeightLearningAlgorithm::SimpleMKL { .. }
1040        ));
1041        assert_eq!(config.max_iterations, 200);
1042        assert!(!config.normalize_kernels);
1043    }
1044
1045    #[test]
1046    fn test_important_kernels() {
1047        let base_kernels = vec![
1048            BaseKernel::RBF { gamma: 0.1 },
1049            BaseKernel::Linear,
1050            BaseKernel::Polynomial {
1051                degree: 2.0,
1052                gamma: 1.0,
1053                coef0: 0.0,
1054            },
1055        ];
1056
1057        let mut mkl = MultipleKernelLearning::new(base_kernels);
1058        mkl.weights = Some(array![0.6, 0.05, 0.35]);
1059
1060        let important = mkl.important_kernels(0.1);
1061        assert_eq!(important.len(), 2); // Only kernels with weight >= 0.1
1062        assert_eq!(important[0].0, 0); // First kernel (RBF)
1063        assert_eq!(important[1].0, 2); // Third kernel (Polynomial)
1064    }
1065
1066    #[test]
1067    fn test_supervised_vs_unsupervised() {
1068        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
1069        let y = array![0.0, 1.0, 0.0, 1.0];
1070
1071        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1072
1073        let mut mkl_unsupervised =
1074            MultipleKernelLearning::new(base_kernels.clone()).with_random_state(42);
1075        mkl_unsupervised.fit(&x, None).unwrap();
1076
1077        let mut mkl_supervised = MultipleKernelLearning::new(base_kernels).with_random_state(42);
1078        mkl_supervised.fit(&x, Some(&y)).unwrap();
1079
1080        // Both should work without errors
1081        assert!(mkl_unsupervised.kernel_weights().is_some());
1082        assert!(mkl_supervised.kernel_weights().is_some());
1083    }
1084
1085    #[test]
1086    fn test_transform_compatibility() {
1087        let x_train = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
1088        let x_test = array![[2.0, 3.0], [4.0, 5.0]];
1089
1090        let base_kernels = vec![BaseKernel::RBF { gamma: 0.1 }, BaseKernel::Linear];
1091
1092        let mut mkl = MultipleKernelLearning::new(base_kernels)
1093            .with_config(MultiKernelConfig {
1094                approximation_method: ApproximationMethod::RandomFourierFeatures {
1095                    n_components: 10,
1096                },
1097                ..Default::default()
1098            })
1099            .with_random_state(42);
1100
1101        mkl.fit(&x_train, None).unwrap();
1102        let features = mkl.transform(&x_test).unwrap();
1103
1104        assert_eq!(features.nrows(), 2); // Two test samples
1105        assert!(features.ncols() > 0); // Some features generated
1106    }
1107}