tensorlogic_sklears_kernels/
kernel_transform.rs1use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13pub fn normalize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
44 let n = kernel_matrix.len();
45
46 if n == 0 {
47 return Ok(Vec::new());
48 }
49
50 for row in kernel_matrix {
52 if row.len() != n {
53 return Err(KernelError::ComputationError(
54 "Kernel matrix must be square".to_string(),
55 ));
56 }
57 }
58
59 let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
61
62 if diagonal.iter().any(|&d| d <= 0.0) {
64 return Err(KernelError::ComputationError(
65 "Kernel matrix has non-positive diagonal elements".to_string(),
66 ));
67 }
68
69 let sqrt_diag: Vec<f64> = diagonal.iter().map(|&d| d.sqrt()).collect();
71
72 let mut normalized = vec![vec![0.0; n]; n];
74 for i in 0..n {
75 for j in 0..n {
76 normalized[i][j] = kernel_matrix[i][j] / (sqrt_diag[i] * sqrt_diag[j]);
77 }
78 }
79
80 Ok(normalized)
81}
82
83pub fn center_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
113 let n = kernel_matrix.len();
114
115 if n == 0 {
116 return Ok(Vec::new());
117 }
118
119 for row in kernel_matrix {
121 if row.len() != n {
122 return Err(KernelError::ComputationError(
123 "Kernel matrix must be square".to_string(),
124 ));
125 }
126 }
127
128 let row_means: Vec<f64> = kernel_matrix
130 .iter()
131 .map(|row| row.iter().sum::<f64>() / n as f64)
132 .collect();
133
134 let col_means: Vec<f64> = (0..n)
136 .map(|j| kernel_matrix.iter().map(|row| row[j]).sum::<f64>() / n as f64)
137 .collect();
138
139 let grand_mean = row_means.iter().sum::<f64>() / n as f64;
141
142 let mut centered = vec![vec![0.0; n]; n];
144 #[allow(clippy::needless_range_loop)] for i in 0..n {
146 for j in 0..n {
147 centered[i][j] = kernel_matrix[i][j] - row_means[i] - col_means[j] + grand_mean;
148 }
149 }
150
151 Ok(centered)
152}
153
154pub fn standardize_kernel_matrix(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
178 let normalized = normalize_kernel_matrix(kernel_matrix)?;
179 center_kernel_matrix(&normalized)
180}
181
182pub struct NormalizedKernel {
187 base_kernel: Box<dyn Kernel>,
189 diagonal_cache: std::sync::Mutex<std::collections::HashMap<u64, f64>>,
191}
192
193impl NormalizedKernel {
194 pub fn new(base_kernel: Box<dyn Kernel>) -> Self {
212 Self {
213 base_kernel,
214 diagonal_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
215 }
216 }
217
218 fn hash_vector(x: &[f64]) -> u64 {
220 use std::collections::hash_map::DefaultHasher;
221 use std::hash::{Hash, Hasher};
222
223 let mut hasher = DefaultHasher::new();
224 for &val in x {
225 val.to_bits().hash(&mut hasher);
226 }
227 hasher.finish()
228 }
229
230 fn get_diagonal(&self, x: &[f64]) -> Result<f64> {
232 let hash = Self::hash_vector(x);
233
234 {
236 let cache = self
237 .diagonal_cache
238 .lock()
239 .expect("lock should not be poisoned");
240 if let Some(&cached) = cache.get(&hash) {
241 return Ok(cached);
242 }
243 }
244
245 let value = self.base_kernel.compute(x, x)?;
247 let mut cache = self
248 .diagonal_cache
249 .lock()
250 .expect("lock should not be poisoned");
251 cache.insert(hash, value);
252 Ok(value)
253 }
254}
255
256impl Kernel for NormalizedKernel {
257 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
258 let k_xy = self.base_kernel.compute(x, y)?;
259 let k_xx = self.get_diagonal(x)?;
260 let k_yy = self.get_diagonal(y)?;
261
262 if k_xx <= 0.0 || k_yy <= 0.0 {
263 return Err(KernelError::ComputationError(
264 "Kernel diagonal elements must be positive for normalization".to_string(),
265 ));
266 }
267
268 Ok(k_xy / (k_xx * k_yy).sqrt())
269 }
270
271 fn name(&self) -> &str {
272 "Normalized"
273 }
274}
275
276#[cfg(test)]
277#[allow(non_snake_case, clippy::needless_range_loop)] mod tests {
279 use super::*;
280 use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
281
282 #[test]
283 fn test_normalize_kernel_matrix_basic() {
284 let K = vec![
285 vec![4.0, 2.0, 1.0],
286 vec![2.0, 9.0, 3.0],
287 vec![1.0, 3.0, 16.0],
288 ];
289
290 let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
291
292 assert!((K_norm[0][0] - 1.0).abs() < 1e-10);
294 assert!((K_norm[1][1] - 1.0).abs() < 1e-10);
295 assert!((K_norm[2][2] - 1.0).abs() < 1e-10);
296
297 assert!((K_norm[0][1] - K_norm[1][0]).abs() < 1e-10);
299 assert!((K_norm[0][2] - K_norm[2][0]).abs() < 1e-10);
300 assert!((K_norm[1][2] - K_norm[2][1]).abs() < 1e-10);
301 }
302
303 #[test]
304 fn test_normalize_kernel_matrix_correctness() {
305 let K = vec![vec![4.0, 2.0], vec![2.0, 9.0]];
306
307 let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
308
309 assert!((K_norm[0][1] - 1.0 / 3.0).abs() < 1e-10);
313 }
314
315 #[test]
316 fn test_normalize_kernel_matrix_empty() {
317 let K: Vec<Vec<f64>> = Vec::new();
318 let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
319 assert!(K_norm.is_empty());
320 }
321
322 #[test]
323 fn test_normalize_kernel_matrix_non_square() {
324 let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
325
326 let result = normalize_kernel_matrix(&K);
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_normalize_kernel_matrix_negative_diagonal() {
332 let K = vec![vec![-1.0, 2.0], vec![2.0, 4.0]];
333
334 let result = normalize_kernel_matrix(&K);
335 assert!(result.is_err());
336 }
337
338 #[test]
339 fn test_center_kernel_matrix_basic() {
340 let K = vec![
341 vec![1.0, 0.8, 0.6],
342 vec![0.8, 1.0, 0.7],
343 vec![0.6, 0.7, 1.0],
344 ];
345
346 let K_centered = center_kernel_matrix(&K).expect("unwrap");
347
348 for row in &K_centered {
350 let row_sum: f64 = row.iter().sum();
351 assert!(row_sum.abs() < 1e-10);
352 }
353
354 for j in 0..3 {
356 let col_sum: f64 = (0..3).map(|i| K_centered[i][j]).sum();
357 assert!(col_sum.abs() < 1e-10);
358 }
359
360 let grand_sum: f64 = K_centered.iter().map(|row| row.iter().sum::<f64>()).sum();
362 assert!(grand_sum.abs() < 1e-9);
363 }
364
365 #[test]
366 fn test_center_kernel_matrix_empty() {
367 let K: Vec<Vec<f64>> = Vec::new();
368 let K_centered = center_kernel_matrix(&K).expect("unwrap");
369 assert!(K_centered.is_empty());
370 }
371
372 #[test]
373 fn test_center_kernel_matrix_non_square() {
374 let K = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0]];
375
376 let result = center_kernel_matrix(&K);
377 assert!(result.is_err());
378 }
379
380 #[test]
381 fn test_standardize_kernel_matrix() {
382 let K = vec![
383 vec![4.0, 2.0, 1.0],
384 vec![2.0, 9.0, 3.0],
385 vec![1.0, 3.0, 16.0],
386 ];
387
388 let K_std = standardize_kernel_matrix(&K).expect("unwrap");
389
390 for row in &K_std {
392 let row_sum: f64 = row.iter().sum();
393 assert!(row_sum.abs() < 1e-9);
394 }
395 }
396
397 #[test]
398 fn test_normalized_kernel_wrapper() {
399 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
400 let normalized = NormalizedKernel::new(linear);
401
402 let x = vec![1.0, 2.0, 3.0];
403 let y = vec![4.0, 5.0, 6.0];
404
405 let self_sim = normalized.compute(&x, &x).expect("unwrap");
407 assert!((self_sim - 1.0).abs() < 1e-10);
408
409 let sim = normalized.compute(&x, &y).expect("unwrap");
411 assert!((-1.0..=1.0).contains(&sim));
412 }
413
414 #[test]
415 fn test_normalized_kernel_rbf() {
416 let rbf =
417 Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap")) as Box<dyn Kernel>;
418 let normalized = NormalizedKernel::new(rbf);
419
420 let x = vec![1.0, 2.0, 3.0];
421 let y = vec![2.0, 3.0, 4.0];
422
423 let self_sim_x = normalized.compute(&x, &x).expect("unwrap");
425 let self_sim_y = normalized.compute(&y, &y).expect("unwrap");
426 assert!((self_sim_x - 1.0).abs() < 1e-10);
427 assert!((self_sim_y - 1.0).abs() < 1e-10);
428
429 let sim = normalized.compute(&x, &y).expect("unwrap");
431 assert!(sim > 0.0 && sim < 1.0);
432 }
433
434 #[test]
435 fn test_normalized_kernel_symmetry() {
436 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
437 let normalized = NormalizedKernel::new(linear);
438
439 let x = vec![1.0, 2.0, 3.0];
440 let y = vec![4.0, 5.0, 6.0];
441
442 let sim_xy = normalized.compute(&x, &y).expect("unwrap");
443 let sim_yx = normalized.compute(&y, &x).expect("unwrap");
444
445 assert!((sim_xy - sim_yx).abs() < 1e-10);
446 }
447
448 #[test]
449 fn test_normalized_kernel_caching() {
450 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
451 let normalized = NormalizedKernel::new(linear);
452
453 let x = vec![1.0, 2.0, 3.0];
454 let y = vec![4.0, 5.0, 6.0];
455
456 let sim1 = normalized.compute(&x, &y).expect("unwrap");
458 let sim2 = normalized.compute(&x, &y).expect("unwrap");
459 let sim3 = normalized.compute(&x, &y).expect("unwrap");
460
461 assert!((sim1 - sim2).abs() < 1e-10);
462 assert!((sim2 - sim3).abs() < 1e-10);
463 }
464
465 #[test]
466 fn test_normalize_then_center_vs_standardize() {
467 let K = vec![
468 vec![4.0, 2.0, 1.0],
469 vec![2.0, 9.0, 3.0],
470 vec![1.0, 3.0, 16.0],
471 ];
472
473 let K_norm = normalize_kernel_matrix(&K).expect("unwrap");
475 let K_norm_cent = center_kernel_matrix(&K_norm).expect("unwrap");
476
477 let K_std = standardize_kernel_matrix(&K).expect("unwrap");
479
480 for i in 0..3 {
482 for j in 0..3 {
483 assert!((K_norm_cent[i][j] - K_std[i][j]).abs() < 1e-10);
484 }
485 }
486 }
487}