tensorlogic_sklears_kernels/
kernel_utils.rs1use crate::error::{KernelError, Result};
10use crate::types::Kernel;
11
12pub fn kernel_target_alignment(kernel_matrix: &[Vec<f64>], labels: &[f64]) -> Result<f64> {
39 let n = kernel_matrix.len();
40
41 if n == 0 {
42 return Err(KernelError::ComputationError(
43 "Kernel matrix cannot be empty".to_string(),
44 ));
45 }
46
47 if labels.len() != n {
48 return Err(KernelError::DimensionMismatch {
49 expected: vec![n],
50 got: vec![labels.len()],
51 context: "kernel-target alignment".to_string(),
52 });
53 }
54
55 for row in kernel_matrix {
57 if row.len() != n {
58 return Err(KernelError::ComputationError(
59 "Kernel matrix must be square".to_string(),
60 ));
61 }
62 }
63
64 let mut ideal_kernel = vec![vec![0.0; n]; n];
66 for i in 0..n {
67 for j in 0..n {
68 ideal_kernel[i][j] = labels[i] * labels[j];
69 }
70 }
71
72 let mut inner_product = 0.0;
74 for i in 0..n {
75 for j in 0..n {
76 inner_product += kernel_matrix[i][j] * ideal_kernel[i][j];
77 }
78 }
79
80 let k_norm = frobenius_norm(kernel_matrix);
82 let y_norm = frobenius_norm(&ideal_kernel);
83
84 if k_norm == 0.0 || y_norm == 0.0 {
85 return Ok(0.0);
86 }
87
88 Ok(inner_product / (k_norm * y_norm))
90}
91
92fn frobenius_norm(matrix: &[Vec<f64>]) -> f64 {
96 matrix
97 .iter()
98 .flat_map(|row| row.iter())
99 .map(|&x| x * x)
100 .sum::<f64>()
101 .sqrt()
102}
103
104pub fn distances_from_kernel(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
129 let n = kernel_matrix.len();
130
131 if n == 0 {
132 return Ok(Vec::new());
133 }
134
135 for row in kernel_matrix {
137 if row.len() != n {
138 return Err(KernelError::ComputationError(
139 "Kernel matrix must be square".to_string(),
140 ));
141 }
142 }
143
144 let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
146
147 let mut distances = vec![vec![0.0; n]; n];
149 for i in 0..n {
150 for j in 0..n {
151 let sq_dist = diagonal[i] + diagonal[j] - 2.0 * kernel_matrix[i][j];
152 distances[i][j] = sq_dist.max(0.0).sqrt();
154 }
155 }
156
157 Ok(distances)
158}
159
160#[allow(clippy::needless_range_loop)]
190pub fn is_valid_kernel_matrix(kernel_matrix: &[Vec<f64>], tolerance: f64) -> Result<bool> {
191 let n = kernel_matrix.len();
192
193 if n == 0 {
194 return Ok(true);
195 }
196
197 for row in kernel_matrix {
199 if row.len() != n {
200 return Ok(false);
201 }
202 }
203
204 for i in 0..n {
206 for j in (i + 1)..n {
207 if (kernel_matrix[i][j] - kernel_matrix[j][i]).abs() > tolerance {
208 return Ok(false);
209 }
210 }
211 }
212
213 Ok(true)
217}
218
219pub fn estimate_kernel_rank(kernel_matrix: &[Vec<f64>], variance_threshold: f64) -> Result<usize> {
235 let n = kernel_matrix.len();
236
237 if n == 0 {
238 return Ok(0);
239 }
240
241 if !(0.0..=1.0).contains(&variance_threshold) {
242 return Err(KernelError::InvalidParameter {
243 parameter: "variance_threshold".to_string(),
244 value: variance_threshold.to_string(),
245 reason: "must be in range [0, 1]".to_string(),
246 });
247 }
248
249 for row in kernel_matrix {
251 if row.len() != n {
252 return Err(KernelError::ComputationError(
253 "Kernel matrix must be square".to_string(),
254 ));
255 }
256 }
257
258 let mut diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
260 diagonal.sort_by(|a, b| b.partial_cmp(a).unwrap()); let total: f64 = diagonal.iter().sum();
263 if total == 0.0 {
264 return Ok(0);
265 }
266
267 let mut cumsum = 0.0;
268 for (rank, &val) in diagonal.iter().enumerate() {
269 cumsum += val;
270 if cumsum / total >= variance_threshold {
271 return Ok(rank + 1);
272 }
273 }
274
275 Ok(n)
276}
277
278pub fn compute_gram_matrix(data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<Vec<Vec<f64>>> {
289 kernel.compute_matrix(data)
290}
291
292pub fn normalize_rows(data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
315 if data.is_empty() {
316 return Ok(Vec::new());
317 }
318
319 let mut normalized = Vec::with_capacity(data.len());
320
321 for row in data {
322 let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
323
324 if norm == 0.0 {
325 normalized.push(row.clone());
327 } else {
328 let normalized_row: Vec<f64> = row.iter().map(|&x| x / norm).collect();
329 normalized.push(normalized_row);
330 }
331 }
332
333 Ok(normalized)
334}
335
336pub fn median_heuristic_bandwidth(
349 data: &[Vec<f64>],
350 kernel: &dyn Kernel,
351 sample_size: Option<usize>,
352) -> Result<f64> {
353 let n = data.len();
354
355 if n < 2 {
356 return Err(KernelError::ComputationError(
357 "Need at least 2 samples for bandwidth estimation".to_string(),
358 ));
359 }
360
361 let gram_matrix = kernel.compute_matrix(data)?;
363
364 let diagonal: Vec<f64> = (0..n).map(|i| gram_matrix[i][i]).collect();
366
367 let mut distances = Vec::new();
369 let sample_size = sample_size.unwrap_or(n * (n - 1) / 2);
370
371 for i in 0..n {
372 for j in (i + 1)..n {
373 let sq_dist = diagonal[i] + diagonal[j] - 2.0 * gram_matrix[i][j];
374 let dist = sq_dist.max(0.0).sqrt();
375
376 if dist > 0.0 {
377 distances.push(dist);
378 }
379
380 if distances.len() >= sample_size {
381 break;
382 }
383 }
384 if distances.len() >= sample_size {
385 break;
386 }
387 }
388
389 if distances.is_empty() {
390 return Err(KernelError::ComputationError(
391 "All pairwise distances are zero".to_string(),
392 ));
393 }
394
395 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
397 let median = if distances.len() % 2 == 0 {
398 let mid = distances.len() / 2;
399 (distances[mid - 1] + distances[mid]) / 2.0
400 } else {
401 distances[distances.len() / 2]
402 };
403
404 let gamma = 1.0 / (2.0 * median * median);
406
407 Ok(gamma)
408}
409
410#[cfg(test)]
411#[allow(non_snake_case, clippy::needless_range_loop)] mod tests {
413 use super::*;
414 use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
415
416 #[test]
417 fn test_kernel_target_alignment_good() {
418 let K = vec![
420 vec![1.0, 0.9, 0.1],
421 vec![0.9, 1.0, 0.1],
422 vec![0.1, 0.1, 1.0],
423 ];
424 let labels = vec![1.0, 1.0, -1.0];
425
426 let alignment = kernel_target_alignment(&K, &labels).unwrap();
427
428 assert!((0.5..=1.0).contains(&alignment));
431 }
432
433 #[test]
434 fn test_kernel_target_alignment_poor() {
435 let K = vec![
437 vec![1.0, 0.5, 0.5],
438 vec![0.5, 1.0, 0.5],
439 vec![0.5, 0.5, 1.0],
440 ];
441 let labels = vec![1.0, 1.0, -1.0];
442
443 let alignment = kernel_target_alignment(&K, &labels).unwrap();
444 assert!(alignment < 0.5); }
446
447 #[test]
448 fn test_kernel_target_alignment_dimension_mismatch() {
449 let K = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
450 let labels = vec![1.0, -1.0, 1.0]; let result = kernel_target_alignment(&K, &labels);
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn test_distances_from_kernel() {
458 let K = vec![
459 vec![1.0, 0.8, 0.6],
460 vec![0.8, 1.0, 0.7],
461 vec![0.6, 0.7, 1.0],
462 ];
463
464 let distances = distances_from_kernel(&K).unwrap();
465
466 assert!(distances[0][0].abs() < 1e-10);
468 assert!(distances[1][1].abs() < 1e-10);
469 assert!(distances[2][2].abs() < 1e-10);
470
471 for i in 0..3 {
473 for j in 0..3 {
474 assert!((distances[i][j] - distances[j][i]).abs() < 1e-10);
475 }
476 }
477 }
478
479 #[test]
480 fn test_is_valid_kernel_matrix() {
481 let K = vec![
483 vec![1.0, 0.8, 0.6],
484 vec![0.8, 1.0, 0.7],
485 vec![0.6, 0.7, 1.0],
486 ];
487 assert!(is_valid_kernel_matrix(&K, 1e-10).unwrap());
488
489 let K_bad = vec![
491 vec![1.0, 0.8, 0.6],
492 vec![0.7, 1.0, 0.7], vec![0.6, 0.7, 1.0],
494 ];
495 assert!(!is_valid_kernel_matrix(&K_bad, 1e-10).unwrap());
496 }
497
498 #[test]
499 fn test_estimate_kernel_rank() {
500 let K = vec![
501 vec![1.0, 0.1, 0.1],
502 vec![0.1, 0.5, 0.1],
503 vec![0.1, 0.1, 0.2],
504 ];
505
506 let rank = estimate_kernel_rank(&K, 0.9).unwrap();
507 assert!((1..=3).contains(&rank));
508 }
509
510 #[test]
511 fn test_normalize_rows() {
512 let data = vec![vec![3.0, 4.0], vec![5.0, 12.0]];
513
514 let normalized = normalize_rows(&data).unwrap();
515
516 for row in &normalized {
518 let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
519 assert!((norm - 1.0).abs() < 1e-10);
520 }
521 }
522
523 #[test]
524 fn test_normalize_rows_zero_vector() {
525 let data = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
526
527 let normalized = normalize_rows(&data).unwrap();
528
529 assert!(normalized[0][0].abs() < 1e-10);
531 assert!(normalized[0][1].abs() < 1e-10);
532
533 let norm: f64 = normalized[1].iter().map(|&x| x * x).sum::<f64>().sqrt();
535 assert!((norm - 1.0).abs() < 1e-10);
536 }
537
538 #[test]
539 fn test_median_heuristic_bandwidth() {
540 let data = vec![
541 vec![0.0, 0.0],
542 vec![1.0, 0.0],
543 vec![0.0, 1.0],
544 vec![1.0, 1.0],
545 ];
546
547 let kernel = LinearKernel::new();
548 let gamma = median_heuristic_bandwidth(&data, &kernel, None).unwrap();
549
550 assert!(gamma > 0.0);
552 }
553
554 #[test]
555 fn test_compute_gram_matrix() {
556 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
557
558 let kernel = LinearKernel::new();
559 let K = compute_gram_matrix(&data, &kernel).unwrap();
560
561 assert_eq!(K.len(), 3);
563 assert_eq!(K[0].len(), 3);
564
565 for i in 0..3 {
567 for j in 0..3 {
568 assert!((K[i][j] - K[j][i]).abs() < 1e-10);
569 }
570 }
571 }
572
573 #[test]
574 fn test_frobenius_norm() {
575 let matrix = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
576
577 let norm = frobenius_norm(&matrix);
579 assert!((norm - 30.0_f64.sqrt()).abs() < 1e-10);
580 }
581
582 #[test]
583 fn test_kernel_target_alignment_binary_classification() {
584 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
586
587 let data = vec![
589 vec![0.0, 0.0],
590 vec![0.1, 0.1],
591 vec![0.2, 0.2],
592 vec![5.0, 5.0], vec![5.1, 5.1],
594 vec![5.2, 5.2],
595 ];
596
597 let labels = vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0];
598
599 let K = kernel.compute_matrix(&data).unwrap();
600 let alignment = kernel_target_alignment(&K, &labels).unwrap();
601
602 assert!((0.0..=1.0).contains(&alignment));
604 }
605}