tensorlogic_sklears_kernels/
composite_kernel.rs1use std::sync::Arc;
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13pub struct WeightedSumKernel {
38 kernels: Vec<Arc<dyn Kernel>>,
40 weights: Vec<f64>,
42 normalized: bool,
44}
45
46impl WeightedSumKernel {
47 pub fn new(kernels: Vec<Box<dyn Kernel>>, weights: Vec<f64>) -> Result<Self> {
49 if kernels.is_empty() {
50 return Err(KernelError::InvalidParameter {
51 parameter: "kernels".to_string(),
52 value: "empty".to_string(),
53 reason: "at least one kernel required".to_string(),
54 });
55 }
56
57 if kernels.len() != weights.len() {
58 return Err(KernelError::DimensionMismatch {
59 expected: vec![kernels.len()],
60 got: vec![weights.len()],
61 context: "weighted sum kernel".to_string(),
62 });
63 }
64
65 if weights.iter().any(|&w| w < 0.0) {
67 return Err(KernelError::InvalidParameter {
68 parameter: "weights".to_string(),
69 value: format!("{:?}", weights),
70 reason: "all weights must be non-negative".to_string(),
71 });
72 }
73
74 let weight_sum: f64 = weights.iter().sum();
75 if weight_sum <= 0.0 {
76 return Err(KernelError::InvalidParameter {
77 parameter: "weights".to_string(),
78 value: format!("{:?}", weights),
79 reason: "weights must sum to a positive value".to_string(),
80 });
81 }
82
83 let kernels: Vec<Arc<dyn Kernel>> = kernels.into_iter().map(Arc::from).collect();
85
86 Ok(Self {
87 kernels,
88 weights,
89 normalized: false,
90 })
91 }
92
93 pub fn new_normalized(kernels: Vec<Box<dyn Kernel>>, mut weights: Vec<f64>) -> Result<Self> {
95 let weight_sum: f64 = weights.iter().sum();
96 if weight_sum <= 0.0 {
97 return Err(KernelError::InvalidParameter {
98 parameter: "weights".to_string(),
99 value: format!("{:?}", weights),
100 reason: "weights must sum to a positive value".to_string(),
101 });
102 }
103
104 for w in &mut weights {
106 *w /= weight_sum;
107 }
108
109 let mut kernel = Self::new(kernels, weights)?;
110 kernel.normalized = true;
111 Ok(kernel)
112 }
113
114 pub fn uniform(kernels: Vec<Box<dyn Kernel>>) -> Result<Self> {
116 let n = kernels.len();
117 if n == 0 {
118 return Err(KernelError::InvalidParameter {
119 parameter: "kernels".to_string(),
120 value: "empty".to_string(),
121 reason: "at least one kernel required".to_string(),
122 });
123 }
124
125 let weights = vec![1.0 / n as f64; n];
126 Self::new_normalized(kernels, weights)
127 }
128}
129
130impl Kernel for WeightedSumKernel {
131 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
132 let mut result = 0.0;
133
134 for (kernel, &weight) in self.kernels.iter().zip(self.weights.iter()) {
135 let value = kernel.compute(x, y)?;
136 result += weight * value;
137 }
138
139 Ok(result)
140 }
141
142 fn name(&self) -> &str {
143 "WeightedSum"
144 }
145
146 fn is_psd(&self) -> bool {
147 self.weights.iter().all(|&w| w >= 0.0) && self.kernels.iter().all(|k| k.is_psd())
149 }
150}
151
152pub struct ProductKernel {
176 kernels: Vec<Arc<dyn Kernel>>,
178}
179
180impl ProductKernel {
181 pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Result<Self> {
183 if kernels.is_empty() {
184 return Err(KernelError::InvalidParameter {
185 parameter: "kernels".to_string(),
186 value: "empty".to_string(),
187 reason: "at least one kernel required".to_string(),
188 });
189 }
190
191 let kernels: Vec<Arc<dyn Kernel>> = kernels.into_iter().map(Arc::from).collect();
193
194 Ok(Self { kernels })
195 }
196}
197
198impl Kernel for ProductKernel {
199 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
200 let mut result = 1.0;
201
202 for kernel in &self.kernels {
203 let value = kernel.compute(x, y)?;
204 result *= value;
205 }
206
207 Ok(result)
208 }
209
210 fn name(&self) -> &str {
211 "Product"
212 }
213
214 fn is_psd(&self) -> bool {
215 self.kernels.iter().all(|k| k.is_psd())
217 }
218}
219
220pub struct KernelAlignment;
233
234impl KernelAlignment {
235 pub fn compute_alignment(k1: &[Vec<f64>], k2: &[Vec<f64>]) -> Result<f64> {
244 if k1.is_empty() || k2.is_empty() {
245 return Err(KernelError::InvalidParameter {
246 parameter: "kernel_matrices".to_string(),
247 value: "empty".to_string(),
248 reason: "kernel matrices cannot be empty".to_string(),
249 });
250 }
251
252 let n1 = k1.len();
253 let n2 = k2.len();
254
255 if n1 != n2 {
256 return Err(KernelError::DimensionMismatch {
257 expected: vec![n1, n1],
258 got: vec![n2, n2],
259 context: "kernel alignment".to_string(),
260 });
261 }
262
263 for (i, row) in k1.iter().enumerate() {
265 if row.len() != n1 {
266 return Err(KernelError::DimensionMismatch {
267 expected: vec![n1],
268 got: vec![row.len()],
269 context: format!("k1 row {}", i),
270 });
271 }
272 }
273
274 for (i, row) in k2.iter().enumerate() {
275 if row.len() != n2 {
276 return Err(KernelError::DimensionMismatch {
277 expected: vec![n2],
278 got: vec![row.len()],
279 context: format!("k2 row {}", i),
280 });
281 }
282 }
283
284 let k1_centered = Self::center_kernel_matrix(k1);
286 let k2_centered = Self::center_kernel_matrix(k2);
287
288 let mut inner_product = 0.0;
290 for i in 0..n1 {
291 for j in 0..n1 {
292 inner_product += k1_centered[i][j] * k2_centered[i][j];
293 }
294 }
295
296 let norm1 = Self::frobenius_norm(&k1_centered);
298 let norm2 = Self::frobenius_norm(&k2_centered);
299
300 if norm1 == 0.0 || norm2 == 0.0 {
301 return Ok(0.0);
302 }
303
304 Ok(inner_product / (norm1 * norm2))
305 }
306
307 #[allow(clippy::needless_range_loop)]
309 fn center_kernel_matrix(k: &[Vec<f64>]) -> Vec<Vec<f64>> {
310 let n = k.len();
311 let mut centered = vec![vec![0.0; n]; n];
312
313 let mut row_means = vec![0.0; n];
315 let mut col_means = vec![0.0; n];
316 let mut total_mean = 0.0;
317
318 for i in 0..n {
319 for j in 0..n {
320 row_means[i] += k[i][j];
321 col_means[j] += k[i][j];
322 total_mean += k[i][j];
323 }
324 }
325
326 for mean in &mut row_means {
327 *mean /= n as f64;
328 }
329 for mean in &mut col_means {
330 *mean /= n as f64;
331 }
332 total_mean /= (n * n) as f64;
333
334 for i in 0..n {
336 for j in 0..n {
337 centered[i][j] = k[i][j] - row_means[i] - col_means[j] + total_mean;
338 }
339 }
340
341 centered
342 }
343
344 fn frobenius_norm(k: &[Vec<f64>]) -> f64 {
346 let mut sum_sq = 0.0;
347 for row in k {
348 for &val in row {
349 sum_sq += val * val;
350 }
351 }
352 sum_sq.sqrt()
353 }
354}
355
356#[cfg(test)]
357#[allow(clippy::needless_range_loop)]
358mod tests {
359 use super::*;
360 use crate::tensor_kernels::{CosineKernel, LinearKernel, RbfKernel};
361 use crate::types::RbfKernelConfig;
362
363 #[test]
364 fn test_weighted_sum_kernel() {
365 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
366 let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
367
368 let weights = vec![0.7, 0.3];
369 let kernel = WeightedSumKernel::new(vec![linear, rbf], weights).unwrap();
370
371 let x = vec![1.0, 2.0, 3.0];
372 let y = vec![4.0, 5.0, 6.0];
373
374 let result = kernel.compute(&x, &y).unwrap();
375 assert!(result > 0.0);
376 assert_eq!(kernel.name(), "WeightedSum");
377 }
378
379 #[test]
380 fn test_weighted_sum_normalized() {
381 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
382 let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
383
384 let weights = vec![2.0, 3.0]; let kernel = WeightedSumKernel::new_normalized(vec![linear, cosine], weights).unwrap();
386
387 let x = vec![1.0, 2.0, 3.0];
388 let y = vec![4.0, 5.0, 6.0];
389
390 let result = kernel.compute(&x, &y).unwrap();
391 assert!(result > 0.0);
392 }
393
394 #[test]
395 fn test_weighted_sum_uniform() {
396 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
397 let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
398 let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
399
400 let kernel = WeightedSumKernel::uniform(vec![linear, cosine, rbf]).unwrap();
401
402 let x = vec![1.0, 2.0, 3.0];
403 let y = vec![4.0, 5.0, 6.0];
404
405 let result = kernel.compute(&x, &y).unwrap();
406 assert!(result > 0.0);
407 }
408
409 #[test]
410 fn test_weighted_sum_empty_kernels() {
411 let result = WeightedSumKernel::new(vec![], vec![]);
412 assert!(result.is_err());
413 }
414
415 #[test]
416 fn test_weighted_sum_dimension_mismatch() {
417 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
418 let result = WeightedSumKernel::new(vec![linear], vec![0.5, 0.5]);
419 assert!(result.is_err());
420 }
421
422 #[test]
423 fn test_weighted_sum_negative_weights() {
424 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
425 let result = WeightedSumKernel::new(vec![linear], vec![-0.5]);
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn test_product_kernel() {
431 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
432 let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
433
434 let kernel = ProductKernel::new(vec![linear, cosine]).unwrap();
435
436 let x = vec![1.0, 2.0, 3.0];
437 let y = vec![4.0, 5.0, 6.0];
438
439 let result = kernel.compute(&x, &y).unwrap();
440 assert!(result > 0.0);
441 assert_eq!(kernel.name(), "Product");
442 }
443
444 #[test]
445 fn test_product_kernel_empty() {
446 let result = ProductKernel::new(vec![]);
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn test_product_psd_property() {
452 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
453 let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
454
455 let kernel = ProductKernel::new(vec![linear, rbf]).unwrap();
456 assert!(kernel.is_psd());
457 }
458
459 #[test]
460 fn test_kernel_alignment() {
461 let k1 = vec![
463 vec![1.0, 0.8, 0.6],
464 vec![0.8, 1.0, 0.7],
465 vec![0.6, 0.7, 1.0],
466 ];
467
468 let k2 = vec![
469 vec![1.0, 0.75, 0.55],
470 vec![0.75, 1.0, 0.65],
471 vec![0.55, 0.65, 1.0],
472 ];
473
474 let alignment = KernelAlignment::compute_alignment(&k1, &k2).unwrap();
475
476 assert!(alignment > 0.9);
478 assert!(alignment <= 1.0);
479 }
480
481 #[test]
482 fn test_kernel_alignment_identity() {
483 let k = vec![
484 vec![1.0, 0.5, 0.3],
485 vec![0.5, 1.0, 0.4],
486 vec![0.3, 0.4, 1.0],
487 ];
488
489 let alignment = KernelAlignment::compute_alignment(&k, &k).unwrap();
490
491 assert!((alignment - 1.0).abs() < 1e-10);
493 }
494
495 #[test]
496 fn test_kernel_alignment_dimension_mismatch() {
497 let k1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
498
499 let k2 = vec![
500 vec![1.0, 0.5, 0.3],
501 vec![0.5, 1.0, 0.4],
502 vec![0.3, 0.4, 1.0],
503 ];
504
505 let result = KernelAlignment::compute_alignment(&k1, &k2);
506 assert!(result.is_err());
507 }
508
509 #[test]
510 fn test_weighted_sum_kernel_matrix() {
511 let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
512 let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
513
514 let kernel = WeightedSumKernel::uniform(vec![linear, cosine]).unwrap();
515
516 let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
517
518 let matrix = kernel.compute_matrix(&inputs).unwrap();
519 assert_eq!(matrix.len(), 3);
520 assert_eq!(matrix[0].len(), 3);
521
522 for i in 0..3 {
524 for j in 0..3 {
525 assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
526 }
527 }
528 }
529}