tensorlogic_sklears_kernels/
kernel_transform.rs1use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13pub fn normalize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
44 let n = kernel_matrix.len();
45
46 if n == 0 {
47 return Ok(Vec::new());
48 }
49
50 for row in kernel_matrix {
52 if row.len() != n {
53 return Err(KernelError::ComputationError(
54 "Kernel matrix must be square".to_string(),
55 ));
56 }
57 }
58
59 let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
61
62 if diagonal.iter().any(|&d| d <= 0.0) {
64 return Err(KernelError::ComputationError(
65 "Kernel matrix has non-positive diagonal elements".to_string(),
66 ));
67 }
68
69 let sqrt_diag: Vec<f64> = diagonal.iter().map(|&d| d.sqrt()).collect();
71
72 let mut normalized = vec![vec![0.0; n]; n];
74 for i in 0..n {
75 for j in 0..n {
76 normalized[i][j] = kernel_matrix[i][j] / (sqrt_diag[i] * sqrt_diag[j]);
77 }
78 }
79
80 Ok(normalized)
81}
82
83pub fn center_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
113 let n = kernel_matrix.len();
114
115 if n == 0 {
116 return Ok(Vec::new());
117 }
118
119 for row in kernel_matrix {
121 if row.len() != n {
122 return Err(KernelError::ComputationError(
123 "Kernel matrix must be square".to_string(),
124 ));
125 }
126 }
127
128 let row_means: Vec<f64> = kernel_matrix
130 .iter()
131 .map(|row| row.iter().sum::<f64>() / n as f64)
132 .collect();
133
134 let col_means: Vec<f64> = (0..n)
136 .map(|j| kernel_matrix.iter().map(|row| row[j]).sum::<f64>() / n as f64)
137 .collect();
138
139 let grand_mean = row_means.iter().sum::<f64>() / n as f64;
141
142 let mut centered = vec![vec![0.0; n]; n];
144 #[allow(clippy::needless_range_loop)] for i in 0..n {
146 for j in 0..n {
147 centered[i][j] = kernel_matrix[i][j] - row_means[i] - col_means[j] + grand_mean;
148 }
149 }
150
151 Ok(centered)
152}
153
154pub fn standardize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
178 let normalized = normalize_kernel_matrix(kernel_matrix)?;
179 center_kernel_matrix(&normalized)
180}
181
182pub struct NormalizedKernel {
187 base_kernel: Box<dyn Kernel>,
189 diagonal_cache: std::sync::Mutex<std::collections::HashMap<u64, f64>>,
191}
192
193impl NormalizedKernel {
194 pub fn new(base_kernel: Box<dyn Kernel>) -> Self {
212 Self {
213 base_kernel,
214 diagonal_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
215 }
216 }
217
218 fn hash_vector(x: &[f64]) -> u64 {
220 use std::collections::hash_map::DefaultHasher;
221 use std::hash::{Hash, Hasher};
222
223 let mut hasher = DefaultHasher::new();
224 for &val in x {
225 val.to_bits().hash(&mut hasher);
226 }
227 hasher.finish()
228 }
229
230 fn get_diagonal(&self, x: &[f64]) -> Result<f64> {
232 let hash = Self::hash_vector(x);
233
234 {
236 let cache = self.diagonal_cache.lock().unwrap();
237 if let Some(&cached) = cache.get(&hash) {
238 return Ok(cached);
239 }
240 }
241
242 let value = self.base_kernel.compute(x, x)?;
244 let mut cache = self.diagonal_cache.lock().unwrap();
245 cache.insert(hash, value);
246 Ok(value)
247 }
248}
249
250impl Kernel for NormalizedKernel {
251 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
252 let k_xy = self.base_kernel.compute(x, y)?;
253 let k_xx = self.get_diagonal(x)?;
254 let k_yy = self.get_diagonal(y)?;
255
256 if k_xx <= 0.0 || k_yy <= 0.0 {
257 return Err(KernelError::ComputationError(
258 "Kernel diagonal elements must be positive for normalization".to_string(),
259 ));
260 }
261
262 Ok(k_xy / (k_xx * k_yy).sqrt())
263 }
264
265 fn name(&self) -> &str {
266 "Normalized"
267 }
268}
269
270#[cfg(test)]
271#[allow(non_snake_case, clippy::needless_range_loop)] mod tests {
273 use super::*;
274 use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
275
276 #[test]
277 fn test_normalize_kernel_matrix_basic() {
278 let K = vec![
279 vec![4.0, 2.0, 1.0],
280 vec![2.0, 9.0, 3.0],
281 vec![1.0, 3.0, 16.0],
282 ];
283
284 let K_norm = normalize_kernel_matrix(&K).unwrap();
285
286 assert!((K_norm[0][0] - 1.0).abs() < 1e-10);
288 assert!((K_norm[1][1] - 1.0).abs() < 1e-10);
289 assert!((K_norm[2][2] - 1.0).abs() < 1e-10);
290
291 assert!((K_norm[0][1] - K_norm[1][0]).abs() < 1e-10);
293 assert!((K_norm[0][2] - K_norm[2][0]).abs() < 1e-10);
294 assert!((K_norm[1][2] - K_norm[2][1]).abs() < 1e-10);
295 }
296
297 #[test]
298 fn test_normalize_kernel_matrix_correctness() {
299 let K = vec![vec![4.0, 2.0], vec![2.0, 9.0]];
300
301 let K_norm = normalize_kernel_matrix(&K).unwrap();
302
303 assert!((K_norm[0][1] - 1.0 / 3.0).abs() < 1e-10);
307 }
308
309 #[test]
310 fn test_normalize_kernel_matrix_empty() {
311 let K: Vec<Vec<f64>> = Vec::new();
312 let K_norm = normalize_kernel_matrix(&K).unwrap();
313 assert!(K_norm.is_empty());
314 }
315
316 #[test]
317 fn test_normalize_kernel_matrix_non_square() {
318 let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
319
320 let result = normalize_kernel_matrix(&K);
321 assert!(result.is_err());
322 }
323
324 #[test]
325 fn test_normalize_kernel_matrix_negative_diagonal() {
326 let K = vec![vec![-1.0, 2.0], vec![2.0, 4.0]];
327
328 let result = normalize_kernel_matrix(&K);
329 assert!(result.is_err());
330 }
331
332 #[test]
333 fn test_center_kernel_matrix_basic() {
334 let K = vec![
335 vec![1.0, 0.8, 0.6],
336 vec![0.8, 1.0, 0.7],
337 vec![0.6, 0.7, 1.0],
338 ];
339
340 let K_centered = center_kernel_matrix(&K).unwrap();
341
342 for row in &K_centered {
344 let row_sum: f64 = row.iter().sum();
345 assert!(row_sum.abs() < 1e-10);
346 }
347
348 for j in 0..3 {
350 let col_sum: f64 = (0..3).map(|i| K_centered[i][j]).sum();
351 assert!(col_sum.abs() < 1e-10);
352 }
353
354 let grand_sum: f64 = K_centered.iter().map(|row| row.iter().sum::<f64>()).sum();
356 assert!(grand_sum.abs() < 1e-9);
357 }
358
359 #[test]
360 fn test_center_kernel_matrix_empty() {
361 let K: Vec<Vec<f64>> = Vec::new();
362 let K_centered = center_kernel_matrix(&K).unwrap();
363 assert!(K_centered.is_empty());
364 }
365
366 #[test]
367 fn test_center_kernel_matrix_non_square() {
368 let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
369
370 let result = center_kernel_matrix(&K);
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_standardize_kernel_matrix() {
376 let K = vec![
377 vec![4.0, 2.0, 1.0],
378 vec![2.0, 9.0, 3.0],
379 vec![1.0, 3.0, 16.0],
380 ];
381
382 let K_std = standardize_kernel_matrix(&K).unwrap();
383
384 for row in &K_std {
386 let row_sum: f64 = row.iter().sum();
387 assert!(row_sum.abs() < 1e-9);
388 }
389 }
390
391 #[test]
392 fn test_normalized_kernel_wrapper() {
393 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
394 let normalized = NormalizedKernel::new(linear);
395
396 let x = vec![1.0, 2.0, 3.0];
397 let y = vec![4.0, 5.0, 6.0];
398
399 let self_sim = normalized.compute(&x, &x).unwrap();
401 assert!((self_sim - 1.0).abs() < 1e-10);
402
403 let sim = normalized.compute(&x, &y).unwrap();
405 assert!((-1.0..=1.0).contains(&sim));
406 }
407
408 #[test]
409 fn test_normalized_kernel_rbf() {
410 let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
411 let normalized = NormalizedKernel::new(rbf);
412
413 let x = vec![1.0, 2.0, 3.0];
414 let y = vec![2.0, 3.0, 4.0];
415
416 let self_sim_x = normalized.compute(&x, &x).unwrap();
418 let self_sim_y = normalized.compute(&y, &y).unwrap();
419 assert!((self_sim_x - 1.0).abs() < 1e-10);
420 assert!((self_sim_y - 1.0).abs() < 1e-10);
421
422 let sim = normalized.compute(&x, &y).unwrap();
424 assert!(sim > 0.0 && sim < 1.0);
425 }
426
427 #[test]
428 fn test_normalized_kernel_symmetry() {
429 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
430 let normalized = NormalizedKernel::new(linear);
431
432 let x = vec![1.0, 2.0, 3.0];
433 let y = vec![4.0, 5.0, 6.0];
434
435 let sim_xy = normalized.compute(&x, &y).unwrap();
436 let sim_yx = normalized.compute(&y, &x).unwrap();
437
438 assert!((sim_xy - sim_yx).abs() < 1e-10);
439 }
440
441 #[test]
442 fn test_normalized_kernel_caching() {
443 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
444 let normalized = NormalizedKernel::new(linear);
445
446 let x = vec![1.0, 2.0, 3.0];
447 let y = vec![4.0, 5.0, 6.0];
448
449 let sim1 = normalized.compute(&x, &y).unwrap();
451 let sim2 = normalized.compute(&x, &y).unwrap();
452 let sim3 = normalized.compute(&x, &y).unwrap();
453
454 assert!((sim1 - sim2).abs() < 1e-10);
455 assert!((sim2 - sim3).abs() < 1e-10);
456 }
457
458 #[test]
459 fn test_normalize_then_center_vs_standardize() {
460 let K = vec![
461 vec![4.0, 2.0, 1.0],
462 vec![2.0, 9.0, 3.0],
463 vec![1.0, 3.0, 16.0],
464 ];
465
466 let K_norm = normalize_kernel_matrix(&K).unwrap();
468 let K_norm_cent = center_kernel_matrix(&K_norm).unwrap();
469
470 let K_std = standardize_kernel_matrix(&K).unwrap();
472
473 for i in 0..3 {
475 for j in 0..3 {
476 assert!((K_norm_cent[i][j] - K_std[i][j]).abs() < 1e-10);
477 }
478 }
479 }
480}