Skip to main content

tensorlogic_sklears_kernels/
kernel_transform.rs

1//! Kernel transformation utilities for preprocessing and normalization.
2//!
3//! This module provides utilities for transforming kernel matrices, including:
4//! - Kernel normalization (normalize to unit diagonal)
5//! - Kernel centering (for kernel PCA)
6//! - Kernel standardization
7//!
8//! These transformations are essential for many kernel-based algorithms.
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13/// Normalize a kernel matrix to have unit diagonal entries.
14///
15/// Normalized kernel: K_norm(x,y) = K(x,y) / sqrt(K(x,x) * K(y,y))
16///
17/// This ensures all diagonal entries equal 1.0, which is useful for
18/// algorithms that assume normalized kernels.
19///
20/// # Arguments
21/// * `kernel_matrix` - Input kernel matrix (must be square)
22///
23/// # Returns
24/// * Normalized kernel matrix
25///
26/// # Examples
27/// ```
28/// use tensorlogic_sklears_kernels::kernel_transform::normalize_kernel_matrix;
29///
30/// let K = vec![
31///     vec![4.0, 2.0, 1.0],
32///     vec![2.0, 9.0, 3.0],
33///     vec![1.0, 3.0, 16.0],
34/// ];
35///
36/// let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
37///
38/// // All diagonal entries should be 1.0
39/// assert!((K_norm[0][0] - 1.0).abs() < 1e-10);
40/// assert!((K_norm[1][1] - 1.0).abs() < 1e-10);
41/// assert!((K_norm[2][2] - 1.0).abs() < 1e-10);
42/// ```
43pub fn normalize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
44    let n = kernel_matrix.len();
45
46    if n == 0 {
47        return Ok(Vec::new());
48    }
49
50    // Verify square matrix
51    for row in kernel_matrix {
52        if row.len() != n {
53            return Err(KernelError::ComputationError(
54                "Kernel matrix must be square".to_string(),
55            ));
56        }
57    }
58
59    // Extract diagonal elements
60    let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
61
62    // Check for non-positive diagonal elements
63    if diagonal.iter().any(|&d| d <= 0.0) {
64        return Err(KernelError::ComputationError(
65            "Kernel matrix has non-positive diagonal elements".to_string(),
66        ));
67    }
68
69    // Compute normalization factors
70    let sqrt_diag: Vec<f64> = diagonal.iter().map(|&d| d.sqrt()).collect();
71
72    // Normalize: K_norm[i,j] = K[i,j] / (sqrt(K[i,i]) * sqrt(K[j,j]))
73    let mut normalized = vec![vec![0.0; n]; n];
74    for i in 0..n {
75        for j in 0..n {
76            normalized[i][j] = kernel_matrix[i][j] / (sqrt_diag[i] * sqrt_diag[j]);
77        }
78    }
79
80    Ok(normalized)
81}
82
83/// Center a kernel matrix by removing the mean in feature space.
84///
85/// Centered kernel: K_c = (I - 1/n * 11^T) K (I - 1/n * 11^T)
86///
87/// This transformation is required for kernel PCA to ensure the
88/// data is centered in feature space.
89///
90/// # Arguments
91/// * `kernel_matrix` - Input kernel matrix (must be square)
92///
93/// # Returns
94/// * Centered kernel matrix
95///
96/// # Examples
97/// ```
98/// use tensorlogic_sklears_kernels::kernel_transform::center_kernel_matrix;
99///
100/// let K = vec![
101///     vec![1.0, 0.8, 0.6],
102///     vec![0.8, 1.0, 0.7],
103///     vec![0.6, 0.7, 1.0],
104/// ];
105///
106/// let K_centered = center_kernel_matrix(&K).expect("unwrap");
107///
108/// // Row and column means should be approximately zero
109/// let row_mean: f64 = K_centered[0].iter().sum::<f64>() / 3.0;
110/// assert!(row_mean.abs() < 1e-10);
111/// ```
112pub fn center_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
113    let n = kernel_matrix.len();
114
115    if n == 0 {
116        return Ok(Vec::new());
117    }
118
119    // Verify square matrix
120    for row in kernel_matrix {
121        if row.len() != n {
122            return Err(KernelError::ComputationError(
123                "Kernel matrix must be square".to_string(),
124            ));
125        }
126    }
127
128    // Compute row means
129    let row_means: Vec<f64> = kernel_matrix
130        .iter()
131        .map(|row| row.iter().sum::<f64>() / n as f64)
132        .collect();
133
134    // Compute column means (for symmetric matrices, same as row means, but compute anyway)
135    let col_means: Vec<f64> = (0..n)
136        .map(|j| kernel_matrix.iter().map(|row| row[j]).sum::<f64>() / n as f64)
137        .collect();
138
139    // Compute grand mean
140    let grand_mean = row_means.iter().sum::<f64>() / n as f64;
141
142    // Center: K_c[i,j] = K[i,j] - row_mean[i] - col_mean[j] + grand_mean
143    let mut centered = vec![vec![0.0; n]; n];
144    #[allow(clippy::needless_range_loop)] // Nested loops needed for matrix indexing
145    for i in 0..n {
146        for j in 0..n {
147            centered[i][j] = kernel_matrix[i][j] - row_means[i] - col_means[j] + grand_mean;
148        }
149    }
150
151    Ok(centered)
152}
153
154/// Standardize a kernel matrix (normalize then center).
155///
156/// This combines normalization and centering in one operation,
157/// which is useful for many kernel-based algorithms.
158///
159/// # Arguments
160/// * `kernel_matrix` - Input kernel matrix (must be square)
161///
162/// # Returns
163/// * Standardized kernel matrix
164///
165/// # Examples
166/// ```
167/// use tensorlogic_sklears_kernels::kernel_transform::standardize_kernel_matrix;
168///
169/// let K = vec![
170///     vec![4.0, 2.0, 1.0],
171///     vec![2.0, 9.0, 3.0],
172///     vec![1.0, 3.0, 16.0],
173/// ];
174///
175/// let K_std = standardize_kernel_matrix(&K).expect("unwrap");
176/// ```
177pub fn standardize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
178    let normalized = normalize_kernel_matrix(kernel_matrix)?;
179    center_kernel_matrix(&normalized)
180}
181
182/// Wrapper that creates a normalized version of any kernel.
183///
184/// The normalized kernel computes K_norm(x,y) = K(x,y) / sqrt(K(x,x) * K(y,y))
185/// ensuring that K_norm(x,x) = 1.0 for all x.
186pub struct NormalizedKernel {
187    /// Base kernel
188    base_kernel: Box<dyn Kernel>,
189    /// Cache for diagonal values K(x,x) (thread-safe)
190    diagonal_cache: std::sync::Mutex<std::collections::HashMap<u64, f64>>,
191}
192
193impl NormalizedKernel {
194    /// Create a new normalized kernel wrapper
195    ///
196    /// # Examples
197    /// ```
198    /// use tensorlogic_sklears_kernels::{LinearKernel, NormalizedKernel, Kernel};
199    ///
200    /// let linear = Box::new(LinearKernel::new());
201    /// let normalized = NormalizedKernel::new(linear);
202    ///
203    /// let x = vec![1.0, 2.0, 3.0];
204    /// let y = vec![4.0, 5.0, 6.0];
205    /// let sim = normalized.compute(&x, &y).expect("unwrap");
206    ///
207    /// // Self-similarity should be 1.0
208    /// let self_sim = normalized.compute(&x, &x).expect("unwrap");
209    /// assert!((self_sim - 1.0).abs() < 1e-10);
210    /// ```
211    pub fn new(base_kernel: Box<dyn Kernel>) -> Self {
212        Self {
213            base_kernel,
214            diagonal_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
215        }
216    }
217
218    /// Hash a vector for caching (simple hash for demonstration)
219    fn hash_vector(x: &[f64]) -> u64 {
220        use std::collections::hash_map::DefaultHasher;
221        use std::hash::{Hash, Hasher};
222
223        let mut hasher = DefaultHasher::new();
224        for &val in x {
225            val.to_bits().hash(&mut hasher);
226        }
227        hasher.finish()
228    }
229
230    /// Get diagonal value K(x,x) with caching
231    fn get_diagonal(&self, x: &[f64]) -> Result<f64> {
232        let hash = Self::hash_vector(x);
233
234        // Check cache
235        {
236            let cache = self
237                .diagonal_cache
238                .lock()
239                .expect("lock should not be poisoned");
240            if let Some(&cached) = cache.get(&hash) {
241                return Ok(cached);
242            }
243        }
244
245        // Compute and cache
246        let value = self.base_kernel.compute(x, x)?;
247        let mut cache = self
248            .diagonal_cache
249            .lock()
250            .expect("lock should not be poisoned");
251        cache.insert(hash, value);
252        Ok(value)
253    }
254}
255
256impl Kernel for NormalizedKernel {
257    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
258        let k_xy = self.base_kernel.compute(x, y)?;
259        let k_xx = self.get_diagonal(x)?;
260        let k_yy = self.get_diagonal(y)?;
261
262        if k_xx <= 0.0 || k_yy <= 0.0 {
263            return Err(KernelError::ComputationError(
264                "Kernel diagonal elements must be positive for normalization".to_string(),
265            ));
266        }
267
268        Ok(k_xy / (k_xx * k_yy).sqrt())
269    }
270
271    fn name(&self) -> &str {
272        "Normalized"
273    }
274}
275
276#[cfg(test)]
277#[allow(non_snake_case, clippy::needless_range_loop)] // Allow K for kernel matrices, range loops for 2D matrix access
278mod tests {
279    use super::*;
280    use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
281
282    #[test]
283    fn test_normalize_kernel_matrix_basic() {
284        let K = vec![
285            vec![4.0, 2.0, 1.0],
286            vec![2.0, 9.0, 3.0],
287            vec![1.0, 3.0, 16.0],
288        ];
289
290        let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
291
292        // Check diagonal is all 1.0
293        assert!((K_norm[0][0] - 1.0).abs() < 1e-10);
294        assert!((K_norm[1][1] - 1.0).abs() < 1e-10);
295        assert!((K_norm[2][2] - 1.0).abs() < 1e-10);
296
297        // Check symmetry preserved
298        assert!((K_norm[0][1] - K_norm[1][0]).abs() < 1e-10);
299        assert!((K_norm[0][2] - K_norm[2][0]).abs() < 1e-10);
300        assert!((K_norm[1][2] - K_norm[2][1]).abs() < 1e-10);
301    }
302
303    #[test]
304    fn test_normalize_kernel_matrix_correctness() {
305        let K = vec![vec![4.0, 2.0], vec![2.0, 9.0]];
306
307        let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
308
309        // K_norm[0][1] = K[0][1] / sqrt(K[0][0] * K[1][1])
310        //              = 2.0 / sqrt(4.0 * 9.0)
311        //              = 2.0 / 6.0 = 1/3
312        assert!((K_norm[0][1] - 1.0 / 3.0).abs() < 1e-10);
313    }
314
315    #[test]
316    fn test_normalize_kernel_matrix_empty() {
317        let K: Vec<Vec<f64>> = Vec::new();
318        let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
319        assert!(K_norm.is_empty());
320    }
321
322    #[test]
323    fn test_normalize_kernel_matrix_non_square() {
324        let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
325
326        let result = normalize_kernel_matrix(&K);
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn test_normalize_kernel_matrix_negative_diagonal() {
332        let K = vec![vec![-1.0, 2.0], vec![2.0, 4.0]];
333
334        let result = normalize_kernel_matrix(&K);
335        assert!(result.is_err());
336    }
337
338    #[test]
339    fn test_center_kernel_matrix_basic() {
340        let K = vec![
341            vec![1.0, 0.8, 0.6],
342            vec![0.8, 1.0, 0.7],
343            vec![0.6, 0.7, 1.0],
344        ];
345
346        let K_centered = center_kernel_matrix(&K).expect("unwrap");
347
348        // Check row sums are approximately zero
349        for row in &K_centered {
350            let row_sum: f64 = row.iter().sum();
351            assert!(row_sum.abs() < 1e-10);
352        }
353
354        // Check column sums are approximately zero
355        for j in 0..3 {
356            let col_sum: f64 = (0..3).map(|i| K_centered[i][j]).sum();
357            assert!(col_sum.abs() < 1e-10);
358        }
359
360        // Check grand sum is approximately zero
361        let grand_sum: f64 = K_centered.iter().map(|row| row.iter().sum::<f64>()).sum();
362        assert!(grand_sum.abs() < 1e-9);
363    }
364
365    #[test]
366    fn test_center_kernel_matrix_empty() {
367        let K: Vec<Vec<f64>> = Vec::new();
368        let K_centered = center_kernel_matrix(&K).expect("unwrap");
369        assert!(K_centered.is_empty());
370    }
371
372    #[test]
373    fn test_center_kernel_matrix_non_square() {
374        let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
375
376        let result = center_kernel_matrix(&K);
377        assert!(result.is_err());
378    }
379
380    #[test]
381    fn test_standardize_kernel_matrix() {
382        let K = vec![
383            vec![4.0, 2.0, 1.0],
384            vec![2.0, 9.0, 3.0],
385            vec![1.0, 3.0, 16.0],
386        ];
387
388        let K_std = standardize_kernel_matrix(&K).expect("unwrap");
389
390        // After standardization, row/column sums should be close to zero
391        for row in &K_std {
392            let row_sum: f64 = row.iter().sum();
393            assert!(row_sum.abs() < 1e-9);
394        }
395    }
396
397    #[test]
398    fn test_normalized_kernel_wrapper() {
399        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
400        let normalized = NormalizedKernel::new(linear);
401
402        let x = vec![1.0, 2.0, 3.0];
403        let y = vec![4.0, 5.0, 6.0];
404
405        // Self-similarity should be 1.0
406        let self_sim = normalized.compute(&x, &x).expect("unwrap");
407        assert!((self_sim - 1.0).abs() < 1e-10);
408
409        // Compute normalized similarity
410        let sim = normalized.compute(&x, &y).expect("unwrap");
411        assert!((-1.0..=1.0).contains(&sim));
412    }
413
414    #[test]
415    fn test_normalized_kernel_rbf() {
416        let rbf =
417            Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap")) as Box<dyn Kernel>;
418        let normalized = NormalizedKernel::new(rbf);
419
420        let x = vec![1.0, 2.0, 3.0];
421        let y = vec![2.0, 3.0, 4.0];
422
423        // Self-similarity should be 1.0
424        let self_sim_x = normalized.compute(&x, &x).expect("unwrap");
425        let self_sim_y = normalized.compute(&y, &y).expect("unwrap");
426        assert!((self_sim_x - 1.0).abs() < 1e-10);
427        assert!((self_sim_y - 1.0).abs() < 1e-10);
428
429        // Cross-similarity should be in (0, 1)
430        let sim = normalized.compute(&x, &y).expect("unwrap");
431        assert!(sim > 0.0 && sim < 1.0);
432    }
433
434    #[test]
435    fn test_normalized_kernel_symmetry() {
436        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
437        let normalized = NormalizedKernel::new(linear);
438
439        let x = vec![1.0, 2.0, 3.0];
440        let y = vec![4.0, 5.0, 6.0];
441
442        let sim_xy = normalized.compute(&x, &y).expect("unwrap");
443        let sim_yx = normalized.compute(&y, &x).expect("unwrap");
444
445        assert!((sim_xy - sim_yx).abs() < 1e-10);
446    }
447
448    #[test]
449    fn test_normalized_kernel_caching() {
450        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
451        let normalized = NormalizedKernel::new(linear);
452
453        let x = vec![1.0, 2.0, 3.0];
454        let y = vec![4.0, 5.0, 6.0];
455
456        // Multiple calls should use cache
457        let sim1 = normalized.compute(&x, &y).expect("unwrap");
458        let sim2 = normalized.compute(&x, &y).expect("unwrap");
459        let sim3 = normalized.compute(&x, &y).expect("unwrap");
460
461        assert!((sim1 - sim2).abs() < 1e-10);
462        assert!((sim2 - sim3).abs() < 1e-10);
463    }
464
465    #[test]
466    fn test_normalize_then_center_vs_standardize() {
467        let K = vec![
468            vec![4.0, 2.0, 1.0],
469            vec![2.0, 9.0, 3.0],
470            vec![1.0, 3.0, 16.0],
471        ];
472
473        // Method 1: Normalize then center
474        let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
475        let K_norm_cent = center_kernel_matrix(&K_norm).expect("unwrap");
476
477        // Method 2: Use standardize
478        let K_std = standardize_kernel_matrix(&K).expect("unwrap");
479
480        // Should be identical
481        for i in 0..3 {
482            for j in 0..3 {
483                assert!((K_norm_cent[i][j] - K_std[i][j]).abs() < 1e-10);
484            }
485        }
486    }
487}