Skip to main content

tensorlogic_sklears_kernels/
kernel_alignment.rs

1//! # Kernel Alignment
2//!
3//! Implements Kernel Target Alignment (KTA), Centered Kernel Alignment (CKA),
4//! HSIC (Hilbert-Schmidt Independence Criterion), and optimization routines for
5//! selecting kernel hyperparameters based on alignment with a target kernel.
6//!
7//! ## Overview
8//!
9//! Kernel alignment metrics quantify how well a kernel matrix captures the
10//! structure of the learning problem. Given labels, one constructs an "ideal"
11//! target kernel where `T[i,j] = +1` if `labels[i] == labels[j]` and `-1`
12//! otherwise. A high alignment between `K` and `T` indicates that the kernel
13//! maps similar-class points close together.
14//!
15//! ### Kernel Target Alignment (KTA)
16//!
17//! ```text
18//! KTA(K, T) = <K, T>_F / (||K||_F * ||T||_F)
19//! ```
20//!
21//! ### Centered Kernel Alignment (CKA)
22//!
23//! CKA applies double-centering before computing alignment, making it invariant
24//! to isotropic scaling and constant shifts:
25//!
26//! ```text
27//! CKA(K1, K2) = HSIC(K1, K2) / sqrt(HSIC(K1,K1) * HSIC(K2,K2))
28//! ```
29//!
30//! where `HSIC(K, L) = (1/n^2) * <H*K*H, H*L*H>_F`.
31//!
32//! ## References
33//!
34//! - Cortes, C., Mohri, M., & Rostamizadeh, A. (2012). Algorithms for learning
35//!   kernels based on centered alignment. JMLR.
36//! - Kornblith, S., et al. (2019). Similarity of neural network representations
37//!   revisited. ICML.
38
39use std::fmt;
40
41// ---------------------------------------------------------------------------
42// Error type
43// ---------------------------------------------------------------------------
44
45/// Errors that can arise during kernel alignment computations.
46#[derive(Debug, Clone, PartialEq)]
47pub enum AlignmentError {
48    /// The supplied data does not form a square matrix.
49    NonSquareMatrix,
50    /// The two matrices have incompatible sizes.
51    DimensionMismatch {
52        /// Expected dimension.
53        expected: usize,
54        /// Received dimension.
55        got: usize,
56    },
57    /// A numerical issue was encountered (e.g. zero-norm matrix).
58    NumericalError(String),
59    /// The matrix is singular or near-singular.
60    SingularMatrix,
61}
62
63impl fmt::Display for AlignmentError {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            Self::NonSquareMatrix => write!(f, "Matrix is not square"),
67            Self::DimensionMismatch { expected, got } => write!(
68                f,
69                "Dimension mismatch: expected {}×{}, got {}×{}",
70                expected, expected, got, got
71            ),
72            Self::NumericalError(msg) => write!(f, "Numerical error: {}", msg),
73            Self::SingularMatrix => write!(f, "Matrix is singular or near-singular"),
74        }
75    }
76}
77
78impl std::error::Error for AlignmentError {}
79
80// ---------------------------------------------------------------------------
81// KernelMatrix
82// ---------------------------------------------------------------------------
83
84/// A square kernel matrix (n×n, symmetric positive semi-definite).
85///
86/// All matrix operations are implemented from scratch over `Vec<Vec<f64>>`.
87#[derive(Debug, Clone)]
88pub struct KernelMatrix {
89    data: Vec<Vec<f64>>,
90    n: usize,
91}
92
93impl KernelMatrix {
94    /// Construct from a row-major `Vec<Vec<f64>>`.
95    ///
96    /// Returns `Err(AlignmentError::NonSquareMatrix)` if any row has a length
97    /// different from the number of rows.
98    pub fn new(data: Vec<Vec<f64>>) -> Result<KernelMatrix, AlignmentError> {
99        let n = data.len();
100        for row in &data {
101            if row.len() != n {
102                return Err(AlignmentError::NonSquareMatrix);
103            }
104        }
105        Ok(KernelMatrix { data, n })
106    }
107
108    /// Construct from a flat slice of length `n*n` in row-major order.
109    pub fn from_flat(flat: &[f64], n: usize) -> Result<KernelMatrix, AlignmentError> {
110        if flat.len() != n * n {
111            return Err(AlignmentError::NonSquareMatrix);
112        }
113        let data = (0..n).map(|i| flat[i * n..(i + 1) * n].to_vec()).collect();
114        Ok(KernelMatrix { data, n })
115    }
116
117    /// Construct the `n×n` identity kernel matrix.
118    pub fn identity(n: usize) -> KernelMatrix {
119        let mut data = vec![vec![0.0_f64; n]; n];
120        #[allow(clippy::needless_range_loop)]
121        for i in 0..n {
122            data[i][i] = 1.0;
123        }
124        KernelMatrix { data, n }
125    }
126
127    /// Construct the "ideal" label kernel:
128    /// `K[i,j] = 1.0` if `labels[i] == labels[j]`, else `-1.0`.
129    ///
130    /// This is equivalent to the outer product of the label sign vector and is
131    /// the target used in KTA for binary or multi-class classification.
132    pub fn from_labels(labels: &[f64]) -> KernelMatrix {
133        let n = labels.len();
134        let mut data = vec![vec![0.0_f64; n]; n];
135        for i in 0..n {
136            for j in 0..n {
137                // Use approximate equality to handle floating-point labels
138                data[i][j] = if (labels[i] - labels[j]).abs() < 1e-10 {
139                    1.0
140                } else {
141                    -1.0
142                };
143            }
144        }
145        KernelMatrix { data, n }
146    }
147
148    /// Return element `(i, j)`.
149    #[inline]
150    pub fn get(&self, i: usize, j: usize) -> f64 {
151        self.data[i][j]
152    }
153
154    /// Return the dimension `n` (number of rows/columns).
155    #[inline]
156    pub fn n(&self) -> usize {
157        self.n
158    }
159
160    /// Compute the trace: `sum_i K[i,i]`.
161    pub fn trace(&self) -> f64 {
162        (0..self.n).map(|i| self.data[i][i]).sum()
163    }
164
165    /// Compute `||K||_F^2 = sum_{i,j} K[i,j]^2`.
166    pub fn frobenius_norm_sq(&self) -> f64 {
167        self.data
168            .iter()
169            .flat_map(|row| row.iter())
170            .map(|&v| v * v)
171            .sum()
172    }
173
174    /// Compute `<K1, K2>_F = sum_{i,j} K1[i,j] * K2[i,j]`.
175    ///
176    /// Returns `Err(AlignmentError::DimensionMismatch)` if the matrices have
177    /// different sizes.
178    pub fn frobenius_inner(&self, other: &KernelMatrix) -> Result<f64, AlignmentError> {
179        if self.n != other.n {
180            return Err(AlignmentError::DimensionMismatch {
181                expected: self.n,
182                got: other.n,
183            });
184        }
185        let mut sum = 0.0_f64;
186        for i in 0..self.n {
187            for j in 0..self.n {
188                sum += self.data[i][j] * other.data[i][j];
189            }
190        }
191        Ok(sum)
192    }
193
194    /// Double-center the kernel matrix: `K_c = H * K * H`
195    /// where `H = I - (1/n) * 1*1^T` is the centering matrix.
196    ///
197    /// Equivalent to:
198    /// ```text
199    /// K_c[i,j] = K[i,j] - row_mean[i] - col_mean[j] + grand_mean
200    /// ```
201    pub fn center(&self) -> KernelMatrix {
202        let n = self.n;
203        let n_f = n as f64;
204
205        // Row means
206        let row_means: Vec<f64> = self
207            .data
208            .iter()
209            .map(|row| row.iter().sum::<f64>() / n_f)
210            .collect();
211
212        // Column means
213        let col_means: Vec<f64> = (0..n)
214            .map(|j| (0..n).map(|i| self.data[i][j]).sum::<f64>() / n_f)
215            .collect();
216
217        // Grand mean
218        let grand_mean: f64 = row_means.iter().sum::<f64>() / n_f;
219
220        let mut data = vec![vec![0.0_f64; n]; n];
221        for i in 0..n {
222            for j in 0..n {
223                data[i][j] = self.data[i][j] - row_means[i] - col_means[j] + grand_mean;
224            }
225        }
226        KernelMatrix { data, n }
227    }
228
229    /// Matrix multiply: `(self * other)[i,j] = sum_k self[i,k] * other[k,j]`.
230    ///
231    /// Used internally; not exposed as a primary public API.
232    #[allow(dead_code)]
233    fn matmul(&self, other: &KernelMatrix) -> Result<KernelMatrix, AlignmentError> {
234        if self.n != other.n {
235            return Err(AlignmentError::DimensionMismatch {
236                expected: self.n,
237                got: other.n,
238            });
239        }
240        let n = self.n;
241        let mut data = vec![vec![0.0_f64; n]; n];
242        #[allow(clippy::needless_range_loop)]
243        for i in 0..n {
244            for k in 0..n {
245                let aik = self.data[i][k];
246                if aik == 0.0 {
247                    continue;
248                }
249                for j in 0..n {
250                    data[i][j] += aik * other.data[k][j];
251                }
252            }
253        }
254        Ok(KernelMatrix { data, n })
255    }
256
257    /// Compute `trace(self * other)` efficiently in O(n^2) without full matmul.
258    ///
259    /// `trace(A * B) = sum_{i,j} A[i,j] * B[j,i]`
260    #[allow(dead_code)]
261    fn trace_product(&self, other: &KernelMatrix) -> Result<f64, AlignmentError> {
262        if self.n != other.n {
263            return Err(AlignmentError::DimensionMismatch {
264                expected: self.n,
265                got: other.n,
266            });
267        }
268        let n = self.n;
269        let mut tr = 0.0_f64;
270        for i in 0..n {
271            for j in 0..n {
272                tr += self.data[i][j] * other.data[j][i];
273            }
274        }
275        Ok(tr)
276    }
277}
278
279// ---------------------------------------------------------------------------
280// Result types
281// ---------------------------------------------------------------------------
282
283/// Result of a pairwise kernel alignment computation.
284#[derive(Debug, Clone)]
285pub struct AlignmentResult {
286    /// The alignment score, normalised to `[-1, 1]`.
287    pub score: f64,
288    /// The raw Frobenius inner product `<K1, K2>_F` (or `<K1_c, K2_c>_F`).
289    pub numerator: f64,
290    /// `sqrt(||K1||_F^2 * ||K2||_F^2)`.
291    pub denominator: f64,
292    /// Number of samples `n`.
293    pub n_samples: usize,
294}
295
296/// Comprehensive alignment statistics between a kernel `K` and a target kernel.
297#[derive(Debug, Clone)]
298pub struct AlignmentStats {
299    /// Kernel Target Alignment (uncentered).
300    pub kta: f64,
301    /// Centered Kernel Alignment.
302    pub cka: f64,
303    /// Biased HSIC estimate `(1/n^2) * <K_c, T_c>_F`.
304    pub hsic: f64,
305    /// Number of samples.
306    pub n_samples: usize,
307}
308
309// ---------------------------------------------------------------------------
310// Core alignment functions
311// ---------------------------------------------------------------------------
312
313/// Compute the **Kernel Target Alignment (KTA)** between kernel `k` and target
314/// kernel `target`.
315///
316/// ```text
317/// KTA(K, T) = <K, T>_F / (||K||_F * ||T||_F)
318/// ```
319///
320/// The score lies in `[-1, 1]`; values near `1` indicate the kernel faithfully
321/// encodes the label structure.
322pub fn kernel_target_alignment(
323    k: &KernelMatrix,
324    target: &KernelMatrix,
325) -> Result<AlignmentResult, AlignmentError> {
326    if k.n() != target.n() {
327        return Err(AlignmentError::DimensionMismatch {
328            expected: k.n(),
329            got: target.n(),
330        });
331    }
332
333    let numerator = k.frobenius_inner(target)?;
334    let norm_k_sq = k.frobenius_norm_sq();
335    let norm_t_sq = target.frobenius_norm_sq();
336    let denominator = (norm_k_sq * norm_t_sq).sqrt();
337
338    if denominator < f64::EPSILON {
339        return Err(AlignmentError::NumericalError(
340            "One or both kernel matrices have zero Frobenius norm".to_string(),
341        ));
342    }
343
344    Ok(AlignmentResult {
345        score: numerator / denominator,
346        numerator,
347        denominator,
348        n_samples: k.n(),
349    })
350}
351
352/// Compute the **Centered Kernel Alignment (CKA)** between kernel matrices
353/// `k1` and `k2`.
354///
355/// CKA applies double-centering (via the centering matrix `H`) before alignment,
356/// making it invariant to isotropic scaling and mean shifts:
357///
358/// ```text
359/// CKA(K1, K2) = HSIC(K1, K2) / sqrt(HSIC(K1, K1) * HSIC(K2, K2))
360/// ```
361pub fn centered_kernel_alignment(
362    k1: &KernelMatrix,
363    k2: &KernelMatrix,
364) -> Result<AlignmentResult, AlignmentError> {
365    if k1.n() != k2.n() {
366        return Err(AlignmentError::DimensionMismatch {
367            expected: k1.n(),
368            got: k2.n(),
369        });
370    }
371
372    let k1_c = k1.center();
373    let k2_c = k2.center();
374
375    let n_sq = (k1.n() * k1.n()) as f64;
376
377    let hsic_12 = k1_c.frobenius_inner(&k2_c)? / n_sq;
378    let hsic_11 = k1_c.frobenius_norm_sq() / n_sq;
379    let hsic_22 = k2_c.frobenius_norm_sq() / n_sq;
380
381    let denominator_sq = hsic_11 * hsic_22;
382    if denominator_sq < f64::EPSILON * f64::EPSILON {
383        return Err(AlignmentError::NumericalError(
384            "HSIC self-alignment is zero; cannot normalise CKA".to_string(),
385        ));
386    }
387
388    let denominator = denominator_sq.sqrt();
389    let score = hsic_12 / denominator;
390
391    Ok(AlignmentResult {
392        score,
393        numerator: hsic_12,
394        denominator,
395        n_samples: k1.n(),
396    })
397}
398
399/// Compute the **biased HSIC** (Hilbert-Schmidt Independence Criterion) estimate:
400///
401/// ```text
402/// HSIC(K, L) = (1/n^2) * trace(K * H * L * H)
403///            = (1/n^2) * <K_c, L_c>_F
404/// ```
405///
406/// where `K_c = H*K*H` and `L_c = H*L*H` are the doubly-centred versions.
407pub fn hsic(k: &KernelMatrix, l: &KernelMatrix) -> Result<f64, AlignmentError> {
408    if k.n() != l.n() {
409        return Err(AlignmentError::DimensionMismatch {
410            expected: k.n(),
411            got: l.n(),
412        });
413    }
414    let n_sq = (k.n() * k.n()) as f64;
415    let k_c = k.center();
416    let l_c = l.center();
417    let inner = k_c.frobenius_inner(&l_c)?;
418    Ok(inner / n_sq)
419}
420
421/// Compute **all alignment metrics** in a single pass.
422///
423/// Returns [`AlignmentStats`] containing KTA, CKA, and HSIC values.
424pub fn alignment_stats(
425    k: &KernelMatrix,
426    target: &KernelMatrix,
427) -> Result<AlignmentStats, AlignmentError> {
428    if k.n() != target.n() {
429        return Err(AlignmentError::DimensionMismatch {
430            expected: k.n(),
431            got: target.n(),
432        });
433    }
434
435    let kta_result = kernel_target_alignment(k, target)?;
436    let cka_result = centered_kernel_alignment(k, target)?;
437    let hsic_val = hsic(k, target)?;
438
439    Ok(AlignmentStats {
440        kta: kta_result.score,
441        cka: cka_result.score,
442        hsic: hsic_val,
443        n_samples: k.n(),
444    })
445}
446
447// ---------------------------------------------------------------------------
448// Optimization types and routines
449// ---------------------------------------------------------------------------
450
451/// Configuration for alignment-based kernel hyperparameter search.
452#[derive(Debug, Clone)]
453pub struct AlignmentOptConfig {
454    /// Maximum number of iterations (default: 50).
455    pub max_iterations: usize,
456    /// Step size for gradient ascent (default: 0.01).
457    pub learning_rate: f64,
458    /// Convergence threshold: stop when `|Δscore| < tolerance` (default: 1e-6).
459    pub tolerance: f64,
460    /// If `true`, use CKA; otherwise use KTA (default: `true`).
461    pub use_cka: bool,
462    /// Finite-difference step for gradient estimation (default: 1e-5).
463    pub fd_step: f64,
464}
465
466impl Default for AlignmentOptConfig {
467    fn default() -> Self {
468        AlignmentOptConfig {
469            max_iterations: 50,
470            learning_rate: 0.01,
471            tolerance: 1e-6,
472            use_cka: true,
473            fd_step: 1e-5,
474        }
475    }
476}
477
478/// Outcome of a kernel alignment optimisation run.
479#[derive(Debug, Clone)]
480pub struct OptimizationResult {
481    /// Best alignment score found.
482    pub best_score: f64,
483    /// Kernel hyperparameters that achieved `best_score`.
484    pub best_params: Vec<f64>,
485    /// Alignment score recorded after each iteration / grid point.
486    pub score_history: Vec<f64>,
487    /// Whether the optimiser converged before `max_iterations`.
488    pub converged: bool,
489    /// Total number of iterations (or grid points evaluated).
490    pub iterations: usize,
491}
492
493/// Evaluate the alignment score for a given parameter vector.
494fn evaluate_alignment(
495    kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
496    target: &KernelMatrix,
497    params: &[f64],
498    use_cka: bool,
499) -> Result<f64, AlignmentError> {
500    let k = kernel_fn(params);
501    if use_cka {
502        centered_kernel_alignment(&k, target).map(|r| r.score)
503    } else {
504        kernel_target_alignment(&k, target).map(|r| r.score)
505    }
506}
507
508/// **Grid search** over a discrete set of kernel parameter vectors, returning
509/// the one that maximises alignment with `target`.
510///
511/// # Arguments
512///
513/// * `kernel_fn` - A closure that maps a parameter vector to a [`KernelMatrix`].
514/// * `target`    - The target kernel (e.g. built from labels via
515///   [`KernelMatrix::from_labels`]).
516/// * `params_grid` - The set of parameter vectors to evaluate.
517/// * `config`    - Search configuration (determines CKA vs KTA).
518///
519/// # Returns
520///
521/// An [`OptimizationResult`] with `best_params` set to the grid vector
522/// achieving the highest alignment.
523pub fn grid_search_alignment(
524    kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
525    target: &KernelMatrix,
526    params_grid: &[Vec<f64>],
527    config: &AlignmentOptConfig,
528) -> Result<OptimizationResult, AlignmentError> {
529    if params_grid.is_empty() {
530        return Err(AlignmentError::NumericalError(
531            "params_grid must not be empty".to_string(),
532        ));
533    }
534
535    let mut best_score = f64::NEG_INFINITY;
536    let mut best_params = params_grid[0].clone();
537    let mut score_history = Vec::with_capacity(params_grid.len());
538
539    for params in params_grid {
540        let score = evaluate_alignment(kernel_fn, target, params, config.use_cka)?;
541        score_history.push(score);
542        if score > best_score {
543            best_score = score;
544            best_params = params.clone();
545        }
546    }
547
548    Ok(OptimizationResult {
549        best_score,
550        best_params,
551        score_history,
552        converged: true,
553        iterations: params_grid.len(),
554    })
555}
556
557/// **Gradient ascent** on the alignment score via finite differences.
558///
559/// For each parameter `θ_k`, the partial derivative is approximated as:
560///
561/// ```text
562/// ∂A/∂θ_k ≈ (A(θ + ε*e_k) - A(θ - ε*e_k)) / (2ε)
563/// ```
564///
565/// Parameters are updated as `θ ← θ + η * ∇A(θ)` until convergence or
566/// `max_iterations` is reached.
567///
568/// # Arguments
569///
570/// * `kernel_fn`      - A closure mapping parameters to a [`KernelMatrix`].
571/// * `target`         - Target kernel.
572/// * `initial_params` - Starting parameter vector.
573/// * `config`         - Optimisation configuration.
574pub fn gradient_ascent_alignment(
575    kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
576    target: &KernelMatrix,
577    initial_params: &[f64],
578    config: &AlignmentOptConfig,
579) -> Result<OptimizationResult, AlignmentError> {
580    if initial_params.is_empty() {
581        return Err(AlignmentError::NumericalError(
582            "initial_params must not be empty".to_string(),
583        ));
584    }
585
586    let d = initial_params.len();
587    let mut params = initial_params.to_vec();
588    let mut score_history = Vec::with_capacity(config.max_iterations);
589    let mut converged = false;
590
591    let mut current_score = evaluate_alignment(kernel_fn, target, &params, config.use_cka)?;
592    score_history.push(current_score);
593
594    for _iter in 0..config.max_iterations {
595        // Compute finite-difference gradient
596        let mut grad = vec![0.0_f64; d];
597        for k in 0..d {
598            let mut params_fwd = params.clone();
599            let mut params_bwd = params.clone();
600            params_fwd[k] += config.fd_step;
601            params_bwd[k] -= config.fd_step;
602
603            let score_fwd = evaluate_alignment(kernel_fn, target, &params_fwd, config.use_cka)?;
604            let score_bwd = evaluate_alignment(kernel_fn, target, &params_bwd, config.use_cka)?;
605            grad[k] = (score_fwd - score_bwd) / (2.0 * config.fd_step);
606        }
607
608        // Gradient ascent step
609        for k in 0..d {
610            params[k] += config.learning_rate * grad[k];
611        }
612
613        let new_score = evaluate_alignment(kernel_fn, target, &params, config.use_cka)?;
614        score_history.push(new_score);
615
616        if (new_score - current_score).abs() < config.tolerance {
617            converged = true;
618            current_score = new_score;
619            break;
620        }
621        current_score = new_score;
622    }
623
624    let iterations = score_history.len();
625    Ok(OptimizationResult {
626        best_score: current_score,
627        best_params: params,
628        score_history,
629        converged,
630        iterations,
631    })
632}
633
634// ---------------------------------------------------------------------------
635// Tests
636// ---------------------------------------------------------------------------
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    // Helper: build a small RBF kernel matrix for a 1-D dataset
643    fn rbf_kernel_matrix(data: &[f64], gamma: f64) -> KernelMatrix {
644        let n = data.len();
645        let mut mat = vec![vec![0.0_f64; n]; n];
646        for i in 0..n {
647            for j in 0..n {
648                let diff = data[i] - data[j];
649                mat[i][j] = (-gamma * diff * diff).exp();
650            }
651        }
652        KernelMatrix::new(mat).expect("valid kernel matrix")
653    }
654
655    // ---------------------------------------------------------------------------
656    // KernelMatrix structural tests
657    // ---------------------------------------------------------------------------
658
659    #[test]
660    fn test_identity_trace_equals_n() {
661        for n in [1_usize, 3, 5, 10] {
662            let id = KernelMatrix::identity(n);
663            let tr = id.trace();
664            assert!(
665                (tr - n as f64).abs() < 1e-12,
666                "identity trace should be {n}, got {tr}"
667            );
668        }
669    }
670
671    #[test]
672    fn test_from_labels_correct_values() {
673        let labels = vec![0.0, 0.0, 1.0, 1.0];
674        let k = KernelMatrix::from_labels(&labels);
675        assert_eq!(k.n(), 4);
676        // Same-class pairs
677        assert!((k.get(0, 1) - 1.0).abs() < 1e-12);
678        assert!((k.get(2, 3) - 1.0).abs() < 1e-12);
679        // Diagonal
680        assert!((k.get(0, 0) - 1.0).abs() < 1e-12);
681        // Cross-class pairs
682        assert!((k.get(0, 2) + 1.0).abs() < 1e-12);
683        assert!((k.get(1, 3) + 1.0).abs() < 1e-12);
684    }
685
686    #[test]
687    fn test_center_zero_row_column_sums() {
688        // Use a non-trivial positive semidefinite matrix
689        let data = vec![
690            vec![4.0, 2.0, 1.0],
691            vec![2.0, 3.0, 0.5],
692            vec![1.0, 0.5, 2.0],
693        ];
694        let k = KernelMatrix::new(data).expect("valid");
695        let k_c = k.center();
696        let n = k_c.n();
697
698        for i in 0..n {
699            let row_sum: f64 = (0..n).map(|j| k_c.get(i, j)).sum();
700            assert!(row_sum.abs() < 1e-10, "centered row {i} sum = {row_sum}");
701            let col_sum: f64 = (0..n).map(|j| k_c.get(j, i)).sum();
702            assert!(col_sum.abs() < 1e-10, "centered col {i} sum = {col_sum}");
703        }
704    }
705
706    #[test]
707    fn test_frobenius_inner_symmetric() {
708        let data1 = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
709        let data2 = vec![vec![1.0, 0.5], vec![0.5, 2.0]];
710        let k1 = KernelMatrix::new(data1).expect("valid");
711        let k2 = KernelMatrix::new(data2).expect("valid");
712
713        let inner_12 = k1.frobenius_inner(&k2).expect("ok");
714        let inner_21 = k2.frobenius_inner(&k1).expect("ok");
715        assert!(
716            (inner_12 - inner_21).abs() < 1e-12,
717            "<K1,K2> = {inner_12}, <K2,K1> = {inner_21}"
718        );
719    }
720
721    #[test]
722    fn test_frobenius_norm_identity() {
723        for n in [1_usize, 4, 9] {
724            let id = KernelMatrix::identity(n);
725            let norm_sq = id.frobenius_norm_sq();
726            let norm = norm_sq.sqrt();
727            let expected = (n as f64).sqrt();
728            assert!(
729                (norm - expected).abs() < 1e-12,
730                "||I_n||_F should be sqrt({n}) = {expected}, got {norm}"
731            );
732        }
733    }
734
735    #[test]
736    fn test_from_flat_validates_square() {
737        // 2×2 from flat works
738        let flat = vec![1.0, 0.0, 0.0, 1.0];
739        assert!(KernelMatrix::from_flat(&flat, 2).is_ok());
740
741        // 5 elements cannot form a square matrix
742        let bad = vec![1.0, 2.0, 3.0, 4.0, 5.0];
743        assert!(matches!(
744            KernelMatrix::from_flat(&bad, 2),
745            Err(AlignmentError::NonSquareMatrix)
746        ));
747    }
748
749    // ---------------------------------------------------------------------------
750    // KTA tests
751    // ---------------------------------------------------------------------------
752
753    #[test]
754    fn test_kta_identical_kernels_is_one() {
755        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
756        let k = rbf_kernel_matrix(&data, 0.5);
757        let result = kernel_target_alignment(&k, &k).expect("ok");
758        assert!(
759            (result.score - 1.0).abs() < 1e-10,
760            "KTA of K with itself should be 1.0, got {}",
761            result.score
762        );
763    }
764
765    #[test]
766    fn test_kta_with_label_target_positive() {
767        // Two well-separated clusters → RBF with large gamma → high KTA
768        let data = vec![0.0, 0.1, 0.2, 10.0, 10.1, 10.2];
769        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
770        let k = rbf_kernel_matrix(&data, 1.0);
771        let target = KernelMatrix::from_labels(&labels);
772        let result = kernel_target_alignment(&k, &target).expect("ok");
773        assert!(
774            result.score > 0.0,
775            "KTA should be positive for clustered data, got {}",
776            result.score
777        );
778    }
779
780    #[test]
781    fn test_kta_range_is_minus_one_to_one() {
782        let data = vec![1.0, 2.0, 3.0, 4.0];
783        let labels = vec![0.0, 1.0, 0.0, 1.0];
784        let k = rbf_kernel_matrix(&data, 1.0);
785        let target = KernelMatrix::from_labels(&labels);
786        let result = kernel_target_alignment(&k, &target).expect("ok");
787        assert!(
788            result.score >= -1.0 - 1e-9 && result.score <= 1.0 + 1e-9,
789            "KTA score out of range: {}",
790            result.score
791        );
792    }
793
794    // ---------------------------------------------------------------------------
795    // CKA tests
796    // ---------------------------------------------------------------------------
797
798    #[test]
799    fn test_cka_identical_kernels_is_one() {
800        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
801        let k = rbf_kernel_matrix(&data, 0.5);
802        let result = centered_kernel_alignment(&k, &k).expect("ok");
803        assert!(
804            (result.score - 1.0).abs() < 1e-10,
805            "CKA of K with itself should be 1.0, got {}",
806            result.score
807        );
808    }
809
810    #[test]
811    fn test_cka_invariant_to_scaling() {
812        let data = vec![0.5, 1.0, 2.0, 3.0, 4.0];
813        let k = rbf_kernel_matrix(&data, 0.3);
814        let labels = vec![0.0, 0.0, 1.0, 1.0, 1.0];
815        let target = KernelMatrix::from_labels(&labels);
816
817        // Build 2*K
818        let n = k.n();
819        let scaled_data: Vec<Vec<f64>> = (0..n)
820            .map(|i| (0..n).map(|j| 2.0 * k.get(i, j)).collect())
821            .collect();
822        let k_scaled = KernelMatrix::new(scaled_data).expect("valid");
823
824        let cka_original = centered_kernel_alignment(&k, &target).expect("ok").score;
825        let cka_scaled = centered_kernel_alignment(&k_scaled, &target)
826            .expect("ok")
827            .score;
828
829        assert!(
830            (cka_original - cka_scaled).abs() < 1e-10,
831            "CKA should be invariant to scaling: {cka_original} vs {cka_scaled}"
832        );
833    }
834
835    #[test]
836    fn test_cka_invariant_to_mean_shift() {
837        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
838        let k = rbf_kernel_matrix(&data, 0.2);
839        let labels = vec![0.0, 0.0, 1.0, 1.0, 1.0];
840        let target = KernelMatrix::from_labels(&labels);
841
842        // Shift K by a constant c → K' = K + c*11^T
843        let n = k.n();
844        let c = 3.0_f64;
845        let shifted_data: Vec<Vec<f64>> = (0..n)
846            .map(|i| (0..n).map(|j| k.get(i, j) + c).collect())
847            .collect();
848        let k_shifted = KernelMatrix::new(shifted_data).expect("valid");
849
850        let cka_original = centered_kernel_alignment(&k, &target).expect("ok").score;
851        let cka_shifted = centered_kernel_alignment(&k_shifted, &target)
852            .expect("ok")
853            .score;
854
855        assert!(
856            (cka_original - cka_shifted).abs() < 1e-9,
857            "CKA should be invariant to constant mean shift: {cka_original} vs {cka_shifted}"
858        );
859    }
860
861    // ---------------------------------------------------------------------------
862    // HSIC tests
863    // ---------------------------------------------------------------------------
864
865    #[test]
866    fn test_hsic_identical_kernel_positive() {
867        let data = vec![1.0, 3.0, 5.0, 7.0];
868        let k = rbf_kernel_matrix(&data, 1.0);
869        let val = hsic(&k, &k).expect("ok");
870        assert!(val > 0.0, "HSIC(K,K) should be positive, got {val}");
871    }
872
873    #[test]
874    fn test_hsic_near_independent_kernels() {
875        // An identity kernel encodes no inter-sample similarity; pairing with
876        // a label kernel that has off-diagonal structure yields a small HSIC.
877        let n = 8;
878        let identity = KernelMatrix::identity(n);
879
880        // Constant kernel (all ones) is trivially uninformative after centering
881        let data = vec![vec![1.0_f64; n]; n];
882        let constant_k = KernelMatrix::new(data).expect("valid");
883
884        let val = hsic(&identity, &constant_k).expect("ok");
885        // After centering a constant matrix becomes all zeros → HSIC = 0
886        assert!(
887            val.abs() < 1e-12,
888            "HSIC(I, 1*1^T) after centering should be ~0, got {val}"
889        );
890    }
891
892    // ---------------------------------------------------------------------------
893    // AlignmentStats test
894    // ---------------------------------------------------------------------------
895
896    #[test]
897    fn test_alignment_stats_reports_all_metrics() {
898        let data = vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0];
899        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
900        let k = rbf_kernel_matrix(&data, 2.0);
901        let target = KernelMatrix::from_labels(&labels);
902
903        let stats = alignment_stats(&k, &target).expect("ok");
904        assert_eq!(stats.n_samples, 6);
905        // KTA, CKA should be in [-1,1]
906        assert!(stats.kta >= -1.0 - 1e-9 && stats.kta <= 1.0 + 1e-9);
907        assert!(stats.cka >= -1.0 - 1e-9 && stats.cka <= 1.0 + 1e-9);
908    }
909
910    #[test]
911    fn test_alignment_stats_perfect_alignment_near_one() {
912        // Identical kernels should give KTA = CKA = 1.0
913        let data = vec![1.0, 2.0, 3.0, 4.0];
914        let k = rbf_kernel_matrix(&data, 0.5);
915        let stats = alignment_stats(&k, &k).expect("ok");
916        assert!(
917            (stats.kta - 1.0).abs() < 1e-10,
918            "KTA should be 1.0, got {}",
919            stats.kta
920        );
921        assert!(
922            (stats.cka - 1.0).abs() < 1e-10,
923            "CKA should be 1.0, got {}",
924            stats.cka
925        );
926    }
927
928    // ---------------------------------------------------------------------------
929    // Optimisation tests
930    // ---------------------------------------------------------------------------
931
932    #[test]
933    fn test_grid_search_finds_best_params() {
934        let data = vec![0.0, 0.2, 0.4, 5.0, 5.2, 5.4];
935        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
936        let target = KernelMatrix::from_labels(&labels);
937
938        // Grid of gamma values: larger gamma → tighter clusters → higher alignment
939        let params_grid: Vec<Vec<f64>> =
940            vec![vec![0.01], vec![0.1], vec![1.0], vec![5.0], vec![10.0]];
941
942        let config = AlignmentOptConfig {
943            use_cka: true,
944            ..Default::default()
945        };
946
947        let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0]);
948
949        let result = grid_search_alignment(&kernel_fn, &target, &params_grid, &config).expect("ok");
950
951        assert_eq!(result.iterations, 5);
952        assert_eq!(result.score_history.len(), 5);
953        assert!(result.converged);
954
955        // Verify best_score is actually the maximum in history
956        let max_in_history = result
957            .score_history
958            .iter()
959            .cloned()
960            .fold(f64::NEG_INFINITY, f64::max);
961        assert!(
962            (result.best_score - max_in_history).abs() < 1e-12,
963            "best_score {} should equal max in history {}",
964            result.best_score,
965            max_in_history
966        );
967    }
968
969    #[test]
970    fn test_gradient_ascent_converges_toward_higher_alignment() {
971        let data = vec![0.0, 0.3, 0.6, 4.0, 4.3, 4.6];
972        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
973        let target = KernelMatrix::from_labels(&labels);
974
975        let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0].abs());
976
977        let initial_params = vec![0.01_f64];
978        let config = AlignmentOptConfig {
979            max_iterations: 30,
980            learning_rate: 0.05,
981            tolerance: 1e-8,
982            use_cka: true,
983            fd_step: 1e-4,
984        };
985
986        let result =
987            gradient_ascent_alignment(&kernel_fn, &target, &initial_params, &config).expect("ok");
988
989        assert!(
990            !result.score_history.is_empty(),
991            "score_history must be non-empty"
992        );
993        // The final score should be >= the initial score
994        let first_score = result.score_history[0];
995        assert!(
996            result.best_score >= first_score - 1e-6,
997            "gradient ascent should not decrease alignment: final {} < initial {}",
998            result.best_score,
999            first_score
1000        );
1001    }
1002
1003    #[test]
1004    fn test_score_history_non_decreasing_approximately() {
1005        // We run gradient ascent on a simple 1-parameter RBF and check that the
1006        // alignment trend is upward (allowing small oscillations due to FD noise).
1007        let data = vec![0.0, 0.5, 1.0, 6.0, 6.5, 7.0];
1008        let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1009        let target = KernelMatrix::from_labels(&labels);
1010
1011        let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0].abs() + 1e-3);
1012
1013        let config = AlignmentOptConfig {
1014            max_iterations: 20,
1015            learning_rate: 0.02,
1016            tolerance: 1e-9,
1017            use_cka: true,
1018            fd_step: 1e-4,
1019        };
1020
1021        let result = gradient_ascent_alignment(&kernel_fn, &target, &[0.01], &config).expect("ok");
1022
1023        // The final best score should not be catastrophically worse than the midpoint
1024        let n = result.score_history.len();
1025        if n >= 2 {
1026            let final_score = result.score_history[n - 1];
1027            let initial_score = result.score_history[0];
1028            // Allow a 5% relative tolerance (gradient ascent may oscillate slightly)
1029            assert!(
1030                final_score >= initial_score - 0.05 * initial_score.abs().max(1e-3),
1031                "score history should trend upward: initial={initial_score}, final={final_score}"
1032            );
1033        }
1034    }
1035
1036    // ---------------------------------------------------------------------------
1037    // Error handling tests
1038    // ---------------------------------------------------------------------------
1039
1040    #[test]
1041    fn test_kta_dimension_mismatch_error() {
1042        let k1 = KernelMatrix::identity(3);
1043        let k2 = KernelMatrix::identity(4);
1044        let result = kernel_target_alignment(&k1, &k2);
1045        assert!(matches!(
1046            result,
1047            Err(AlignmentError::DimensionMismatch {
1048                expected: 3,
1049                got: 4
1050            })
1051        ));
1052    }
1053
1054    #[test]
1055    fn test_cka_dimension_mismatch_error() {
1056        let k1 = KernelMatrix::identity(2);
1057        let k2 = KernelMatrix::identity(5);
1058        let result = centered_kernel_alignment(&k1, &k2);
1059        assert!(matches!(
1060            result,
1061            Err(AlignmentError::DimensionMismatch {
1062                expected: 2,
1063                got: 5
1064            })
1065        ));
1066    }
1067
1068    #[test]
1069    fn test_hsic_dimension_mismatch_error() {
1070        let k1 = KernelMatrix::identity(3);
1071        let k2 = KernelMatrix::identity(6);
1072        let result = hsic(&k1, &k2);
1073        assert!(matches!(
1074            result,
1075            Err(AlignmentError::DimensionMismatch {
1076                expected: 3,
1077                got: 6
1078            })
1079        ));
1080    }
1081
1082    #[test]
1083    fn test_alignment_stats_dimension_mismatch() {
1084        let k = KernelMatrix::identity(3);
1085        let target = KernelMatrix::identity(4);
1086        let result = alignment_stats(&k, &target);
1087        assert!(matches!(
1088            result,
1089            Err(AlignmentError::DimensionMismatch { .. })
1090        ));
1091    }
1092
1093    // ---------------------------------------------------------------------------
1094    // Matrix operation correctness
1095    // ---------------------------------------------------------------------------
1096
1097    #[test]
1098    fn test_matmul_identity_neutral() {
1099        let n = 4;
1100        let id = KernelMatrix::identity(n);
1101        let k = rbf_kernel_matrix(&[1.0, 2.0, 3.0, 4.0], 0.5);
1102        let product = k.matmul(&id).expect("ok");
1103        for i in 0..n {
1104            for j in 0..n {
1105                let diff = (product.get(i, j) - k.get(i, j)).abs();
1106                assert!(diff < 1e-12, "K*I should equal K at ({i},{j}): diff={diff}");
1107            }
1108        }
1109    }
1110
1111    #[test]
1112    fn test_trace_product_vs_matmul_trace() {
1113        let k = rbf_kernel_matrix(&[0.0, 1.0, 2.0, 3.0], 0.4);
1114        let l = rbf_kernel_matrix(&[0.0, 1.0, 2.0, 3.0], 0.8);
1115        let via_trace_product = k.trace_product(&l).expect("ok");
1116        let via_matmul = k.matmul(&l).expect("ok").trace();
1117        assert!(
1118            (via_trace_product - via_matmul).abs() < 1e-10,
1119            "trace(K*L) via trace_product ({via_trace_product}) vs matmul ({via_matmul})"
1120        );
1121    }
1122}