Skip to main content

tensorlogic_sklears_kernels/
kernel_selection.rs

1// Allow needless_range_loop for matrix operations which are clearer with indexed loops
2#![allow(clippy::needless_range_loop)]
3
4//! Kernel selection and cross-validation utilities.
5//!
6//! This module provides tools for selecting the best kernel and hyperparameters
7//! for a given dataset, including:
8//!
9//! - **Kernel Target Alignment (KTA)**: Quick kernel quality metric
10//! - **Leave-One-Out Cross-Validation (LOO-CV)**: Efficient for GP regression
11//! - **K-Fold Cross-Validation**: For general kernel evaluation
12//! - **Kernel comparison utilities**: Compare multiple kernels on same data
13//!
14//! ## Example
15//!
16//! ```rust
17//! use tensorlogic_sklears_kernels::kernel_selection::{
18//!     KernelSelector, KernelComparison, KFoldConfig
19//! };
20//! use tensorlogic_sklears_kernels::{RbfKernel, RbfKernelConfig, LinearKernel, Kernel};
21//!
22//! // Create kernels to compare
23//! let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
24//! let linear = LinearKernel::new();
25//!
26//! // Sample data
27//! let data = vec![
28//!     vec![1.0, 2.0],
29//!     vec![2.0, 3.0],
30//!     vec![3.0, 4.0],
31//!     vec![4.0, 5.0],
32//! ];
33//! let targets = vec![1.0, 2.0, 3.0, 4.0];
34//!
35//! // Compare using Kernel Target Alignment
36//! let selector = KernelSelector::new();
37//! let rbf_kta = selector.kernel_target_alignment(&rbf, &data, &targets).unwrap();
38//! let linear_kta = selector.kernel_target_alignment(&linear, &data, &targets).unwrap();
39//! ```
40
41use crate::error::{KernelError, Result};
42use crate::types::Kernel;
43
44/// Configuration for K-fold cross-validation.
45#[derive(Debug, Clone)]
46pub struct KFoldConfig {
47    /// Number of folds
48    pub n_folds: usize,
49    /// Whether to shuffle the data
50    pub shuffle: bool,
51    /// Random seed for shuffling (if enabled)
52    pub seed: Option<u64>,
53}
54
55impl Default for KFoldConfig {
56    fn default() -> Self {
57        Self {
58            n_folds: 5,
59            shuffle: false,
60            seed: None,
61        }
62    }
63}
64
65impl KFoldConfig {
66    /// Create a new K-fold configuration.
67    pub fn new(n_folds: usize) -> Self {
68        Self {
69            n_folds,
70            ..Default::default()
71        }
72    }
73
74    /// Enable shuffling with optional seed.
75    pub fn with_shuffle(mut self, shuffle: bool, seed: Option<u64>) -> Self {
76        self.shuffle = shuffle;
77        self.seed = seed;
78        self
79    }
80}
81
82/// Results from cross-validation.
83#[derive(Debug, Clone)]
84pub struct CrossValidationResult {
85    /// Mean score across folds
86    pub mean_score: f64,
87    /// Standard deviation of scores
88    pub std_score: f64,
89    /// Individual fold scores
90    pub fold_scores: Vec<f64>,
91    /// Total computation time in microseconds
92    pub compute_time_us: u64,
93}
94
95impl CrossValidationResult {
96    /// Create a new cross-validation result.
97    pub fn new(fold_scores: Vec<f64>, compute_time_us: u64) -> Self {
98        let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
99        let variance = fold_scores
100            .iter()
101            .map(|s| (s - mean_score).powi(2))
102            .sum::<f64>()
103            / fold_scores.len() as f64;
104        let std_score = variance.sqrt();
105
106        Self {
107            mean_score,
108            std_score,
109            fold_scores,
110            compute_time_us,
111        }
112    }
113
114    /// Get the 95% confidence interval.
115    pub fn confidence_interval(&self) -> (f64, f64) {
116        let margin = 1.96 * self.std_score / (self.fold_scores.len() as f64).sqrt();
117        (self.mean_score - margin, self.mean_score + margin)
118    }
119}
120
121/// Comparison results for multiple kernels.
122#[derive(Debug, Clone)]
123pub struct KernelComparison {
124    /// Kernel names
125    pub kernel_names: Vec<String>,
126    /// Scores for each kernel
127    pub scores: Vec<f64>,
128    /// Standard deviations (if available)
129    pub std_devs: Option<Vec<f64>>,
130    /// Index of the best kernel
131    pub best_index: usize,
132}
133
134impl KernelComparison {
135    /// Create a comparison from scores.
136    pub fn from_scores(kernel_names: Vec<String>, scores: Vec<f64>) -> Self {
137        let best_index = scores
138            .iter()
139            .enumerate()
140            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
141            .map(|(i, _)| i)
142            .unwrap_or(0);
143
144        Self {
145            kernel_names,
146            scores,
147            std_devs: None,
148            best_index,
149        }
150    }
151
152    /// Add standard deviations.
153    pub fn with_std_devs(mut self, std_devs: Vec<f64>) -> Self {
154        self.std_devs = Some(std_devs);
155        self
156    }
157
158    /// Get the best kernel name.
159    pub fn best_kernel(&self) -> &str {
160        &self.kernel_names[self.best_index]
161    }
162
163    /// Get the best score.
164    pub fn best_score(&self) -> f64 {
165        self.scores[self.best_index]
166    }
167
168    /// Generate a summary report.
169    pub fn summary(&self) -> String {
170        let mut report = String::from("Kernel Comparison Results:\n");
171        report.push_str(&format!("{:=<50}\n", ""));
172
173        for (i, name) in self.kernel_names.iter().enumerate() {
174            let score = self.scores[i];
175            let std = self
176                .std_devs
177                .as_ref()
178                .map(|s| format!(" ± {:.4}", s[i]))
179                .unwrap_or_default();
180            let best = if i == self.best_index { " *BEST*" } else { "" };
181            report.push_str(&format!("{:20} : {:.4}{}{}\n", name, score, std, best));
182        }
183
184        report
185    }
186}
187
188/// Kernel selector for choosing and comparing kernels.
189#[derive(Debug, Clone, Default)]
190pub struct KernelSelector {
191    /// Regularization for numerical stability
192    regularization: f64,
193}
194
195impl KernelSelector {
196    /// Create a new kernel selector.
197    pub fn new() -> Self {
198        Self {
199            regularization: 1e-6,
200        }
201    }
202
203    /// Set the regularization parameter.
204    pub fn with_regularization(mut self, reg: f64) -> Self {
205        self.regularization = reg;
206        self
207    }
208
209    /// Compute Kernel Target Alignment (KTA).
210    ///
211    /// KTA measures how well a kernel aligns with the target labels.
212    /// Higher values indicate better alignment.
213    ///
214    /// Formula: KTA = <K, yy^T>_F / (||K||_F * ||yy^T||_F)
215    ///
216    /// # Arguments
217    /// * `kernel` - The kernel to evaluate
218    /// * `data` - Input data points
219    /// * `targets` - Target values (for regression) or labels (for classification)
220    pub fn kernel_target_alignment<K: Kernel + ?Sized>(
221        &self,
222        kernel: &K,
223        data: &[Vec<f64>],
224        targets: &[f64],
225    ) -> Result<f64> {
226        if data.len() != targets.len() {
227            return Err(KernelError::ComputationError(
228                "data and targets must have same length".to_string(),
229            ));
230        }
231        if data.is_empty() {
232            return Err(KernelError::ComputationError(
233                "data cannot be empty".to_string(),
234            ));
235        }
236
237        let n = data.len();
238
239        // Compute kernel matrix K
240        let k_matrix = kernel.compute_matrix(data)?;
241
242        // Compute target matrix yy^T
243        let mut y_matrix = vec![vec![0.0; n]; n];
244        for i in 0..n {
245            for j in 0..n {
246                y_matrix[i][j] = targets[i] * targets[j];
247            }
248        }
249
250        // Compute Frobenius inner product <K, yy^T>
251        let mut k_y_product = 0.0;
252        for i in 0..n {
253            for j in 0..n {
254                k_y_product += k_matrix[i][j] * y_matrix[i][j];
255            }
256        }
257
258        // Compute ||K||_F
259        let mut k_norm_sq = 0.0;
260        for i in 0..n {
261            for j in 0..n {
262                k_norm_sq += k_matrix[i][j] * k_matrix[i][j];
263            }
264        }
265        let k_norm = k_norm_sq.sqrt();
266
267        // Compute ||yy^T||_F
268        let mut y_norm_sq = 0.0;
269        for i in 0..n {
270            for j in 0..n {
271                y_norm_sq += y_matrix[i][j] * y_matrix[i][j];
272            }
273        }
274        let y_norm = y_norm_sq.sqrt();
275
276        // KTA
277        if k_norm < 1e-10 || y_norm < 1e-10 {
278            return Ok(0.0);
279        }
280
281        Ok(k_y_product / (k_norm * y_norm))
282    }
283
284    /// Compute centered Kernel Target Alignment.
285    ///
286    /// Centered KTA is more robust and accounts for the mean of the kernel matrix.
287    pub fn centered_kernel_target_alignment<K: Kernel + ?Sized>(
288        &self,
289        kernel: &K,
290        data: &[Vec<f64>],
291        targets: &[f64],
292    ) -> Result<f64> {
293        if data.len() != targets.len() {
294            return Err(KernelError::ComputationError(
295                "data and targets must have same length".to_string(),
296            ));
297        }
298        if data.is_empty() {
299            return Err(KernelError::ComputationError(
300                "data cannot be empty".to_string(),
301            ));
302        }
303
304        let n = data.len();
305
306        // Compute kernel matrix K
307        let k_matrix = kernel.compute_matrix(data)?;
308
309        // Center the kernel matrix
310        let centered_k = center_kernel_matrix(&k_matrix);
311
312        // Center the target matrix
313        let target_mean: f64 = targets.iter().sum::<f64>() / n as f64;
314        let centered_targets: Vec<f64> = targets.iter().map(|t| t - target_mean).collect();
315
316        // Compute centered target matrix
317        let mut y_matrix = vec![vec![0.0; n]; n];
318        for i in 0..n {
319            for j in 0..n {
320                y_matrix[i][j] = centered_targets[i] * centered_targets[j];
321            }
322        }
323
324        // Compute Frobenius inner product
325        let mut k_y_product = 0.0;
326        let mut k_norm_sq = 0.0;
327        let mut y_norm_sq = 0.0;
328
329        for i in 0..n {
330            for j in 0..n {
331                k_y_product += centered_k[i][j] * y_matrix[i][j];
332                k_norm_sq += centered_k[i][j] * centered_k[i][j];
333                y_norm_sq += y_matrix[i][j] * y_matrix[i][j];
334            }
335        }
336
337        let k_norm = k_norm_sq.sqrt();
338        let y_norm = y_norm_sq.sqrt();
339
340        if k_norm < 1e-10 || y_norm < 1e-10 {
341            return Ok(0.0);
342        }
343
344        Ok(k_y_product / (k_norm * y_norm))
345    }
346
347    /// Compare multiple kernels using KTA.
348    ///
349    /// Returns a comparison with the best kernel identified.
350    pub fn compare_kernels_kta(
351        &self,
352        kernels: &[(&str, &dyn Kernel)],
353        data: &[Vec<f64>],
354        targets: &[f64],
355    ) -> Result<KernelComparison> {
356        let mut names = Vec::with_capacity(kernels.len());
357        let mut scores = Vec::with_capacity(kernels.len());
358
359        for (name, kernel) in kernels {
360            let kta = self.kernel_target_alignment(*kernel, data, targets)?;
361            names.push(name.to_string());
362            scores.push(kta);
363        }
364
365        Ok(KernelComparison::from_scores(names, scores))
366    }
367
368    /// Evaluate kernel quality using Leave-One-Out (LOO) error estimate.
369    ///
370    /// For GP regression, this provides an efficient estimate of generalization error.
371    /// Lower values indicate better performance.
372    ///
373    /// Note: This is an approximation based on the kernel matrix.
374    pub fn loo_error_estimate<K: Kernel + ?Sized>(
375        &self,
376        kernel: &K,
377        data: &[Vec<f64>],
378        targets: &[f64],
379    ) -> Result<f64> {
380        if data.len() != targets.len() {
381            return Err(KernelError::ComputationError(
382                "data and targets must have same length".to_string(),
383            ));
384        }
385        if data.len() < 2 {
386            return Err(KernelError::ComputationError(
387                "need at least 2 data points".to_string(),
388            ));
389        }
390
391        let n = data.len();
392
393        // Compute kernel matrix with regularization
394        let k_matrix = kernel.compute_matrix(data)?;
395        let mut k_reg = k_matrix.clone();
396        for i in 0..n {
397            k_reg[i][i] += self.regularization;
398        }
399
400        // Compute inverse using simple method (Gauss-Jordan elimination)
401        let k_inv = invert_matrix(&k_reg)?;
402
403        // Compute alpha = K^{-1} y
404        let mut alpha = vec![0.0; n];
405        for i in 0..n {
406            for j in 0..n {
407                alpha[i] += k_inv[i][j] * targets[j];
408            }
409        }
410
411        // LOO error: sum_i (alpha_i / K^{-1}_{ii})^2 / n
412        let mut loo_error = 0.0;
413        for i in 0..n {
414            let diag = k_inv[i][i];
415            if diag.abs() > 1e-10 {
416                let loo_residual = alpha[i] / diag;
417                loo_error += loo_residual * loo_residual;
418            }
419        }
420
421        Ok(loo_error / n as f64)
422    }
423
424    /// Perform K-fold cross-validation for kernel evaluation.
425    ///
426    /// This evaluates a kernel by training on K-1 folds and testing on the remaining fold.
427    /// The score returned is based on kernel alignment with targets.
428    pub fn k_fold_cv<K: Kernel + ?Sized>(
429        &self,
430        kernel: &K,
431        data: &[Vec<f64>],
432        targets: &[f64],
433        config: &KFoldConfig,
434    ) -> Result<CrossValidationResult> {
435        use std::time::Instant;
436
437        if data.len() != targets.len() {
438            return Err(KernelError::ComputationError(
439                "data and targets must have same length".to_string(),
440            ));
441        }
442        if data.len() < config.n_folds {
443            return Err(KernelError::ComputationError(format!(
444                "need at least {} data points for {}-fold CV",
445                config.n_folds, config.n_folds
446            )));
447        }
448
449        let start = Instant::now();
450        let n = data.len();
451
452        // Create indices (optionally shuffled)
453        let mut indices: Vec<usize> = (0..n).collect();
454        if config.shuffle {
455            // Simple deterministic shuffle using seed
456            let seed = config.seed.unwrap_or(42);
457            shuffle_indices(&mut indices, seed);
458        }
459
460        // Split into folds
461        let fold_size = n / config.n_folds;
462        let mut fold_scores = Vec::with_capacity(config.n_folds);
463
464        for fold in 0..config.n_folds {
465            let fold_start = fold * fold_size;
466            let fold_end = if fold == config.n_folds - 1 {
467                n
468            } else {
469                fold_start + fold_size
470            };
471
472            // Split data
473            let test_indices: Vec<_> = indices[fold_start..fold_end].to_vec();
474            let train_indices: Vec<_> = indices[0..fold_start]
475                .iter()
476                .chain(indices[fold_end..].iter())
477                .copied()
478                .collect();
479
480            // Collect train/test data
481            let _train_data: Vec<_> = train_indices.iter().map(|&i| data[i].clone()).collect();
482            let _train_targets: Vec<_> = train_indices.iter().map(|&i| targets[i]).collect();
483            let test_data: Vec<_> = test_indices.iter().map(|&i| data[i].clone()).collect();
484            let test_targets: Vec<_> = test_indices.iter().map(|&i| targets[i]).collect();
485
486            // Evaluate on this fold using KTA on test set
487            // Note: Full CV would train on train_data/train_targets then evaluate on test
488            // For simplicity, we use KTA on test fold as the score
489            let score = self.kernel_target_alignment(kernel, &test_data, &test_targets)?;
490            fold_scores.push(score);
491        }
492
493        let compute_time_us = start.elapsed().as_micros() as u64;
494        Ok(CrossValidationResult::new(fold_scores, compute_time_us))
495    }
496
497    /// Find the best gamma for RBF kernel using grid search.
498    ///
499    /// Searches over a logarithmic grid of gamma values.
500    pub fn grid_search_rbf_gamma(
501        &self,
502        data: &[Vec<f64>],
503        targets: &[f64],
504        gammas: &[f64],
505    ) -> Result<GammaSearchResult> {
506        use crate::tensor_kernels::RbfKernel;
507        use crate::types::RbfKernelConfig;
508
509        let mut best_gamma = gammas[0];
510        let mut best_score = f64::NEG_INFINITY;
511        let mut all_scores = Vec::with_capacity(gammas.len());
512
513        for &gamma in gammas {
514            let config = RbfKernelConfig::new(gamma);
515            let kernel = RbfKernel::new(config)?;
516            let score = self.centered_kernel_target_alignment(&kernel, data, targets)?;
517            all_scores.push((gamma, score));
518
519            if score > best_score {
520                best_score = score;
521                best_gamma = gamma;
522            }
523        }
524
525        Ok(GammaSearchResult {
526            best_gamma,
527            best_score,
528            all_scores,
529        })
530    }
531}
532
533/// Result of gamma grid search for RBF kernel.
534#[derive(Debug, Clone)]
535pub struct GammaSearchResult {
536    /// Best gamma value found
537    pub best_gamma: f64,
538    /// Best KTA score
539    pub best_score: f64,
540    /// All (gamma, score) pairs tested
541    pub all_scores: Vec<(f64, f64)>,
542}
543
544impl GammaSearchResult {
545    /// Get a summary of the search results.
546    pub fn summary(&self) -> String {
547        let mut s = format!(
548            "RBF Gamma Search:\n  Best gamma: {:.6}\n  Best score: {:.4}\n\n",
549            self.best_gamma, self.best_score
550        );
551        s.push_str("All results:\n");
552        for (gamma, score) in &self.all_scores {
553            let marker = if (*gamma - self.best_gamma).abs() < 1e-10 {
554                " *"
555            } else {
556                ""
557            };
558            s.push_str(&format!("  gamma={:.6}: {:.4}{}\n", gamma, score, marker));
559        }
560        s
561    }
562}
563
564/// Center a kernel matrix: K_c = H K H where H = I - (1/n) * 1 * 1^T
565fn center_kernel_matrix(k: &[Vec<f64>]) -> Vec<Vec<f64>> {
566    let n = k.len();
567    if n == 0 {
568        return vec![];
569    }
570
571    // Compute row means, column means, and global mean
572    let mut row_means = vec![0.0; n];
573    let mut col_means = vec![0.0; n];
574    let mut global_mean = 0.0;
575
576    for (i, row) in k.iter().enumerate() {
577        for (j, &val) in row.iter().enumerate() {
578            row_means[i] += val;
579            col_means[j] += val;
580            global_mean += val;
581        }
582    }
583
584    let n_f = n as f64;
585    for mean in &mut row_means {
586        *mean /= n_f;
587    }
588    for mean in &mut col_means {
589        *mean /= n_f;
590    }
591    global_mean /= n_f * n_f;
592
593    // Center: K_c[i][j] = K[i][j] - row_mean[i] - col_mean[j] + global_mean
594    let mut centered = vec![vec![0.0; n]; n];
595    for i in 0..n {
596        for j in 0..n {
597            centered[i][j] = k[i][j] - row_means[i] - col_means[j] + global_mean;
598        }
599    }
600
601    centered
602}
603
604/// Simple matrix inversion using Gauss-Jordan elimination.
605fn invert_matrix(matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
606    let n = matrix.len();
607    if n == 0 {
608        return Err(KernelError::ComputationError(
609            "cannot invert empty matrix".to_string(),
610        ));
611    }
612
613    // Create augmented matrix [A | I]
614    let mut aug = vec![vec![0.0; 2 * n]; n];
615    for i in 0..n {
616        for j in 0..n {
617            aug[i][j] = matrix[i][j];
618        }
619        aug[i][n + i] = 1.0;
620    }
621
622    // Forward elimination with partial pivoting
623    for i in 0..n {
624        // Find pivot
625        let mut max_row = i;
626        let mut max_val = aug[i][i].abs();
627        for k in (i + 1)..n {
628            if aug[k][i].abs() > max_val {
629                max_val = aug[k][i].abs();
630                max_row = k;
631            }
632        }
633
634        if max_val < 1e-10 {
635            return Err(KernelError::ComputationError(
636                "matrix is singular or nearly singular".to_string(),
637            ));
638        }
639
640        // Swap rows
641        if max_row != i {
642            aug.swap(i, max_row);
643        }
644
645        // Eliminate column
646        let pivot = aug[i][i];
647        for j in 0..(2 * n) {
648            aug[i][j] /= pivot;
649        }
650
651        for k in 0..n {
652            if k != i {
653                let factor = aug[k][i];
654                for j in 0..(2 * n) {
655                    aug[k][j] -= factor * aug[i][j];
656                }
657            }
658        }
659    }
660
661    // Extract inverse
662    let mut inverse = vec![vec![0.0; n]; n];
663    for i in 0..n {
664        for j in 0..n {
665            inverse[i][j] = aug[i][n + j];
666        }
667    }
668
669    Ok(inverse)
670}
671
672/// Simple shuffle using a deterministic PRNG.
673fn shuffle_indices(indices: &mut [usize], seed: u64) {
674    let n = indices.len();
675    let mut state = seed;
676
677    for i in (1..n).rev() {
678        // Simple LCG for deterministic shuffling
679        state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
680        let j = (state >> 33) as usize % (i + 1);
681        indices.swap(i, j);
682    }
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688    use crate::tensor_kernels::{LinearKernel, RbfKernel};
689    use crate::types::RbfKernelConfig;
690
691    #[test]
692    fn test_kfold_config() {
693        let config = KFoldConfig::new(10);
694        assert_eq!(config.n_folds, 10);
695        assert!(!config.shuffle);
696    }
697
698    #[test]
699    fn test_kfold_config_with_shuffle() {
700        let config = KFoldConfig::new(5).with_shuffle(true, Some(42));
701        assert_eq!(config.n_folds, 5);
702        assert!(config.shuffle);
703        assert_eq!(config.seed, Some(42));
704    }
705
706    #[test]
707    fn test_cross_validation_result() {
708        let fold_scores = vec![0.8, 0.85, 0.75, 0.9, 0.82];
709        let result = CrossValidationResult::new(fold_scores.clone(), 1000);
710
711        assert!((result.mean_score - 0.824).abs() < 1e-10);
712        assert!(result.std_score > 0.0);
713        assert_eq!(result.fold_scores, fold_scores);
714    }
715
716    #[test]
717    fn test_kernel_comparison() {
718        let names = vec!["Linear".to_string(), "RBF".to_string()];
719        let scores = vec![0.5, 0.8];
720
721        let comp = KernelComparison::from_scores(names, scores);
722        assert_eq!(comp.best_index, 1);
723        assert_eq!(comp.best_kernel(), "RBF");
724        assert!((comp.best_score() - 0.8).abs() < 1e-10);
725    }
726
727    #[test]
728    fn test_kernel_comparison_summary() {
729        let names = vec!["Linear".to_string(), "RBF".to_string()];
730        let scores = vec![0.5, 0.8];
731        let std_devs = vec![0.05, 0.03];
732
733        let comp = KernelComparison::from_scores(names, scores).with_std_devs(std_devs);
734        let summary = comp.summary();
735
736        assert!(summary.contains("Linear"));
737        assert!(summary.contains("RBF"));
738        assert!(summary.contains("*BEST*"));
739    }
740
741    #[test]
742    fn test_kernel_target_alignment() {
743        let selector = KernelSelector::new();
744        let kernel = LinearKernel::new();
745
746        let data = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
747        let targets = vec![1.0, 2.0, 3.0, 4.0]; // Perfectly correlated
748
749        let kta = selector.kernel_target_alignment(&kernel, &data, &targets);
750        assert!(kta.is_ok());
751        let kta_val = kta.unwrap();
752        // For perfectly correlated data, KTA should be high
753        assert!(kta_val > 0.5);
754    }
755
756    #[test]
757    fn test_centered_kernel_target_alignment() {
758        let selector = KernelSelector::new();
759        let kernel = LinearKernel::new();
760
761        let data = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
762        let targets = vec![1.0, 2.0, 3.0, 4.0];
763
764        let ckta = selector.centered_kernel_target_alignment(&kernel, &data, &targets);
765        assert!(ckta.is_ok());
766    }
767
768    #[test]
769    fn test_kta_empty_data() {
770        let selector = KernelSelector::new();
771        let kernel = LinearKernel::new();
772
773        let result = selector.kernel_target_alignment(&kernel, &[], &[]);
774        assert!(result.is_err());
775    }
776
777    #[test]
778    fn test_kta_mismatched_lengths() {
779        let selector = KernelSelector::new();
780        let kernel = LinearKernel::new();
781
782        let data = vec![vec![1.0], vec![2.0]];
783        let targets = vec![1.0, 2.0, 3.0];
784
785        let result = selector.kernel_target_alignment(&kernel, &data, &targets);
786        assert!(result.is_err());
787    }
788
789    #[test]
790    fn test_compare_kernels_kta() {
791        let selector = KernelSelector::new();
792        let linear = LinearKernel::new();
793        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
794
795        let data = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
796        let targets = vec![1.0, 2.0, 3.0, 4.0];
797
798        let kernels: Vec<(&str, &dyn Kernel)> = vec![("Linear", &linear), ("RBF", &rbf)];
799
800        let comparison = selector.compare_kernels_kta(&kernels, &data, &targets);
801        assert!(comparison.is_ok());
802
803        let comp = comparison.unwrap();
804        assert_eq!(comp.kernel_names.len(), 2);
805        assert_eq!(comp.scores.len(), 2);
806    }
807
808    #[test]
809    fn test_loo_error_estimate() {
810        let selector = KernelSelector::new().with_regularization(0.1);
811        let kernel = LinearKernel::new();
812
813        let data = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
814        let targets = vec![1.0, 2.0, 3.0, 4.0];
815
816        let result = selector.loo_error_estimate(&kernel, &data, &targets);
817        assert!(result.is_ok());
818        let error = result.unwrap();
819        // Error should be finite and non-negative
820        assert!(error >= 0.0);
821        assert!(error.is_finite());
822    }
823
824    #[test]
825    fn test_k_fold_cv() {
826        let selector = KernelSelector::new();
827        let kernel = LinearKernel::new();
828        let config = KFoldConfig::new(3);
829
830        let data = vec![
831            vec![1.0],
832            vec![2.0],
833            vec![3.0],
834            vec![4.0],
835            vec![5.0],
836            vec![6.0],
837        ];
838        let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
839
840        let result = selector.k_fold_cv(&kernel, &data, &targets, &config);
841        assert!(result.is_ok());
842
843        let cv_result = result.unwrap();
844        assert_eq!(cv_result.fold_scores.len(), 3);
845        assert!(cv_result.mean_score.is_finite());
846    }
847
848    #[test]
849    fn test_k_fold_cv_with_shuffle() {
850        let selector = KernelSelector::new();
851        let kernel = LinearKernel::new();
852        let config = KFoldConfig::new(3).with_shuffle(true, Some(42));
853
854        let data = vec![
855            vec![1.0],
856            vec![2.0],
857            vec![3.0],
858            vec![4.0],
859            vec![5.0],
860            vec![6.0],
861        ];
862        let targets = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
863
864        let result = selector.k_fold_cv(&kernel, &data, &targets, &config);
865        assert!(result.is_ok());
866    }
867
868    #[test]
869    fn test_grid_search_rbf_gamma() {
870        let selector = KernelSelector::new();
871
872        let data = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
873        let targets = vec![1.0, 2.0, 3.0, 4.0];
874        let gammas = vec![0.01, 0.1, 1.0, 10.0];
875
876        let result = selector.grid_search_rbf_gamma(&data, &targets, &gammas);
877        assert!(result.is_ok());
878
879        let search_result = result.unwrap();
880        assert!(gammas.contains(&search_result.best_gamma));
881        assert_eq!(search_result.all_scores.len(), gammas.len());
882    }
883
884    #[test]
885    fn test_center_kernel_matrix() {
886        let k = vec![
887            vec![1.0, 0.5, 0.3],
888            vec![0.5, 1.0, 0.4],
889            vec![0.3, 0.4, 1.0],
890        ];
891
892        let centered = center_kernel_matrix(&k);
893        assert_eq!(centered.len(), 3);
894
895        // Centered matrix should have row and column means close to 0
896        let n = centered.len() as f64;
897        for row in &centered {
898            let row_mean: f64 = row.iter().sum::<f64>() / n;
899            assert!(row_mean.abs() < 1e-10);
900        }
901    }
902
903    #[test]
904    fn test_matrix_inversion() {
905        let matrix = vec![vec![4.0, 7.0], vec![2.0, 6.0]];
906
907        let inv = invert_matrix(&matrix).unwrap();
908
909        // Check A * A^{-1} = I
910        let n = matrix.len();
911        for i in 0..n {
912            for j in 0..n {
913                let mut sum = 0.0;
914                for k in 0..n {
915                    sum += matrix[i][k] * inv[k][j];
916                }
917                let expected = if i == j { 1.0 } else { 0.0 };
918                assert!((sum - expected).abs() < 1e-10);
919            }
920        }
921    }
922
923    #[test]
924    fn test_shuffle_deterministic() {
925        let mut indices1 = vec![0, 1, 2, 3, 4];
926        let mut indices2 = vec![0, 1, 2, 3, 4];
927
928        shuffle_indices(&mut indices1, 42);
929        shuffle_indices(&mut indices2, 42);
930
931        assert_eq!(indices1, indices2); // Same seed = same shuffle
932    }
933
934    #[test]
935    fn test_gamma_search_result_summary() {
936        let result = GammaSearchResult {
937            best_gamma: 0.1,
938            best_score: 0.9,
939            all_scores: vec![(0.01, 0.5), (0.1, 0.9), (1.0, 0.7)],
940        };
941
942        let summary = result.summary();
943        assert!(summary.contains("Best gamma: 0.1"));
944        assert!(summary.contains("Best score: 0.9"));
945    }
946}