Skip to main content

scirs2_transform/kernel/
kernels.rs

1//! Kernel Functions Library
2//!
3//! Provides a comprehensive set of kernel functions for use in kernel methods
4//! such as Kernel PCA, Kernel Ridge Regression, and Support Vector Machines.
5//!
6//! ## Available Kernels
7//!
8//! - **Linear**: `k(x, y) = x^T y`
9//! - **Polynomial**: `k(x, y) = (gamma * x^T y + coef0)^degree`
10//! - **RBF/Gaussian**: `k(x, y) = exp(-gamma * ||x - y||^2)`
11//! - **Laplacian**: `k(x, y) = exp(-gamma * ||x - y||_1)`
12//! - **Sigmoid/Tanh**: `k(x, y) = tanh(gamma * x^T y + coef0)`
13//!
14//! ## Gram Matrix and Centering
15//!
16//! The module also provides utilities for computing Gram matrices (kernel matrices)
17//! and centering them in feature space.
18
19use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21
22use crate::error::{Result, TransformError};
23
24/// Kernel function type
25#[derive(Debug, Clone, PartialEq)]
26pub enum KernelType {
27    /// Linear kernel: `k(x, y) = x^T y`
28    Linear,
29    /// Polynomial kernel: `k(x, y) = (gamma * x^T y + coef0)^degree`
30    Polynomial {
31        /// Scaling factor for the dot product
32        gamma: f64,
33        /// Independent term in the polynomial
34        coef0: f64,
35        /// Degree of the polynomial
36        degree: u32,
37    },
38    /// RBF (Gaussian) kernel: `k(x, y) = exp(-gamma * ||x - y||^2)`
39    RBF {
40        /// Width parameter (inverse of twice the squared bandwidth)
41        gamma: f64,
42    },
43    /// Laplacian kernel: `k(x, y) = exp(-gamma * ||x - y||_1)`
44    Laplacian {
45        /// Width parameter
46        gamma: f64,
47    },
48    /// Sigmoid (tanh) kernel: `k(x, y) = tanh(gamma * x^T y + coef0)`
49    Sigmoid {
50        /// Scaling factor
51        gamma: f64,
52        /// Independent term
53        coef0: f64,
54    },
55}
56
57impl KernelType {
58    /// Create an RBF kernel with automatic gamma selection based on data
59    ///
60    /// Uses the median heuristic: gamma = 1 / (2 * median_distance^2)
61    pub fn rbf_auto<S>(x: &ArrayBase<S, Ix2>) -> Result<Self>
62    where
63        S: Data,
64        S::Elem: Float + NumCast,
65    {
66        let gamma = estimate_rbf_gamma(x)?;
67        Ok(KernelType::RBF { gamma })
68    }
69
70    /// Create a polynomial kernel with default parameters
71    pub fn polynomial_default() -> Self {
72        KernelType::Polynomial {
73            gamma: 1.0,
74            coef0: 1.0,
75            degree: 3,
76        }
77    }
78
79    /// Create an RBF kernel with the given gamma
80    pub fn rbf(gamma: f64) -> Self {
81        KernelType::RBF { gamma }
82    }
83
84    /// Create a Laplacian kernel with the given gamma
85    pub fn laplacian(gamma: f64) -> Self {
86        KernelType::Laplacian { gamma }
87    }
88
89    /// Create a sigmoid kernel with default parameters
90    pub fn sigmoid_default() -> Self {
91        KernelType::Sigmoid {
92            gamma: 1.0,
93            coef0: 0.0,
94        }
95    }
96}
97
98/// Evaluate a kernel function between two vectors
99///
100/// # Arguments
101/// * `x` - First input vector
102/// * `y` - Second input vector
103/// * `kernel` - The kernel function type
104///
105/// # Returns
106/// * `Result<f64>` - The kernel evaluation k(x, y)
107pub fn kernel_eval<S1, S2>(
108    x: &ArrayBase<S1, Ix1>,
109    y: &ArrayBase<S2, Ix1>,
110    kernel: &KernelType,
111) -> Result<f64>
112where
113    S1: Data,
114    S2: Data,
115    S1::Elem: Float + NumCast,
116    S2::Elem: Float + NumCast,
117{
118    if x.len() != y.len() {
119        return Err(TransformError::InvalidInput(format!(
120            "Vector dimensions must match: {} vs {}",
121            x.len(),
122            y.len()
123        )));
124    }
125
126    let n = x.len();
127    match kernel {
128        KernelType::Linear => {
129            let mut dot = 0.0;
130            for i in 0..n {
131                let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
132                let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
133                dot += xi * yi;
134            }
135            Ok(dot)
136        }
137        KernelType::Polynomial {
138            gamma,
139            coef0,
140            degree,
141        } => {
142            let mut dot = 0.0;
143            for i in 0..n {
144                let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
145                let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
146                dot += xi * yi;
147            }
148            Ok((gamma * dot + coef0).powi(*degree as i32))
149        }
150        KernelType::RBF { gamma } => {
151            let mut dist_sq = 0.0;
152            for i in 0..n {
153                let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
154                let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
155                let diff = xi - yi;
156                dist_sq += diff * diff;
157            }
158            Ok((-gamma * dist_sq).exp())
159        }
160        KernelType::Laplacian { gamma } => {
161            let mut l1_dist = 0.0;
162            for i in 0..n {
163                let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
164                let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
165                l1_dist += (xi - yi).abs();
166            }
167            Ok((-gamma * l1_dist).exp())
168        }
169        KernelType::Sigmoid { gamma, coef0 } => {
170            let mut dot = 0.0;
171            for i in 0..n {
172                let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
173                let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
174                dot += xi * yi;
175            }
176            Ok((gamma * dot + coef0).tanh())
177        }
178    }
179}
180
181/// Compute the Gram matrix (kernel matrix) for a dataset
182///
183/// The Gram matrix K has entries K\[i,j\] = k(x_i, x_j).
184///
185/// # Arguments
186/// * `x` - Input data matrix, shape (n_samples, n_features)
187/// * `kernel` - The kernel function type
188///
189/// # Returns
190/// * `Result<Array2<f64>>` - The Gram matrix, shape (n_samples, n_samples)
191pub fn gram_matrix<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array2<f64>>
192where
193    S: Data,
194    S::Elem: Float + NumCast,
195{
196    let n_samples = x.nrows();
197    let mut k = Array2::zeros((n_samples, n_samples));
198
199    for i in 0..n_samples {
200        for j in i..n_samples {
201            let val = kernel_eval(&x.row(i), &x.row(j), kernel)?;
202            k[[i, j]] = val;
203            k[[j, i]] = val;
204        }
205    }
206
207    Ok(k)
208}
209
210/// Compute the Gram matrix between two datasets
211///
212/// K\[i,j\] = k(x_i, y_j)
213///
214/// # Arguments
215/// * `x` - First input data, shape (n_x, n_features)
216/// * `y` - Second input data, shape (n_y, n_features)
217/// * `kernel` - The kernel function type
218///
219/// # Returns
220/// * `Result<Array2<f64>>` - The cross-kernel matrix, shape (n_x, n_y)
221pub fn cross_gram_matrix<S1, S2>(
222    x: &ArrayBase<S1, Ix2>,
223    y: &ArrayBase<S2, Ix2>,
224    kernel: &KernelType,
225) -> Result<Array2<f64>>
226where
227    S1: Data,
228    S2: Data,
229    S1::Elem: Float + NumCast,
230    S2::Elem: Float + NumCast,
231{
232    if x.ncols() != y.ncols() {
233        return Err(TransformError::InvalidInput(format!(
234            "Feature dimensions must match: {} vs {}",
235            x.ncols(),
236            y.ncols()
237        )));
238    }
239
240    let n_x = x.nrows();
241    let n_y = y.nrows();
242    let mut k = Array2::zeros((n_x, n_y));
243
244    for i in 0..n_x {
245        for j in 0..n_y {
246            k[[i, j]] = kernel_eval(&x.row(i), &y.row(j), kernel)?;
247        }
248    }
249
250    Ok(k)
251}
252
253/// Center a kernel matrix in feature space
254///
255/// The centered kernel matrix is:
256/// K_c = K - 1_n K - K 1_n + 1_n K 1_n
257///
258/// where 1_n is the n x n matrix with all entries 1/n.
259///
260/// This corresponds to centering the data in the feature space
261/// without explicitly computing the feature map.
262///
263/// # Arguments
264/// * `k` - The kernel (Gram) matrix, shape (n, n)
265///
266/// # Returns
267/// * `Result<Array2<f64>>` - The centered kernel matrix
268pub fn center_kernel_matrix(k: &Array2<f64>) -> Result<Array2<f64>> {
269    let n = k.nrows();
270    if n != k.ncols() {
271        return Err(TransformError::InvalidInput(
272            "Kernel matrix must be square".to_string(),
273        ));
274    }
275    if n == 0 {
276        return Err(TransformError::InvalidInput(
277            "Kernel matrix must be non-empty".to_string(),
278        ));
279    }
280
281    let n_f64 = n as f64;
282
283    // Compute row means and column means (they should be equal for symmetric K)
284    let row_means = k.mean_axis(Axis(0)).ok_or_else(|| {
285        TransformError::ComputationError("Failed to compute row means".to_string())
286    })?;
287    let col_means = k.mean_axis(Axis(1)).ok_or_else(|| {
288        TransformError::ComputationError("Failed to compute column means".to_string())
289    })?;
290    let grand_mean = row_means.sum() / n_f64;
291
292    let mut k_centered = Array2::zeros((n, n));
293    for i in 0..n {
294        for j in 0..n {
295            k_centered[[i, j]] = k[[i, j]] - row_means[j] - col_means[i] + grand_mean;
296        }
297    }
298
299    Ok(k_centered)
300}
301
302/// Center a test kernel matrix using the training kernel matrix statistics
303///
304/// For out-of-sample data, the centering must use the training data statistics:
305/// K_test_c = K_test - 1'_m K_train - K_test 1_n + 1'_m K_train 1_n
306///
307/// # Arguments
308/// * `k_test` - Test kernel matrix, shape (m, n) where m = test samples, n = training samples
309/// * `k_train` - Training kernel matrix, shape (n, n)
310///
311/// # Returns
312/// * `Result<Array2<f64>>` - The centered test kernel matrix, shape (m, n)
313pub fn center_kernel_matrix_test(
314    k_test: &Array2<f64>,
315    k_train: &Array2<f64>,
316) -> Result<Array2<f64>> {
317    let n_train = k_train.nrows();
318    let n_test = k_test.nrows();
319
320    if k_train.nrows() != k_train.ncols() {
321        return Err(TransformError::InvalidInput(
322            "Training kernel matrix must be square".to_string(),
323        ));
324    }
325    if k_test.ncols() != n_train {
326        return Err(TransformError::InvalidInput(format!(
327            "Test kernel matrix columns ({}) must match training samples ({})",
328            k_test.ncols(),
329            n_train
330        )));
331    }
332
333    let n_f64 = n_train as f64;
334
335    // Mean of each column of K_train
336    let train_col_means = k_train.mean_axis(Axis(0)).ok_or_else(|| {
337        TransformError::ComputationError("Failed to compute train column means".to_string())
338    })?;
339
340    // Mean of each row of K_test (mean over training samples for each test point)
341    let test_row_means = k_test.mean_axis(Axis(1)).ok_or_else(|| {
342        TransformError::ComputationError("Failed to compute test row means".to_string())
343    })?;
344
345    // Grand mean of K_train
346    let train_grand_mean = train_col_means.sum() / n_f64;
347
348    let mut k_centered = Array2::zeros((n_test, n_train));
349    for i in 0..n_test {
350        for j in 0..n_train {
351            k_centered[[i, j]] =
352                k_test[[i, j]] - test_row_means[i] - train_col_means[j] + train_grand_mean;
353        }
354    }
355
356    Ok(k_centered)
357}
358
359/// Estimate the RBF gamma parameter using the median heuristic
360///
361/// gamma = 1 / (2 * median_distance^2)
362///
363/// This is a common automatic bandwidth selection method for the RBF kernel.
364///
365/// # Arguments
366/// * `x` - Input data, shape (n_samples, n_features)
367///
368/// # Returns
369/// * `Result<f64>` - The estimated gamma parameter
370pub fn estimate_rbf_gamma<S>(x: &ArrayBase<S, Ix2>) -> Result<f64>
371where
372    S: Data,
373    S::Elem: Float + NumCast,
374{
375    let n = x.nrows();
376    if n < 2 {
377        return Err(TransformError::InvalidInput(
378            "Need at least 2 samples to estimate gamma".to_string(),
379        ));
380    }
381
382    // Compute all pairwise squared distances
383    let mut distances: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
384    for i in 0..n {
385        for j in (i + 1)..n {
386            let mut dist_sq = 0.0;
387            for k in 0..x.ncols() {
388                let xi: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
389                let xj: f64 = NumCast::from(x[[j, k]]).unwrap_or(0.0);
390                let diff = xi - xj;
391                dist_sq += diff * diff;
392            }
393            distances.push(dist_sq);
394        }
395    }
396
397    // Sort distances
398    distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
399
400    // Get median squared distance
401    let median_sq = if distances.len() % 2 == 0 {
402        let mid = distances.len() / 2;
403        (distances[mid - 1] + distances[mid]) / 2.0
404    } else {
405        distances[distances.len() / 2]
406    };
407
408    if median_sq < 1e-15 {
409        // Data points are very close together, use a default
410        Ok(1.0)
411    } else {
412        Ok(1.0 / (2.0 * median_sq))
413    }
414}
415
416/// Compute the diagonal of a kernel matrix (self-similarities)
417///
418/// # Arguments
419/// * `x` - Input data, shape (n_samples, n_features)
420/// * `kernel` - The kernel function type
421///
422/// # Returns
423/// * `Result<Array1<f64>>` - Diagonal entries k(x_i, x_i)
424pub fn kernel_diagonal<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array1<f64>>
425where
426    S: Data,
427    S::Elem: Float + NumCast,
428{
429    let n = x.nrows();
430    let mut diag = Array1::zeros(n);
431
432    for i in 0..n {
433        diag[i] = kernel_eval(&x.row(i), &x.row(i), kernel)?;
434    }
435
436    Ok(diag)
437}
438
439/// Check if a kernel matrix is positive semi-definite
440///
441/// A kernel matrix should be positive semi-definite (PSD). This function
442/// checks by verifying that all eigenvalues are non-negative (within tolerance).
443///
444/// # Arguments
445/// * `k` - The kernel matrix to check
446/// * `tol` - Tolerance for negative eigenvalues (default: -1e-10)
447///
448/// # Returns
449/// * `Result<bool>` - True if the matrix is PSD within the given tolerance
450pub fn is_positive_semidefinite(k: &Array2<f64>, tol: f64) -> Result<bool> {
451    if k.nrows() != k.ncols() {
452        return Err(TransformError::InvalidInput(
453            "Matrix must be square".to_string(),
454        ));
455    }
456
457    let (eigenvalues, _) =
458        scirs2_linalg::eigh(&k.view(), None).map_err(TransformError::LinalgError)?;
459
460    let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
461
462    Ok(min_eigenvalue >= tol)
463}
464
465/// Kernel alignment between two kernel matrices
466///
467/// The alignment A(K1, K2) = <K1, K2>_F / (||K1||_F * ||K2||_F)
468/// where <.,.>_F is the Frobenius inner product.
469///
470/// This measures how similar two kernel matrices are.
471///
472/// # Arguments
473/// * `k1` - First kernel matrix
474/// * `k2` - Second kernel matrix
475///
476/// # Returns
477/// * `Result<f64>` - The alignment score in [0, 1]
478pub fn kernel_alignment(k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
479    if k1.dim() != k2.dim() {
480        return Err(TransformError::InvalidInput(
481            "Kernel matrices must have the same dimensions".to_string(),
482        ));
483    }
484
485    let frobenius_inner: f64 = k1.iter().zip(k2.iter()).map(|(&a, &b)| a * b).sum();
486    let norm1: f64 = k1.iter().map(|&a| a * a).sum::<f64>().sqrt();
487    let norm2: f64 = k2.iter().map(|&a| a * a).sum::<f64>().sqrt();
488
489    let denom = norm1 * norm2;
490    if denom < 1e-15 {
491        Ok(0.0)
492    } else {
493        Ok((frobenius_inner / denom).clamp(0.0, 1.0))
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use scirs2_core::ndarray::Array;
501
502    fn sample_data() -> Array2<f64> {
503        Array::from_shape_vec(
504            (5, 3),
505            vec![
506                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
507            ],
508        )
509        .expect("Failed to create sample data")
510    }
511
512    #[test]
513    fn test_linear_kernel() {
514        let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
515        let y = Array::from_vec(vec![4.0, 5.0, 6.0]);
516        let result =
517            kernel_eval(&x.view(), &y.view(), &KernelType::Linear).expect("kernel eval failed");
518        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
519        assert!((result - 32.0).abs() < 1e-10);
520    }
521
522    #[test]
523    fn test_polynomial_kernel() {
524        let x = Array::from_vec(vec![1.0, 2.0]);
525        let y = Array::from_vec(vec![3.0, 4.0]);
526        let kernel = KernelType::Polynomial {
527            gamma: 1.0,
528            coef0: 1.0,
529            degree: 2,
530        };
531        let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
532        // (1*3 + 2*4 + 1)^2 = (3 + 8 + 1)^2 = 12^2 = 144
533        assert!((result - 144.0).abs() < 1e-10);
534    }
535
536    #[test]
537    fn test_rbf_kernel() {
538        let x = Array::from_vec(vec![1.0, 0.0]);
539        let y = Array::from_vec(vec![0.0, 1.0]);
540        let kernel = KernelType::RBF { gamma: 0.5 };
541        let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
542        // exp(-0.5 * (1 + 1)) = exp(-1) ~ 0.3679
543        assert!((result - (-1.0_f64).exp()).abs() < 1e-10);
544    }
545
546    #[test]
547    fn test_rbf_kernel_self() {
548        let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
549        let kernel = KernelType::RBF { gamma: 1.0 };
550        let result = kernel_eval(&x.view(), &x.view(), &kernel).expect("kernel eval failed");
551        // k(x, x) = exp(0) = 1
552        assert!((result - 1.0).abs() < 1e-10);
553    }
554
555    #[test]
556    fn test_laplacian_kernel() {
557        let x = Array::from_vec(vec![1.0, 2.0]);
558        let y = Array::from_vec(vec![3.0, 4.0]);
559        let kernel = KernelType::Laplacian { gamma: 0.5 };
560        let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
561        // exp(-0.5 * (|1-3| + |2-4|)) = exp(-0.5 * 4) = exp(-2)
562        assert!((result - (-2.0_f64).exp()).abs() < 1e-10);
563    }
564
565    #[test]
566    fn test_sigmoid_kernel() {
567        let x = Array::from_vec(vec![1.0, 0.0]);
568        let y = Array::from_vec(vec![0.0, 1.0]);
569        let kernel = KernelType::Sigmoid {
570            gamma: 1.0,
571            coef0: 0.0,
572        };
573        let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
574        // tanh(1 * 0 + 0) = tanh(0) = 0
575        assert!((result - 0.0).abs() < 1e-10);
576    }
577
578    #[test]
579    fn test_gram_matrix_symmetry() {
580        let data = sample_data();
581        let kernel = KernelType::RBF { gamma: 0.1 };
582        let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
583
584        assert_eq!(k.shape(), &[5, 5]);
585        for i in 0..5 {
586            for j in 0..5 {
587                assert!(
588                    (k[[i, j]] - k[[j, i]]).abs() < 1e-10,
589                    "Gram matrix not symmetric at ({}, {})",
590                    i,
591                    j
592                );
593            }
594        }
595    }
596
597    #[test]
598    fn test_gram_matrix_diagonal() {
599        let data = sample_data();
600        let kernel = KernelType::RBF { gamma: 0.1 };
601        let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
602
603        // RBF diagonal should be 1.0
604        for i in 0..5 {
605            assert!(
606                (k[[i, i]] - 1.0).abs() < 1e-10,
607                "RBF diagonal should be 1.0"
608            );
609        }
610    }
611
612    #[test]
613    fn test_cross_gram_matrix() {
614        let x = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("Failed");
615        let y = Array::from_shape_vec((2, 2), vec![1.5, 2.5, 3.5, 4.5]).expect("Failed");
616        let kernel = KernelType::Linear;
617        let k = cross_gram_matrix(&x.view(), &y.view(), &kernel).expect("cross gram matrix failed");
618
619        assert_eq!(k.shape(), &[3, 2]);
620        // k[0,0] = 1*1.5 + 2*2.5 = 1.5 + 5.0 = 6.5
621        assert!((k[[0, 0]] - 6.5).abs() < 1e-10);
622    }
623
624    #[test]
625    fn test_center_kernel_matrix() {
626        let data = sample_data();
627        let kernel = KernelType::RBF { gamma: 0.01 };
628        let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
629        let k_centered = center_kernel_matrix(&k).expect("centering failed");
630
631        // Centered kernel matrix should have zero column means
632        let col_means = k_centered
633            .mean_axis(Axis(0))
634            .expect("Failed to compute means");
635        for i in 0..col_means.len() {
636            assert!(
637                col_means[i].abs() < 1e-10,
638                "Centered kernel column mean should be ~0, got {}",
639                col_means[i]
640            );
641        }
642
643        // And zero row means
644        let row_means = k_centered
645            .mean_axis(Axis(1))
646            .expect("Failed to compute means");
647        for i in 0..row_means.len() {
648            assert!(
649                row_means[i].abs() < 1e-10,
650                "Centered kernel row mean should be ~0, got {}",
651                row_means[i]
652            );
653        }
654    }
655
656    #[test]
657    fn test_center_kernel_matrix_test() {
658        let x_train = sample_data();
659        let x_test =
660            Array::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).expect("Failed");
661        let kernel = KernelType::RBF { gamma: 0.01 };
662
663        let k_train = gram_matrix(&x_train.view(), &kernel).expect("gram failed");
664        let k_test =
665            cross_gram_matrix(&x_test.view(), &x_train.view(), &kernel).expect("cross gram failed");
666
667        let k_test_centered =
668            center_kernel_matrix_test(&k_test, &k_train).expect("test centering failed");
669        assert_eq!(k_test_centered.shape(), &[2, 5]);
670
671        // Values should be finite
672        for val in k_test_centered.iter() {
673            assert!(val.is_finite());
674        }
675    }
676
677    #[test]
678    fn test_estimate_rbf_gamma() {
679        let data = sample_data();
680        let gamma = estimate_rbf_gamma(&data.view()).expect("gamma estimation failed");
681        assert!(gamma > 0.0);
682        assert!(gamma.is_finite());
683    }
684
685    #[test]
686    fn test_kernel_diagonal() {
687        let data = sample_data();
688        let kernel = KernelType::Linear;
689        let diag = kernel_diagonal(&data.view(), &kernel).expect("diagonal failed");
690
691        assert_eq!(diag.len(), 5);
692        // Linear kernel diagonal: k(x, x) = x^T x
693        // First row: 1^2 + 2^2 + 3^2 = 14
694        assert!((diag[0] - 14.0).abs() < 1e-10);
695    }
696
697    #[test]
698    fn test_rbf_gram_psd() {
699        let data = sample_data();
700        let kernel = KernelType::RBF { gamma: 0.1 };
701        let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
702        let psd = is_positive_semidefinite(&k, -1e-10).expect("PSD check failed");
703        assert!(psd, "RBF Gram matrix should be PSD");
704    }
705
706    #[test]
707    fn test_kernel_alignment() {
708        let data = sample_data();
709        let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
710        let k2 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
711
712        let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
713        // Same kernel should have alignment 1.0
714        assert!(
715            (alignment - 1.0).abs() < 1e-10,
716            "Self-alignment should be 1.0, got {}",
717            alignment
718        );
719    }
720
721    #[test]
722    fn test_kernel_alignment_different() {
723        let data = sample_data();
724        let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.01 }).expect("gram failed");
725        let k2 = gram_matrix(&data.view(), &KernelType::Linear).expect("gram failed");
726
727        let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
728        assert!(alignment >= 0.0 && alignment <= 1.0);
729    }
730
731    #[test]
732    fn test_rbf_auto() {
733        let data = sample_data();
734        let kernel = KernelType::rbf_auto(&data.view()).expect("auto rbf failed");
735        match kernel {
736            KernelType::RBF { gamma } => {
737                assert!(gamma > 0.0);
738                assert!(gamma.is_finite());
739            }
740            _ => panic!("Expected RBF kernel type"),
741        }
742    }
743
744    #[test]
745    fn test_dimension_mismatch() {
746        let x = Array::from_vec(vec![1.0, 2.0]);
747        let y = Array::from_vec(vec![1.0, 2.0, 3.0]);
748        let result = kernel_eval(&x.view(), &y.view(), &KernelType::Linear);
749        assert!(result.is_err());
750    }
751}