1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
20use scirs2_core::numeric::{Float, NumCast};
21
22use crate::error::{Result, TransformError};
23
24#[derive(Debug, Clone, PartialEq)]
26pub enum KernelType {
27 Linear,
29 Polynomial {
31 gamma: f64,
33 coef0: f64,
35 degree: u32,
37 },
38 RBF {
40 gamma: f64,
42 },
43 Laplacian {
45 gamma: f64,
47 },
48 Sigmoid {
50 gamma: f64,
52 coef0: f64,
54 },
55}
56
57impl KernelType {
58 pub fn rbf_auto<S>(x: &ArrayBase<S, Ix2>) -> Result<Self>
62 where
63 S: Data,
64 S::Elem: Float + NumCast,
65 {
66 let gamma = estimate_rbf_gamma(x)?;
67 Ok(KernelType::RBF { gamma })
68 }
69
70 pub fn polynomial_default() -> Self {
72 KernelType::Polynomial {
73 gamma: 1.0,
74 coef0: 1.0,
75 degree: 3,
76 }
77 }
78
79 pub fn rbf(gamma: f64) -> Self {
81 KernelType::RBF { gamma }
82 }
83
84 pub fn laplacian(gamma: f64) -> Self {
86 KernelType::Laplacian { gamma }
87 }
88
89 pub fn sigmoid_default() -> Self {
91 KernelType::Sigmoid {
92 gamma: 1.0,
93 coef0: 0.0,
94 }
95 }
96}
97
98pub fn kernel_eval<S1, S2>(
108 x: &ArrayBase<S1, Ix1>,
109 y: &ArrayBase<S2, Ix1>,
110 kernel: &KernelType,
111) -> Result<f64>
112where
113 S1: Data,
114 S2: Data,
115 S1::Elem: Float + NumCast,
116 S2::Elem: Float + NumCast,
117{
118 if x.len() != y.len() {
119 return Err(TransformError::InvalidInput(format!(
120 "Vector dimensions must match: {} vs {}",
121 x.len(),
122 y.len()
123 )));
124 }
125
126 let n = x.len();
127 match kernel {
128 KernelType::Linear => {
129 let mut dot = 0.0;
130 for i in 0..n {
131 let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
132 let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
133 dot += xi * yi;
134 }
135 Ok(dot)
136 }
137 KernelType::Polynomial {
138 gamma,
139 coef0,
140 degree,
141 } => {
142 let mut dot = 0.0;
143 for i in 0..n {
144 let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
145 let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
146 dot += xi * yi;
147 }
148 Ok((gamma * dot + coef0).powi(*degree as i32))
149 }
150 KernelType::RBF { gamma } => {
151 let mut dist_sq = 0.0;
152 for i in 0..n {
153 let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
154 let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
155 let diff = xi - yi;
156 dist_sq += diff * diff;
157 }
158 Ok((-gamma * dist_sq).exp())
159 }
160 KernelType::Laplacian { gamma } => {
161 let mut l1_dist = 0.0;
162 for i in 0..n {
163 let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
164 let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
165 l1_dist += (xi - yi).abs();
166 }
167 Ok((-gamma * l1_dist).exp())
168 }
169 KernelType::Sigmoid { gamma, coef0 } => {
170 let mut dot = 0.0;
171 for i in 0..n {
172 let xi: f64 = NumCast::from(x[i]).unwrap_or(0.0);
173 let yi: f64 = NumCast::from(y[i]).unwrap_or(0.0);
174 dot += xi * yi;
175 }
176 Ok((gamma * dot + coef0).tanh())
177 }
178 }
179}
180
181pub fn gram_matrix<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array2<f64>>
192where
193 S: Data,
194 S::Elem: Float + NumCast,
195{
196 let n_samples = x.nrows();
197 let mut k = Array2::zeros((n_samples, n_samples));
198
199 for i in 0..n_samples {
200 for j in i..n_samples {
201 let val = kernel_eval(&x.row(i), &x.row(j), kernel)?;
202 k[[i, j]] = val;
203 k[[j, i]] = val;
204 }
205 }
206
207 Ok(k)
208}
209
210pub fn cross_gram_matrix<S1, S2>(
222 x: &ArrayBase<S1, Ix2>,
223 y: &ArrayBase<S2, Ix2>,
224 kernel: &KernelType,
225) -> Result<Array2<f64>>
226where
227 S1: Data,
228 S2: Data,
229 S1::Elem: Float + NumCast,
230 S2::Elem: Float + NumCast,
231{
232 if x.ncols() != y.ncols() {
233 return Err(TransformError::InvalidInput(format!(
234 "Feature dimensions must match: {} vs {}",
235 x.ncols(),
236 y.ncols()
237 )));
238 }
239
240 let n_x = x.nrows();
241 let n_y = y.nrows();
242 let mut k = Array2::zeros((n_x, n_y));
243
244 for i in 0..n_x {
245 for j in 0..n_y {
246 k[[i, j]] = kernel_eval(&x.row(i), &y.row(j), kernel)?;
247 }
248 }
249
250 Ok(k)
251}
252
253pub fn center_kernel_matrix(k: &Array2<f64>) -> Result<Array2<f64>> {
269 let n = k.nrows();
270 if n != k.ncols() {
271 return Err(TransformError::InvalidInput(
272 "Kernel matrix must be square".to_string(),
273 ));
274 }
275 if n == 0 {
276 return Err(TransformError::InvalidInput(
277 "Kernel matrix must be non-empty".to_string(),
278 ));
279 }
280
281 let n_f64 = n as f64;
282
283 let row_means = k.mean_axis(Axis(0)).ok_or_else(|| {
285 TransformError::ComputationError("Failed to compute row means".to_string())
286 })?;
287 let col_means = k.mean_axis(Axis(1)).ok_or_else(|| {
288 TransformError::ComputationError("Failed to compute column means".to_string())
289 })?;
290 let grand_mean = row_means.sum() / n_f64;
291
292 let mut k_centered = Array2::zeros((n, n));
293 for i in 0..n {
294 for j in 0..n {
295 k_centered[[i, j]] = k[[i, j]] - row_means[j] - col_means[i] + grand_mean;
296 }
297 }
298
299 Ok(k_centered)
300}
301
302pub fn center_kernel_matrix_test(
314 k_test: &Array2<f64>,
315 k_train: &Array2<f64>,
316) -> Result<Array2<f64>> {
317 let n_train = k_train.nrows();
318 let n_test = k_test.nrows();
319
320 if k_train.nrows() != k_train.ncols() {
321 return Err(TransformError::InvalidInput(
322 "Training kernel matrix must be square".to_string(),
323 ));
324 }
325 if k_test.ncols() != n_train {
326 return Err(TransformError::InvalidInput(format!(
327 "Test kernel matrix columns ({}) must match training samples ({})",
328 k_test.ncols(),
329 n_train
330 )));
331 }
332
333 let n_f64 = n_train as f64;
334
335 let train_col_means = k_train.mean_axis(Axis(0)).ok_or_else(|| {
337 TransformError::ComputationError("Failed to compute train column means".to_string())
338 })?;
339
340 let test_row_means = k_test.mean_axis(Axis(1)).ok_or_else(|| {
342 TransformError::ComputationError("Failed to compute test row means".to_string())
343 })?;
344
345 let train_grand_mean = train_col_means.sum() / n_f64;
347
348 let mut k_centered = Array2::zeros((n_test, n_train));
349 for i in 0..n_test {
350 for j in 0..n_train {
351 k_centered[[i, j]] =
352 k_test[[i, j]] - test_row_means[i] - train_col_means[j] + train_grand_mean;
353 }
354 }
355
356 Ok(k_centered)
357}
358
359pub fn estimate_rbf_gamma<S>(x: &ArrayBase<S, Ix2>) -> Result<f64>
371where
372 S: Data,
373 S::Elem: Float + NumCast,
374{
375 let n = x.nrows();
376 if n < 2 {
377 return Err(TransformError::InvalidInput(
378 "Need at least 2 samples to estimate gamma".to_string(),
379 ));
380 }
381
382 let mut distances: Vec<f64> = Vec::with_capacity(n * (n - 1) / 2);
384 for i in 0..n {
385 for j in (i + 1)..n {
386 let mut dist_sq = 0.0;
387 for k in 0..x.ncols() {
388 let xi: f64 = NumCast::from(x[[i, k]]).unwrap_or(0.0);
389 let xj: f64 = NumCast::from(x[[j, k]]).unwrap_or(0.0);
390 let diff = xi - xj;
391 dist_sq += diff * diff;
392 }
393 distances.push(dist_sq);
394 }
395 }
396
397 distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
399
400 let median_sq = if distances.len() % 2 == 0 {
402 let mid = distances.len() / 2;
403 (distances[mid - 1] + distances[mid]) / 2.0
404 } else {
405 distances[distances.len() / 2]
406 };
407
408 if median_sq < 1e-15 {
409 Ok(1.0)
411 } else {
412 Ok(1.0 / (2.0 * median_sq))
413 }
414}
415
416pub fn kernel_diagonal<S>(x: &ArrayBase<S, Ix2>, kernel: &KernelType) -> Result<Array1<f64>>
425where
426 S: Data,
427 S::Elem: Float + NumCast,
428{
429 let n = x.nrows();
430 let mut diag = Array1::zeros(n);
431
432 for i in 0..n {
433 diag[i] = kernel_eval(&x.row(i), &x.row(i), kernel)?;
434 }
435
436 Ok(diag)
437}
438
439pub fn is_positive_semidefinite(k: &Array2<f64>, tol: f64) -> Result<bool> {
451 if k.nrows() != k.ncols() {
452 return Err(TransformError::InvalidInput(
453 "Matrix must be square".to_string(),
454 ));
455 }
456
457 let (eigenvalues, _) =
458 scirs2_linalg::eigh(&k.view(), None).map_err(TransformError::LinalgError)?;
459
460 let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
461
462 Ok(min_eigenvalue >= tol)
463}
464
465pub fn kernel_alignment(k1: &Array2<f64>, k2: &Array2<f64>) -> Result<f64> {
479 if k1.dim() != k2.dim() {
480 return Err(TransformError::InvalidInput(
481 "Kernel matrices must have the same dimensions".to_string(),
482 ));
483 }
484
485 let frobenius_inner: f64 = k1.iter().zip(k2.iter()).map(|(&a, &b)| a * b).sum();
486 let norm1: f64 = k1.iter().map(|&a| a * a).sum::<f64>().sqrt();
487 let norm2: f64 = k2.iter().map(|&a| a * a).sum::<f64>().sqrt();
488
489 let denom = norm1 * norm2;
490 if denom < 1e-15 {
491 Ok(0.0)
492 } else {
493 Ok((frobenius_inner / denom).clamp(0.0, 1.0))
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use scirs2_core::ndarray::Array;
501
502 fn sample_data() -> Array2<f64> {
503 Array::from_shape_vec(
504 (5, 3),
505 vec![
506 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
507 ],
508 )
509 .expect("Failed to create sample data")
510 }
511
512 #[test]
513 fn test_linear_kernel() {
514 let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
515 let y = Array::from_vec(vec![4.0, 5.0, 6.0]);
516 let result =
517 kernel_eval(&x.view(), &y.view(), &KernelType::Linear).expect("kernel eval failed");
518 assert!((result - 32.0).abs() < 1e-10);
520 }
521
522 #[test]
523 fn test_polynomial_kernel() {
524 let x = Array::from_vec(vec![1.0, 2.0]);
525 let y = Array::from_vec(vec![3.0, 4.0]);
526 let kernel = KernelType::Polynomial {
527 gamma: 1.0,
528 coef0: 1.0,
529 degree: 2,
530 };
531 let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
532 assert!((result - 144.0).abs() < 1e-10);
534 }
535
536 #[test]
537 fn test_rbf_kernel() {
538 let x = Array::from_vec(vec![1.0, 0.0]);
539 let y = Array::from_vec(vec![0.0, 1.0]);
540 let kernel = KernelType::RBF { gamma: 0.5 };
541 let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
542 assert!((result - (-1.0_f64).exp()).abs() < 1e-10);
544 }
545
546 #[test]
547 fn test_rbf_kernel_self() {
548 let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
549 let kernel = KernelType::RBF { gamma: 1.0 };
550 let result = kernel_eval(&x.view(), &x.view(), &kernel).expect("kernel eval failed");
551 assert!((result - 1.0).abs() < 1e-10);
553 }
554
555 #[test]
556 fn test_laplacian_kernel() {
557 let x = Array::from_vec(vec![1.0, 2.0]);
558 let y = Array::from_vec(vec![3.0, 4.0]);
559 let kernel = KernelType::Laplacian { gamma: 0.5 };
560 let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
561 assert!((result - (-2.0_f64).exp()).abs() < 1e-10);
563 }
564
565 #[test]
566 fn test_sigmoid_kernel() {
567 let x = Array::from_vec(vec![1.0, 0.0]);
568 let y = Array::from_vec(vec![0.0, 1.0]);
569 let kernel = KernelType::Sigmoid {
570 gamma: 1.0,
571 coef0: 0.0,
572 };
573 let result = kernel_eval(&x.view(), &y.view(), &kernel).expect("kernel eval failed");
574 assert!((result - 0.0).abs() < 1e-10);
576 }
577
578 #[test]
579 fn test_gram_matrix_symmetry() {
580 let data = sample_data();
581 let kernel = KernelType::RBF { gamma: 0.1 };
582 let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
583
584 assert_eq!(k.shape(), &[5, 5]);
585 for i in 0..5 {
586 for j in 0..5 {
587 assert!(
588 (k[[i, j]] - k[[j, i]]).abs() < 1e-10,
589 "Gram matrix not symmetric at ({}, {})",
590 i,
591 j
592 );
593 }
594 }
595 }
596
597 #[test]
598 fn test_gram_matrix_diagonal() {
599 let data = sample_data();
600 let kernel = KernelType::RBF { gamma: 0.1 };
601 let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
602
603 for i in 0..5 {
605 assert!(
606 (k[[i, i]] - 1.0).abs() < 1e-10,
607 "RBF diagonal should be 1.0"
608 );
609 }
610 }
611
612 #[test]
613 fn test_cross_gram_matrix() {
614 let x = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("Failed");
615 let y = Array::from_shape_vec((2, 2), vec![1.5, 2.5, 3.5, 4.5]).expect("Failed");
616 let kernel = KernelType::Linear;
617 let k = cross_gram_matrix(&x.view(), &y.view(), &kernel).expect("cross gram matrix failed");
618
619 assert_eq!(k.shape(), &[3, 2]);
620 assert!((k[[0, 0]] - 6.5).abs() < 1e-10);
622 }
623
624 #[test]
625 fn test_center_kernel_matrix() {
626 let data = sample_data();
627 let kernel = KernelType::RBF { gamma: 0.01 };
628 let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
629 let k_centered = center_kernel_matrix(&k).expect("centering failed");
630
631 let col_means = k_centered
633 .mean_axis(Axis(0))
634 .expect("Failed to compute means");
635 for i in 0..col_means.len() {
636 assert!(
637 col_means[i].abs() < 1e-10,
638 "Centered kernel column mean should be ~0, got {}",
639 col_means[i]
640 );
641 }
642
643 let row_means = k_centered
645 .mean_axis(Axis(1))
646 .expect("Failed to compute means");
647 for i in 0..row_means.len() {
648 assert!(
649 row_means[i].abs() < 1e-10,
650 "Centered kernel row mean should be ~0, got {}",
651 row_means[i]
652 );
653 }
654 }
655
656 #[test]
657 fn test_center_kernel_matrix_test() {
658 let x_train = sample_data();
659 let x_test =
660 Array::from_shape_vec((2, 3), vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5]).expect("Failed");
661 let kernel = KernelType::RBF { gamma: 0.01 };
662
663 let k_train = gram_matrix(&x_train.view(), &kernel).expect("gram failed");
664 let k_test =
665 cross_gram_matrix(&x_test.view(), &x_train.view(), &kernel).expect("cross gram failed");
666
667 let k_test_centered =
668 center_kernel_matrix_test(&k_test, &k_train).expect("test centering failed");
669 assert_eq!(k_test_centered.shape(), &[2, 5]);
670
671 for val in k_test_centered.iter() {
673 assert!(val.is_finite());
674 }
675 }
676
677 #[test]
678 fn test_estimate_rbf_gamma() {
679 let data = sample_data();
680 let gamma = estimate_rbf_gamma(&data.view()).expect("gamma estimation failed");
681 assert!(gamma > 0.0);
682 assert!(gamma.is_finite());
683 }
684
685 #[test]
686 fn test_kernel_diagonal() {
687 let data = sample_data();
688 let kernel = KernelType::Linear;
689 let diag = kernel_diagonal(&data.view(), &kernel).expect("diagonal failed");
690
691 assert_eq!(diag.len(), 5);
692 assert!((diag[0] - 14.0).abs() < 1e-10);
695 }
696
697 #[test]
698 fn test_rbf_gram_psd() {
699 let data = sample_data();
700 let kernel = KernelType::RBF { gamma: 0.1 };
701 let k = gram_matrix(&data.view(), &kernel).expect("gram matrix failed");
702 let psd = is_positive_semidefinite(&k, -1e-10).expect("PSD check failed");
703 assert!(psd, "RBF Gram matrix should be PSD");
704 }
705
706 #[test]
707 fn test_kernel_alignment() {
708 let data = sample_data();
709 let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
710 let k2 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.1 }).expect("gram failed");
711
712 let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
713 assert!(
715 (alignment - 1.0).abs() < 1e-10,
716 "Self-alignment should be 1.0, got {}",
717 alignment
718 );
719 }
720
721 #[test]
722 fn test_kernel_alignment_different() {
723 let data = sample_data();
724 let k1 = gram_matrix(&data.view(), &KernelType::RBF { gamma: 0.01 }).expect("gram failed");
725 let k2 = gram_matrix(&data.view(), &KernelType::Linear).expect("gram failed");
726
727 let alignment = kernel_alignment(&k1, &k2).expect("alignment failed");
728 assert!(alignment >= 0.0 && alignment <= 1.0);
729 }
730
731 #[test]
732 fn test_rbf_auto() {
733 let data = sample_data();
734 let kernel = KernelType::rbf_auto(&data.view()).expect("auto rbf failed");
735 match kernel {
736 KernelType::RBF { gamma } => {
737 assert!(gamma > 0.0);
738 assert!(gamma.is_finite());
739 }
740 _ => panic!("Expected RBF kernel type"),
741 }
742 }
743
744 #[test]
745 fn test_dimension_mismatch() {
746 let x = Array::from_vec(vec![1.0, 2.0]);
747 let y = Array::from_vec(vec![1.0, 2.0, 3.0]);
748 let result = kernel_eval(&x.view(), &y.view(), &KernelType::Linear);
749 assert!(result.is_err());
750 }
751}