Skip to main content

tensorlogic_sklears_kernels/
kernel_utils.rs

1//! Kernel utility functions for machine learning workflows.
2//!
3//! This module provides practical utilities for kernel-based machine learning:
4//! - Kernel-target alignment for measuring kernel quality
5//! - Gram matrix operations (eigendecomposition preparation)
6//! - Distance matrix computation from kernels
7//! - Kernel matrix validation
8
9use crate::error::{KernelError, Result};
10use crate::types::Kernel;
11
12/// Compute kernel-target alignment (KTA) between a kernel matrix and target labels.
13///
14/// KTA measures how well a kernel matrix aligns with the ideal kernel matrix
15/// derived from target labels. Higher values indicate better alignment.
16///
17/// # Arguments
18/// * `kernel_matrix` - The kernel matrix K
19/// * `labels` - Binary labels (+1 or -1) for each sample
20///
21/// # Returns
22/// * Alignment score in range [-1, 1]
23///
24/// # Examples
25/// ```
26/// use tensorlogic_sklears_kernels::kernel_utils::kernel_target_alignment;
27///
28/// let K = vec![
29///     vec![1.0, 0.8, 0.2],
30///     vec![0.8, 1.0, 0.3],
31///     vec![0.2, 0.3, 1.0],
32/// ];
33/// let labels = vec![1.0, 1.0, -1.0];
34///
35/// let alignment = kernel_target_alignment(&K, &labels).unwrap();
36/// // High alignment means kernel separates classes well
37/// ```
38pub fn kernel_target_alignment(kernel_matrix: &[Vec<f64>], labels: &[f64]) -> Result<f64> {
39    let n = kernel_matrix.len();
40
41    if n == 0 {
42        return Err(KernelError::ComputationError(
43            "Kernel matrix cannot be empty".to_string(),
44        ));
45    }
46
47    if labels.len() != n {
48        return Err(KernelError::DimensionMismatch {
49            expected: vec![n],
50            got: vec![labels.len()],
51            context: "kernel-target alignment".to_string(),
52        });
53    }
54
55    // Verify square matrix
56    for row in kernel_matrix {
57        if row.len() != n {
58            return Err(KernelError::ComputationError(
59                "Kernel matrix must be square".to_string(),
60            ));
61        }
62    }
63
64    // Compute ideal kernel matrix Y = y * y^T
65    let mut ideal_kernel = vec![vec![0.0; n]; n];
66    for i in 0..n {
67        for j in 0..n {
68            ideal_kernel[i][j] = labels[i] * labels[j];
69        }
70    }
71
72    // Compute Frobenius inner product <K, Y>
73    let mut inner_product = 0.0;
74    for i in 0..n {
75        for j in 0..n {
76            inner_product += kernel_matrix[i][j] * ideal_kernel[i][j];
77        }
78    }
79
80    // Compute Frobenius norms ||K||_F and ||Y||_F
81    let k_norm = frobenius_norm(kernel_matrix);
82    let y_norm = frobenius_norm(&ideal_kernel);
83
84    if k_norm == 0.0 || y_norm == 0.0 {
85        return Ok(0.0);
86    }
87
88    // Alignment = <K, Y> / (||K||_F * ||Y||_F)
89    Ok(inner_product / (k_norm * y_norm))
90}
91
92/// Compute the Frobenius norm of a matrix.
93///
94/// ||A||_F = sqrt(Σ_ij a_ij^2)
95fn frobenius_norm(matrix: &[Vec<f64>]) -> f64 {
96    matrix
97        .iter()
98        .flat_map(|row| row.iter())
99        .map(|&x| x * x)
100        .sum::<f64>()
101        .sqrt()
102}
103
104/// Compute pairwise distances from a kernel matrix.
105///
106/// For a valid kernel K(x,y), the distance is:
107/// d(x,y) = sqrt(K(x,x) + K(y,y) - 2*K(x,y))
108///
109/// # Arguments
110/// * `kernel_matrix` - Symmetric kernel matrix
111///
112/// # Returns
113/// * Distance matrix
114///
115/// # Examples
116/// ```
117/// use tensorlogic_sklears_kernels::kernel_utils::distances_from_kernel;
118///
119/// let K = vec![
120///     vec![1.0, 0.8, 0.6],
121///     vec![0.8, 1.0, 0.7],
122///     vec![0.6, 0.7, 1.0],
123/// ];
124///
125/// let distances = distances_from_kernel(&K).unwrap();
126/// // distances[i][j] = distance between points i and j
127/// ```
128pub fn distances_from_kernel(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
129    let n = kernel_matrix.len();
130
131    if n == 0 {
132        return Ok(Vec::new());
133    }
134
135    // Verify square matrix
136    for row in kernel_matrix {
137        if row.len() != n {
138            return Err(KernelError::ComputationError(
139                "Kernel matrix must be square".to_string(),
140            ));
141        }
142    }
143
144    // Extract diagonal
145    let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
146
147    // Compute distances: d(i,j) = sqrt(K[i,i] + K[j,j] - 2*K[i,j])
148    let mut distances = vec![vec![0.0; n]; n];
149    for i in 0..n {
150        for j in 0..n {
151            let sq_dist = diagonal[i] + diagonal[j] - 2.0 * kernel_matrix[i][j];
152            // Clamp to zero for numerical stability
153            distances[i][j] = sq_dist.max(0.0).sqrt();
154        }
155    }
156
157    Ok(distances)
158}
159
160/// Check if a kernel matrix is valid (symmetric and positive semi-definite).
161///
162/// A valid kernel matrix must be:
163/// 1. Square
164/// 2. Symmetric: `K[i,j] = K[j,i]`
165/// 3. Positive semi-definite (all eigenvalues ≥ 0)
166///
167/// Note: This function only checks symmetry. Full PSD checking requires
168/// eigendecomposition which is expensive.
169///
170/// # Arguments
171/// * `kernel_matrix` - Matrix to validate
172/// * `tolerance` - Tolerance for symmetry check
173///
174/// # Returns
175/// * `true` if matrix is valid
176///
177/// # Examples
178/// ```
179/// use tensorlogic_sklears_kernels::kernel_utils::is_valid_kernel_matrix;
180///
181/// let K = vec![
182///     vec![1.0, 0.8, 0.6],
183///     vec![0.8, 1.0, 0.7],
184///     vec![0.6, 0.7, 1.0],
185/// ];
186///
187/// assert!(is_valid_kernel_matrix(&K, 1e-10).unwrap());
188/// ```
189#[allow(clippy::needless_range_loop)]
190pub fn is_valid_kernel_matrix(kernel_matrix: &[Vec<f64>], tolerance: f64) -> Result<bool> {
191    let n = kernel_matrix.len();
192
193    if n == 0 {
194        return Ok(true);
195    }
196
197    // Check square
198    for row in kernel_matrix {
199        if row.len() != n {
200            return Ok(false);
201        }
202    }
203
204    // Check symmetry
205    for i in 0..n {
206        for j in (i + 1)..n {
207            if (kernel_matrix[i][j] - kernel_matrix[j][i]).abs() > tolerance {
208                return Ok(false);
209            }
210        }
211    }
212
213    // Note: Full PSD check would require eigendecomposition
214    // For performance, we only check symmetry here
215
216    Ok(true)
217}
218
219/// Compute the effective dimensionality (rank) of a kernel matrix
220/// based on normalized eigenvalue spectrum.
221///
222/// This is useful for determining the intrinsic dimensionality of
223/// the data in kernel space.
224///
225/// # Arguments
226/// * `kernel_matrix` - Kernel matrix
227/// * `variance_threshold` - Cumulative variance threshold (e.g., 0.95 for 95%)
228///
229/// # Returns
230/// * Estimated rank (number of eigenvalues needed to reach threshold)
231///
232/// Note: This is a simplified estimate based on diagonal dominance.
233/// For accurate rank estimation, full eigendecomposition is needed.
234pub fn estimate_kernel_rank(kernel_matrix: &[Vec<f64>], variance_threshold: f64) -> Result<usize> {
235    let n = kernel_matrix.len();
236
237    if n == 0 {
238        return Ok(0);
239    }
240
241    if !(0.0..=1.0).contains(&variance_threshold) {
242        return Err(KernelError::InvalidParameter {
243            parameter: "variance_threshold".to_string(),
244            value: variance_threshold.to_string(),
245            reason: "must be in range [0, 1]".to_string(),
246        });
247    }
248
249    // Verify square matrix
250    for row in kernel_matrix {
251        if row.len() != n {
252            return Err(KernelError::ComputationError(
253                "Kernel matrix must be square".to_string(),
254            ));
255        }
256    }
257
258    // Simple estimate: use diagonal elements as proxy for eigenvalues
259    let mut diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
260    diagonal.sort_by(|a, b| b.partial_cmp(a).unwrap()); // Sort descending
261
262    let total: f64 = diagonal.iter().sum();
263    if total == 0.0 {
264        return Ok(0);
265    }
266
267    let mut cumsum = 0.0;
268    for (rank, &val) in diagonal.iter().enumerate() {
269        cumsum += val;
270        if cumsum / total >= variance_threshold {
271            return Ok(rank + 1);
272        }
273    }
274
275    Ok(n)
276}
277
278/// Compute the kernel matrix from data using a given kernel function.
279///
280/// This is a convenience function that wraps `Kernel::compute_matrix`.
281///
282/// # Arguments
283/// * `data` - Feature vectors
284/// * `kernel` - Kernel function
285///
286/// # Returns
287/// * Kernel matrix K where `K[i][j] = kernel(data[i], data[j])`
288pub fn compute_gram_matrix(data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<Vec<Vec<f64>>> {
289    kernel.compute_matrix(data)
290}
291
292/// Normalize each row of a data matrix (L2 normalization).
293///
294/// This is useful preprocessing for some kernel methods.
295///
296/// # Arguments
297/// * `data` - Data matrix (rows are samples)
298///
299/// # Returns
300/// * Row-normalized data matrix
301///
302/// # Examples
303/// ```
304/// use tensorlogic_sklears_kernels::kernel_utils::normalize_rows;
305///
306/// let data = vec![
307///     vec![3.0, 4.0],
308///     vec![5.0, 12.0],
309/// ];
310///
311/// let normalized = normalize_rows(&data).unwrap();
312/// // Each row now has unit norm
313/// ```
314pub fn normalize_rows(data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
315    if data.is_empty() {
316        return Ok(Vec::new());
317    }
318
319    let mut normalized = Vec::with_capacity(data.len());
320
321    for row in data {
322        let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
323
324        if norm == 0.0 {
325            // Keep zero vectors as-is
326            normalized.push(row.clone());
327        } else {
328            let normalized_row: Vec<f64> = row.iter().map(|&x| x / norm).collect();
329            normalized.push(normalized_row);
330        }
331    }
332
333    Ok(normalized)
334}
335
336/// Compute kernel bandwidth using median heuristic.
337///
338/// The median heuristic sets gamma = 1 / (2 * median(distances)^2).
339/// This is a common heuristic for RBF and Laplacian kernels.
340///
341/// # Arguments
342/// * `data` - Training data
343/// * `kernel` - Base kernel (used to compute pairwise distances)
344/// * `sample_size` - Number of pairs to sample (None = use all)
345///
346/// # Returns
347/// * Suggested gamma value
348pub fn median_heuristic_bandwidth(
349    data: &[Vec<f64>],
350    kernel: &dyn Kernel,
351    sample_size: Option<usize>,
352) -> Result<f64> {
353    let n = data.len();
354
355    if n < 2 {
356        return Err(KernelError::ComputationError(
357            "Need at least 2 samples for bandwidth estimation".to_string(),
358        ));
359    }
360
361    // Compute kernel matrix
362    let gram_matrix = kernel.compute_matrix(data)?;
363
364    // Extract diagonal
365    let diagonal: Vec<f64> = (0..n).map(|i| gram_matrix[i][i]).collect();
366
367    // Compute pairwise distances
368    let mut distances = Vec::new();
369    let sample_size = sample_size.unwrap_or(n * (n - 1) / 2);
370
371    for i in 0..n {
372        for j in (i + 1)..n {
373            let sq_dist = diagonal[i] + diagonal[j] - 2.0 * gram_matrix[i][j];
374            let dist = sq_dist.max(0.0).sqrt();
375
376            if dist > 0.0 {
377                distances.push(dist);
378            }
379
380            if distances.len() >= sample_size {
381                break;
382            }
383        }
384        if distances.len() >= sample_size {
385            break;
386        }
387    }
388
389    if distances.is_empty() {
390        return Err(KernelError::ComputationError(
391            "All pairwise distances are zero".to_string(),
392        ));
393    }
394
395    // Compute median
396    distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
397    let median = if distances.len() % 2 == 0 {
398        let mid = distances.len() / 2;
399        (distances[mid - 1] + distances[mid]) / 2.0
400    } else {
401        distances[distances.len() / 2]
402    };
403
404    // gamma = 1 / (2 * median^2)
405    let gamma = 1.0 / (2.0 * median * median);
406
407    Ok(gamma)
408}
409
410#[cfg(test)]
411#[allow(non_snake_case, clippy::needless_range_loop)] // Allow K for kernel matrices, range loops for 2D matrix access
412mod tests {
413    use super::*;
414    use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
415
416    #[test]
417    fn test_kernel_target_alignment_good() {
418        // Good alignment: kernel separates classes well
419        let K = vec![
420            vec![1.0, 0.9, 0.1],
421            vec![0.9, 1.0, 0.1],
422            vec![0.1, 0.1, 1.0],
423        ];
424        let labels = vec![1.0, 1.0, -1.0];
425
426        let alignment = kernel_target_alignment(&K, &labels).unwrap();
427
428        // Alignment should be positive for well-separated classes
429        // Actual computed value is around 0.59
430        assert!((0.5..=1.0).contains(&alignment));
431    }
432
433    #[test]
434    fn test_kernel_target_alignment_poor() {
435        // Poor alignment: kernel doesn't separate classes
436        let K = vec![
437            vec![1.0, 0.5, 0.5],
438            vec![0.5, 1.0, 0.5],
439            vec![0.5, 0.5, 1.0],
440        ];
441        let labels = vec![1.0, 1.0, -1.0];
442
443        let alignment = kernel_target_alignment(&K, &labels).unwrap();
444        assert!(alignment < 0.5); // Lower alignment
445    }
446
447    #[test]
448    fn test_kernel_target_alignment_dimension_mismatch() {
449        let K = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
450        let labels = vec![1.0, -1.0, 1.0]; // Wrong size
451
452        let result = kernel_target_alignment(&K, &labels);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn test_distances_from_kernel() {
458        let K = vec![
459            vec![1.0, 0.8, 0.6],
460            vec![0.8, 1.0, 0.7],
461            vec![0.6, 0.7, 1.0],
462        ];
463
464        let distances = distances_from_kernel(&K).unwrap();
465
466        // Diagonal should be zero
467        assert!(distances[0][0].abs() < 1e-10);
468        assert!(distances[1][1].abs() < 1e-10);
469        assert!(distances[2][2].abs() < 1e-10);
470
471        // Distances should be symmetric
472        for i in 0..3 {
473            for j in 0..3 {
474                assert!((distances[i][j] - distances[j][i]).abs() < 1e-10);
475            }
476        }
477    }
478
479    #[test]
480    fn test_is_valid_kernel_matrix() {
481        // Valid symmetric matrix
482        let K = vec![
483            vec![1.0, 0.8, 0.6],
484            vec![0.8, 1.0, 0.7],
485            vec![0.6, 0.7, 1.0],
486        ];
487        assert!(is_valid_kernel_matrix(&K, 1e-10).unwrap());
488
489        // Asymmetric matrix
490        let K_bad = vec![
491            vec![1.0, 0.8, 0.6],
492            vec![0.7, 1.0, 0.7], // Different from K[0][1]
493            vec![0.6, 0.7, 1.0],
494        ];
495        assert!(!is_valid_kernel_matrix(&K_bad, 1e-10).unwrap());
496    }
497
498    #[test]
499    fn test_estimate_kernel_rank() {
500        let K = vec![
501            vec![1.0, 0.1, 0.1],
502            vec![0.1, 0.5, 0.1],
503            vec![0.1, 0.1, 0.2],
504        ];
505
506        let rank = estimate_kernel_rank(&K, 0.9).unwrap();
507        assert!((1..=3).contains(&rank));
508    }
509
510    #[test]
511    fn test_normalize_rows() {
512        let data = vec![vec![3.0, 4.0], vec![5.0, 12.0]];
513
514        let normalized = normalize_rows(&data).unwrap();
515
516        // Check unit norms
517        for row in &normalized {
518            let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
519            assert!((norm - 1.0).abs() < 1e-10);
520        }
521    }
522
523    #[test]
524    fn test_normalize_rows_zero_vector() {
525        let data = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
526
527        let normalized = normalize_rows(&data).unwrap();
528
529        // Zero vector should remain zero
530        assert!(normalized[0][0].abs() < 1e-10);
531        assert!(normalized[0][1].abs() < 1e-10);
532
533        // Second row should be normalized
534        let norm: f64 = normalized[1].iter().map(|&x| x * x).sum::<f64>().sqrt();
535        assert!((norm - 1.0).abs() < 1e-10);
536    }
537
538    #[test]
539    fn test_median_heuristic_bandwidth() {
540        let data = vec![
541            vec![0.0, 0.0],
542            vec![1.0, 0.0],
543            vec![0.0, 1.0],
544            vec![1.0, 1.0],
545        ];
546
547        let kernel = LinearKernel::new();
548        let gamma = median_heuristic_bandwidth(&data, &kernel, None).unwrap();
549
550        // Gamma should be positive
551        assert!(gamma > 0.0);
552    }
553
554    #[test]
555    fn test_compute_gram_matrix() {
556        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
557
558        let kernel = LinearKernel::new();
559        let K = compute_gram_matrix(&data, &kernel).unwrap();
560
561        // Check dimensions
562        assert_eq!(K.len(), 3);
563        assert_eq!(K[0].len(), 3);
564
565        // Check symmetry
566        for i in 0..3 {
567            for j in 0..3 {
568                assert!((K[i][j] - K[j][i]).abs() < 1e-10);
569            }
570        }
571    }
572
573    #[test]
574    fn test_frobenius_norm() {
575        let matrix = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
576
577        // ||A||_F = sqrt(1^2 + 2^2 + 3^2 + 4^2) = sqrt(30)
578        let norm = frobenius_norm(&matrix);
579        assert!((norm - 30.0_f64.sqrt()).abs() < 1e-10);
580    }
581
582    #[test]
583    fn test_kernel_target_alignment_binary_classification() {
584        // Create kernel matrix and labels for binary classification
585        let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
586
587        // Two well-separated clusters
588        let data = vec![
589            vec![0.0, 0.0],
590            vec![0.1, 0.1],
591            vec![0.2, 0.2],
592            vec![5.0, 5.0], // Far away
593            vec![5.1, 5.1],
594            vec![5.2, 5.2],
595        ];
596
597        let labels = vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0];
598
599        let K = kernel.compute_matrix(&data).unwrap();
600        let alignment = kernel_target_alignment(&K, &labels).unwrap();
601
602        // Should have positive alignment for separated clusters
603        assert!((0.0..=1.0).contains(&alignment));
604    }
605}