1use std::fmt;
40
41#[derive(Debug, Clone, PartialEq)]
47pub enum AlignmentError {
48 NonSquareMatrix,
50 DimensionMismatch {
52 expected: usize,
54 got: usize,
56 },
57 NumericalError(String),
59 SingularMatrix,
61}
62
63impl fmt::Display for AlignmentError {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Self::NonSquareMatrix => write!(f, "Matrix is not square"),
67 Self::DimensionMismatch { expected, got } => write!(
68 f,
69 "Dimension mismatch: expected {}×{}, got {}×{}",
70 expected, expected, got, got
71 ),
72 Self::NumericalError(msg) => write!(f, "Numerical error: {}", msg),
73 Self::SingularMatrix => write!(f, "Matrix is singular or near-singular"),
74 }
75 }
76}
77
78impl std::error::Error for AlignmentError {}
79
80#[derive(Debug, Clone)]
88pub struct KernelMatrix {
89 data: Vec<Vec<f64>>,
90 n: usize,
91}
92
93impl KernelMatrix {
94 pub fn new(data: Vec<Vec<f64>>) -> Result<KernelMatrix, AlignmentError> {
99 let n = data.len();
100 for row in &data {
101 if row.len() != n {
102 return Err(AlignmentError::NonSquareMatrix);
103 }
104 }
105 Ok(KernelMatrix { data, n })
106 }
107
108 pub fn from_flat(flat: &[f64], n: usize) -> Result<KernelMatrix, AlignmentError> {
110 if flat.len() != n * n {
111 return Err(AlignmentError::NonSquareMatrix);
112 }
113 let data = (0..n).map(|i| flat[i * n..(i + 1) * n].to_vec()).collect();
114 Ok(KernelMatrix { data, n })
115 }
116
117 pub fn identity(n: usize) -> KernelMatrix {
119 let mut data = vec![vec![0.0_f64; n]; n];
120 #[allow(clippy::needless_range_loop)]
121 for i in 0..n {
122 data[i][i] = 1.0;
123 }
124 KernelMatrix { data, n }
125 }
126
127 pub fn from_labels(labels: &[f64]) -> KernelMatrix {
133 let n = labels.len();
134 let mut data = vec![vec![0.0_f64; n]; n];
135 for i in 0..n {
136 for j in 0..n {
137 data[i][j] = if (labels[i] - labels[j]).abs() < 1e-10 {
139 1.0
140 } else {
141 -1.0
142 };
143 }
144 }
145 KernelMatrix { data, n }
146 }
147
148 #[inline]
150 pub fn get(&self, i: usize, j: usize) -> f64 {
151 self.data[i][j]
152 }
153
154 #[inline]
156 pub fn n(&self) -> usize {
157 self.n
158 }
159
160 pub fn trace(&self) -> f64 {
162 (0..self.n).map(|i| self.data[i][i]).sum()
163 }
164
165 pub fn frobenius_norm_sq(&self) -> f64 {
167 self.data
168 .iter()
169 .flat_map(|row| row.iter())
170 .map(|&v| v * v)
171 .sum()
172 }
173
174 pub fn frobenius_inner(&self, other: &KernelMatrix) -> Result<f64, AlignmentError> {
179 if self.n != other.n {
180 return Err(AlignmentError::DimensionMismatch {
181 expected: self.n,
182 got: other.n,
183 });
184 }
185 let mut sum = 0.0_f64;
186 for i in 0..self.n {
187 for j in 0..self.n {
188 sum += self.data[i][j] * other.data[i][j];
189 }
190 }
191 Ok(sum)
192 }
193
194 pub fn center(&self) -> KernelMatrix {
202 let n = self.n;
203 let n_f = n as f64;
204
205 let row_means: Vec<f64> = self
207 .data
208 .iter()
209 .map(|row| row.iter().sum::<f64>() / n_f)
210 .collect();
211
212 let col_means: Vec<f64> = (0..n)
214 .map(|j| (0..n).map(|i| self.data[i][j]).sum::<f64>() / n_f)
215 .collect();
216
217 let grand_mean: f64 = row_means.iter().sum::<f64>() / n_f;
219
220 let mut data = vec![vec![0.0_f64; n]; n];
221 for i in 0..n {
222 for j in 0..n {
223 data[i][j] = self.data[i][j] - row_means[i] - col_means[j] + grand_mean;
224 }
225 }
226 KernelMatrix { data, n }
227 }
228
229 #[allow(dead_code)]
233 fn matmul(&self, other: &KernelMatrix) -> Result<KernelMatrix, AlignmentError> {
234 if self.n != other.n {
235 return Err(AlignmentError::DimensionMismatch {
236 expected: self.n,
237 got: other.n,
238 });
239 }
240 let n = self.n;
241 let mut data = vec![vec![0.0_f64; n]; n];
242 #[allow(clippy::needless_range_loop)]
243 for i in 0..n {
244 for k in 0..n {
245 let aik = self.data[i][k];
246 if aik == 0.0 {
247 continue;
248 }
249 for j in 0..n {
250 data[i][j] += aik * other.data[k][j];
251 }
252 }
253 }
254 Ok(KernelMatrix { data, n })
255 }
256
257 #[allow(dead_code)]
261 fn trace_product(&self, other: &KernelMatrix) -> Result<f64, AlignmentError> {
262 if self.n != other.n {
263 return Err(AlignmentError::DimensionMismatch {
264 expected: self.n,
265 got: other.n,
266 });
267 }
268 let n = self.n;
269 let mut tr = 0.0_f64;
270 for i in 0..n {
271 for j in 0..n {
272 tr += self.data[i][j] * other.data[j][i];
273 }
274 }
275 Ok(tr)
276 }
277}
278
279#[derive(Debug, Clone)]
285pub struct AlignmentResult {
286 pub score: f64,
288 pub numerator: f64,
290 pub denominator: f64,
292 pub n_samples: usize,
294}
295
296#[derive(Debug, Clone)]
298pub struct AlignmentStats {
299 pub kta: f64,
301 pub cka: f64,
303 pub hsic: f64,
305 pub n_samples: usize,
307}
308
309pub fn kernel_target_alignment(
323 k: &KernelMatrix,
324 target: &KernelMatrix,
325) -> Result<AlignmentResult, AlignmentError> {
326 if k.n() != target.n() {
327 return Err(AlignmentError::DimensionMismatch {
328 expected: k.n(),
329 got: target.n(),
330 });
331 }
332
333 let numerator = k.frobenius_inner(target)?;
334 let norm_k_sq = k.frobenius_norm_sq();
335 let norm_t_sq = target.frobenius_norm_sq();
336 let denominator = (norm_k_sq * norm_t_sq).sqrt();
337
338 if denominator < f64::EPSILON {
339 return Err(AlignmentError::NumericalError(
340 "One or both kernel matrices have zero Frobenius norm".to_string(),
341 ));
342 }
343
344 Ok(AlignmentResult {
345 score: numerator / denominator,
346 numerator,
347 denominator,
348 n_samples: k.n(),
349 })
350}
351
352pub fn centered_kernel_alignment(
362 k1: &KernelMatrix,
363 k2: &KernelMatrix,
364) -> Result<AlignmentResult, AlignmentError> {
365 if k1.n() != k2.n() {
366 return Err(AlignmentError::DimensionMismatch {
367 expected: k1.n(),
368 got: k2.n(),
369 });
370 }
371
372 let k1_c = k1.center();
373 let k2_c = k2.center();
374
375 let n_sq = (k1.n() * k1.n()) as f64;
376
377 let hsic_12 = k1_c.frobenius_inner(&k2_c)? / n_sq;
378 let hsic_11 = k1_c.frobenius_norm_sq() / n_sq;
379 let hsic_22 = k2_c.frobenius_norm_sq() / n_sq;
380
381 let denominator_sq = hsic_11 * hsic_22;
382 if denominator_sq < f64::EPSILON * f64::EPSILON {
383 return Err(AlignmentError::NumericalError(
384 "HSIC self-alignment is zero; cannot normalise CKA".to_string(),
385 ));
386 }
387
388 let denominator = denominator_sq.sqrt();
389 let score = hsic_12 / denominator;
390
391 Ok(AlignmentResult {
392 score,
393 numerator: hsic_12,
394 denominator,
395 n_samples: k1.n(),
396 })
397}
398
399pub fn hsic(k: &KernelMatrix, l: &KernelMatrix) -> Result<f64, AlignmentError> {
408 if k.n() != l.n() {
409 return Err(AlignmentError::DimensionMismatch {
410 expected: k.n(),
411 got: l.n(),
412 });
413 }
414 let n_sq = (k.n() * k.n()) as f64;
415 let k_c = k.center();
416 let l_c = l.center();
417 let inner = k_c.frobenius_inner(&l_c)?;
418 Ok(inner / n_sq)
419}
420
421pub fn alignment_stats(
425 k: &KernelMatrix,
426 target: &KernelMatrix,
427) -> Result<AlignmentStats, AlignmentError> {
428 if k.n() != target.n() {
429 return Err(AlignmentError::DimensionMismatch {
430 expected: k.n(),
431 got: target.n(),
432 });
433 }
434
435 let kta_result = kernel_target_alignment(k, target)?;
436 let cka_result = centered_kernel_alignment(k, target)?;
437 let hsic_val = hsic(k, target)?;
438
439 Ok(AlignmentStats {
440 kta: kta_result.score,
441 cka: cka_result.score,
442 hsic: hsic_val,
443 n_samples: k.n(),
444 })
445}
446
447#[derive(Debug, Clone)]
453pub struct AlignmentOptConfig {
454 pub max_iterations: usize,
456 pub learning_rate: f64,
458 pub tolerance: f64,
460 pub use_cka: bool,
462 pub fd_step: f64,
464}
465
466impl Default for AlignmentOptConfig {
467 fn default() -> Self {
468 AlignmentOptConfig {
469 max_iterations: 50,
470 learning_rate: 0.01,
471 tolerance: 1e-6,
472 use_cka: true,
473 fd_step: 1e-5,
474 }
475 }
476}
477
478#[derive(Debug, Clone)]
480pub struct OptimizationResult {
481 pub best_score: f64,
483 pub best_params: Vec<f64>,
485 pub score_history: Vec<f64>,
487 pub converged: bool,
489 pub iterations: usize,
491}
492
493fn evaluate_alignment(
495 kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
496 target: &KernelMatrix,
497 params: &[f64],
498 use_cka: bool,
499) -> Result<f64, AlignmentError> {
500 let k = kernel_fn(params);
501 if use_cka {
502 centered_kernel_alignment(&k, target).map(|r| r.score)
503 } else {
504 kernel_target_alignment(&k, target).map(|r| r.score)
505 }
506}
507
508pub fn grid_search_alignment(
524 kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
525 target: &KernelMatrix,
526 params_grid: &[Vec<f64>],
527 config: &AlignmentOptConfig,
528) -> Result<OptimizationResult, AlignmentError> {
529 if params_grid.is_empty() {
530 return Err(AlignmentError::NumericalError(
531 "params_grid must not be empty".to_string(),
532 ));
533 }
534
535 let mut best_score = f64::NEG_INFINITY;
536 let mut best_params = params_grid[0].clone();
537 let mut score_history = Vec::with_capacity(params_grid.len());
538
539 for params in params_grid {
540 let score = evaluate_alignment(kernel_fn, target, params, config.use_cka)?;
541 score_history.push(score);
542 if score > best_score {
543 best_score = score;
544 best_params = params.clone();
545 }
546 }
547
548 Ok(OptimizationResult {
549 best_score,
550 best_params,
551 score_history,
552 converged: true,
553 iterations: params_grid.len(),
554 })
555}
556
557pub fn gradient_ascent_alignment(
575 kernel_fn: &dyn Fn(&[f64]) -> KernelMatrix,
576 target: &KernelMatrix,
577 initial_params: &[f64],
578 config: &AlignmentOptConfig,
579) -> Result<OptimizationResult, AlignmentError> {
580 if initial_params.is_empty() {
581 return Err(AlignmentError::NumericalError(
582 "initial_params must not be empty".to_string(),
583 ));
584 }
585
586 let d = initial_params.len();
587 let mut params = initial_params.to_vec();
588 let mut score_history = Vec::with_capacity(config.max_iterations);
589 let mut converged = false;
590
591 let mut current_score = evaluate_alignment(kernel_fn, target, ¶ms, config.use_cka)?;
592 score_history.push(current_score);
593
594 for _iter in 0..config.max_iterations {
595 let mut grad = vec![0.0_f64; d];
597 for k in 0..d {
598 let mut params_fwd = params.clone();
599 let mut params_bwd = params.clone();
600 params_fwd[k] += config.fd_step;
601 params_bwd[k] -= config.fd_step;
602
603 let score_fwd = evaluate_alignment(kernel_fn, target, ¶ms_fwd, config.use_cka)?;
604 let score_bwd = evaluate_alignment(kernel_fn, target, ¶ms_bwd, config.use_cka)?;
605 grad[k] = (score_fwd - score_bwd) / (2.0 * config.fd_step);
606 }
607
608 for k in 0..d {
610 params[k] += config.learning_rate * grad[k];
611 }
612
613 let new_score = evaluate_alignment(kernel_fn, target, ¶ms, config.use_cka)?;
614 score_history.push(new_score);
615
616 if (new_score - current_score).abs() < config.tolerance {
617 converged = true;
618 current_score = new_score;
619 break;
620 }
621 current_score = new_score;
622 }
623
624 let iterations = score_history.len();
625 Ok(OptimizationResult {
626 best_score: current_score,
627 best_params: params,
628 score_history,
629 converged,
630 iterations,
631 })
632}
633
634#[cfg(test)]
639mod tests {
640 use super::*;
641
642 fn rbf_kernel_matrix(data: &[f64], gamma: f64) -> KernelMatrix {
644 let n = data.len();
645 let mut mat = vec![vec![0.0_f64; n]; n];
646 for i in 0..n {
647 for j in 0..n {
648 let diff = data[i] - data[j];
649 mat[i][j] = (-gamma * diff * diff).exp();
650 }
651 }
652 KernelMatrix::new(mat).expect("valid kernel matrix")
653 }
654
655 #[test]
660 fn test_identity_trace_equals_n() {
661 for n in [1_usize, 3, 5, 10] {
662 let id = KernelMatrix::identity(n);
663 let tr = id.trace();
664 assert!(
665 (tr - n as f64).abs() < 1e-12,
666 "identity trace should be {n}, got {tr}"
667 );
668 }
669 }
670
671 #[test]
672 fn test_from_labels_correct_values() {
673 let labels = vec![0.0, 0.0, 1.0, 1.0];
674 let k = KernelMatrix::from_labels(&labels);
675 assert_eq!(k.n(), 4);
676 assert!((k.get(0, 1) - 1.0).abs() < 1e-12);
678 assert!((k.get(2, 3) - 1.0).abs() < 1e-12);
679 assert!((k.get(0, 0) - 1.0).abs() < 1e-12);
681 assert!((k.get(0, 2) + 1.0).abs() < 1e-12);
683 assert!((k.get(1, 3) + 1.0).abs() < 1e-12);
684 }
685
686 #[test]
687 fn test_center_zero_row_column_sums() {
688 let data = vec![
690 vec![4.0, 2.0, 1.0],
691 vec![2.0, 3.0, 0.5],
692 vec![1.0, 0.5, 2.0],
693 ];
694 let k = KernelMatrix::new(data).expect("valid");
695 let k_c = k.center();
696 let n = k_c.n();
697
698 for i in 0..n {
699 let row_sum: f64 = (0..n).map(|j| k_c.get(i, j)).sum();
700 assert!(row_sum.abs() < 1e-10, "centered row {i} sum = {row_sum}");
701 let col_sum: f64 = (0..n).map(|j| k_c.get(j, i)).sum();
702 assert!(col_sum.abs() < 1e-10, "centered col {i} sum = {col_sum}");
703 }
704 }
705
706 #[test]
707 fn test_frobenius_inner_symmetric() {
708 let data1 = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
709 let data2 = vec![vec![1.0, 0.5], vec![0.5, 2.0]];
710 let k1 = KernelMatrix::new(data1).expect("valid");
711 let k2 = KernelMatrix::new(data2).expect("valid");
712
713 let inner_12 = k1.frobenius_inner(&k2).expect("ok");
714 let inner_21 = k2.frobenius_inner(&k1).expect("ok");
715 assert!(
716 (inner_12 - inner_21).abs() < 1e-12,
717 "<K1,K2> = {inner_12}, <K2,K1> = {inner_21}"
718 );
719 }
720
721 #[test]
722 fn test_frobenius_norm_identity() {
723 for n in [1_usize, 4, 9] {
724 let id = KernelMatrix::identity(n);
725 let norm_sq = id.frobenius_norm_sq();
726 let norm = norm_sq.sqrt();
727 let expected = (n as f64).sqrt();
728 assert!(
729 (norm - expected).abs() < 1e-12,
730 "||I_n||_F should be sqrt({n}) = {expected}, got {norm}"
731 );
732 }
733 }
734
735 #[test]
736 fn test_from_flat_validates_square() {
737 let flat = vec![1.0, 0.0, 0.0, 1.0];
739 assert!(KernelMatrix::from_flat(&flat, 2).is_ok());
740
741 let bad = vec![1.0, 2.0, 3.0, 4.0, 5.0];
743 assert!(matches!(
744 KernelMatrix::from_flat(&bad, 2),
745 Err(AlignmentError::NonSquareMatrix)
746 ));
747 }
748
749 #[test]
754 fn test_kta_identical_kernels_is_one() {
755 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
756 let k = rbf_kernel_matrix(&data, 0.5);
757 let result = kernel_target_alignment(&k, &k).expect("ok");
758 assert!(
759 (result.score - 1.0).abs() < 1e-10,
760 "KTA of K with itself should be 1.0, got {}",
761 result.score
762 );
763 }
764
765 #[test]
766 fn test_kta_with_label_target_positive() {
767 let data = vec![0.0, 0.1, 0.2, 10.0, 10.1, 10.2];
769 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
770 let k = rbf_kernel_matrix(&data, 1.0);
771 let target = KernelMatrix::from_labels(&labels);
772 let result = kernel_target_alignment(&k, &target).expect("ok");
773 assert!(
774 result.score > 0.0,
775 "KTA should be positive for clustered data, got {}",
776 result.score
777 );
778 }
779
780 #[test]
781 fn test_kta_range_is_minus_one_to_one() {
782 let data = vec![1.0, 2.0, 3.0, 4.0];
783 let labels = vec![0.0, 1.0, 0.0, 1.0];
784 let k = rbf_kernel_matrix(&data, 1.0);
785 let target = KernelMatrix::from_labels(&labels);
786 let result = kernel_target_alignment(&k, &target).expect("ok");
787 assert!(
788 result.score >= -1.0 - 1e-9 && result.score <= 1.0 + 1e-9,
789 "KTA score out of range: {}",
790 result.score
791 );
792 }
793
794 #[test]
799 fn test_cka_identical_kernels_is_one() {
800 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
801 let k = rbf_kernel_matrix(&data, 0.5);
802 let result = centered_kernel_alignment(&k, &k).expect("ok");
803 assert!(
804 (result.score - 1.0).abs() < 1e-10,
805 "CKA of K with itself should be 1.0, got {}",
806 result.score
807 );
808 }
809
810 #[test]
811 fn test_cka_invariant_to_scaling() {
812 let data = vec![0.5, 1.0, 2.0, 3.0, 4.0];
813 let k = rbf_kernel_matrix(&data, 0.3);
814 let labels = vec![0.0, 0.0, 1.0, 1.0, 1.0];
815 let target = KernelMatrix::from_labels(&labels);
816
817 let n = k.n();
819 let scaled_data: Vec<Vec<f64>> = (0..n)
820 .map(|i| (0..n).map(|j| 2.0 * k.get(i, j)).collect())
821 .collect();
822 let k_scaled = KernelMatrix::new(scaled_data).expect("valid");
823
824 let cka_original = centered_kernel_alignment(&k, &target).expect("ok").score;
825 let cka_scaled = centered_kernel_alignment(&k_scaled, &target)
826 .expect("ok")
827 .score;
828
829 assert!(
830 (cka_original - cka_scaled).abs() < 1e-10,
831 "CKA should be invariant to scaling: {cka_original} vs {cka_scaled}"
832 );
833 }
834
835 #[test]
836 fn test_cka_invariant_to_mean_shift() {
837 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
838 let k = rbf_kernel_matrix(&data, 0.2);
839 let labels = vec![0.0, 0.0, 1.0, 1.0, 1.0];
840 let target = KernelMatrix::from_labels(&labels);
841
842 let n = k.n();
844 let c = 3.0_f64;
845 let shifted_data: Vec<Vec<f64>> = (0..n)
846 .map(|i| (0..n).map(|j| k.get(i, j) + c).collect())
847 .collect();
848 let k_shifted = KernelMatrix::new(shifted_data).expect("valid");
849
850 let cka_original = centered_kernel_alignment(&k, &target).expect("ok").score;
851 let cka_shifted = centered_kernel_alignment(&k_shifted, &target)
852 .expect("ok")
853 .score;
854
855 assert!(
856 (cka_original - cka_shifted).abs() < 1e-9,
857 "CKA should be invariant to constant mean shift: {cka_original} vs {cka_shifted}"
858 );
859 }
860
861 #[test]
866 fn test_hsic_identical_kernel_positive() {
867 let data = vec![1.0, 3.0, 5.0, 7.0];
868 let k = rbf_kernel_matrix(&data, 1.0);
869 let val = hsic(&k, &k).expect("ok");
870 assert!(val > 0.0, "HSIC(K,K) should be positive, got {val}");
871 }
872
873 #[test]
874 fn test_hsic_near_independent_kernels() {
875 let n = 8;
878 let identity = KernelMatrix::identity(n);
879
880 let data = vec![vec![1.0_f64; n]; n];
882 let constant_k = KernelMatrix::new(data).expect("valid");
883
884 let val = hsic(&identity, &constant_k).expect("ok");
885 assert!(
887 val.abs() < 1e-12,
888 "HSIC(I, 1*1^T) after centering should be ~0, got {val}"
889 );
890 }
891
892 #[test]
897 fn test_alignment_stats_reports_all_metrics() {
898 let data = vec![0.0, 0.5, 1.0, 5.0, 5.5, 6.0];
899 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
900 let k = rbf_kernel_matrix(&data, 2.0);
901 let target = KernelMatrix::from_labels(&labels);
902
903 let stats = alignment_stats(&k, &target).expect("ok");
904 assert_eq!(stats.n_samples, 6);
905 assert!(stats.kta >= -1.0 - 1e-9 && stats.kta <= 1.0 + 1e-9);
907 assert!(stats.cka >= -1.0 - 1e-9 && stats.cka <= 1.0 + 1e-9);
908 }
909
910 #[test]
911 fn test_alignment_stats_perfect_alignment_near_one() {
912 let data = vec![1.0, 2.0, 3.0, 4.0];
914 let k = rbf_kernel_matrix(&data, 0.5);
915 let stats = alignment_stats(&k, &k).expect("ok");
916 assert!(
917 (stats.kta - 1.0).abs() < 1e-10,
918 "KTA should be 1.0, got {}",
919 stats.kta
920 );
921 assert!(
922 (stats.cka - 1.0).abs() < 1e-10,
923 "CKA should be 1.0, got {}",
924 stats.cka
925 );
926 }
927
928 #[test]
933 fn test_grid_search_finds_best_params() {
934 let data = vec![0.0, 0.2, 0.4, 5.0, 5.2, 5.4];
935 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
936 let target = KernelMatrix::from_labels(&labels);
937
938 let params_grid: Vec<Vec<f64>> =
940 vec![vec![0.01], vec![0.1], vec![1.0], vec![5.0], vec![10.0]];
941
942 let config = AlignmentOptConfig {
943 use_cka: true,
944 ..Default::default()
945 };
946
947 let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0]);
948
949 let result = grid_search_alignment(&kernel_fn, &target, ¶ms_grid, &config).expect("ok");
950
951 assert_eq!(result.iterations, 5);
952 assert_eq!(result.score_history.len(), 5);
953 assert!(result.converged);
954
955 let max_in_history = result
957 .score_history
958 .iter()
959 .cloned()
960 .fold(f64::NEG_INFINITY, f64::max);
961 assert!(
962 (result.best_score - max_in_history).abs() < 1e-12,
963 "best_score {} should equal max in history {}",
964 result.best_score,
965 max_in_history
966 );
967 }
968
969 #[test]
970 fn test_gradient_ascent_converges_toward_higher_alignment() {
971 let data = vec![0.0, 0.3, 0.6, 4.0, 4.3, 4.6];
972 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
973 let target = KernelMatrix::from_labels(&labels);
974
975 let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0].abs());
976
977 let initial_params = vec![0.01_f64];
978 let config = AlignmentOptConfig {
979 max_iterations: 30,
980 learning_rate: 0.05,
981 tolerance: 1e-8,
982 use_cka: true,
983 fd_step: 1e-4,
984 };
985
986 let result =
987 gradient_ascent_alignment(&kernel_fn, &target, &initial_params, &config).expect("ok");
988
989 assert!(
990 !result.score_history.is_empty(),
991 "score_history must be non-empty"
992 );
993 let first_score = result.score_history[0];
995 assert!(
996 result.best_score >= first_score - 1e-6,
997 "gradient ascent should not decrease alignment: final {} < initial {}",
998 result.best_score,
999 first_score
1000 );
1001 }
1002
1003 #[test]
1004 fn test_score_history_non_decreasing_approximately() {
1005 let data = vec![0.0, 0.5, 1.0, 6.0, 6.5, 7.0];
1008 let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0];
1009 let target = KernelMatrix::from_labels(&labels);
1010
1011 let kernel_fn = |params: &[f64]| rbf_kernel_matrix(&data, params[0].abs() + 1e-3);
1012
1013 let config = AlignmentOptConfig {
1014 max_iterations: 20,
1015 learning_rate: 0.02,
1016 tolerance: 1e-9,
1017 use_cka: true,
1018 fd_step: 1e-4,
1019 };
1020
1021 let result = gradient_ascent_alignment(&kernel_fn, &target, &[0.01], &config).expect("ok");
1022
1023 let n = result.score_history.len();
1025 if n >= 2 {
1026 let final_score = result.score_history[n - 1];
1027 let initial_score = result.score_history[0];
1028 assert!(
1030 final_score >= initial_score - 0.05 * initial_score.abs().max(1e-3),
1031 "score history should trend upward: initial={initial_score}, final={final_score}"
1032 );
1033 }
1034 }
1035
1036 #[test]
1041 fn test_kta_dimension_mismatch_error() {
1042 let k1 = KernelMatrix::identity(3);
1043 let k2 = KernelMatrix::identity(4);
1044 let result = kernel_target_alignment(&k1, &k2);
1045 assert!(matches!(
1046 result,
1047 Err(AlignmentError::DimensionMismatch {
1048 expected: 3,
1049 got: 4
1050 })
1051 ));
1052 }
1053
1054 #[test]
1055 fn test_cka_dimension_mismatch_error() {
1056 let k1 = KernelMatrix::identity(2);
1057 let k2 = KernelMatrix::identity(5);
1058 let result = centered_kernel_alignment(&k1, &k2);
1059 assert!(matches!(
1060 result,
1061 Err(AlignmentError::DimensionMismatch {
1062 expected: 2,
1063 got: 5
1064 })
1065 ));
1066 }
1067
1068 #[test]
1069 fn test_hsic_dimension_mismatch_error() {
1070 let k1 = KernelMatrix::identity(3);
1071 let k2 = KernelMatrix::identity(6);
1072 let result = hsic(&k1, &k2);
1073 assert!(matches!(
1074 result,
1075 Err(AlignmentError::DimensionMismatch {
1076 expected: 3,
1077 got: 6
1078 })
1079 ));
1080 }
1081
1082 #[test]
1083 fn test_alignment_stats_dimension_mismatch() {
1084 let k = KernelMatrix::identity(3);
1085 let target = KernelMatrix::identity(4);
1086 let result = alignment_stats(&k, &target);
1087 assert!(matches!(
1088 result,
1089 Err(AlignmentError::DimensionMismatch { .. })
1090 ));
1091 }
1092
1093 #[test]
1098 fn test_matmul_identity_neutral() {
1099 let n = 4;
1100 let id = KernelMatrix::identity(n);
1101 let k = rbf_kernel_matrix(&[1.0, 2.0, 3.0, 4.0], 0.5);
1102 let product = k.matmul(&id).expect("ok");
1103 for i in 0..n {
1104 for j in 0..n {
1105 let diff = (product.get(i, j) - k.get(i, j)).abs();
1106 assert!(diff < 1e-12, "K*I should equal K at ({i},{j}): diff={diff}");
1107 }
1108 }
1109 }
1110
1111 #[test]
1112 fn test_trace_product_vs_matmul_trace() {
1113 let k = rbf_kernel_matrix(&[0.0, 1.0, 2.0, 3.0], 0.4);
1114 let l = rbf_kernel_matrix(&[0.0, 1.0, 2.0, 3.0], 0.8);
1115 let via_trace_product = k.trace_product(&l).expect("ok");
1116 let via_matmul = k.matmul(&l).expect("ok").trace();
1117 assert!(
1118 (via_trace_product - via_matmul).abs() < 1e-10,
1119 "trace(K*L) via trace_product ({via_trace_product}) vs matmul ({via_matmul})"
1120 );
1121 }
1122}