Skip to main content

tensorlogic_sklears_kernels/
composite_kernel.rs

1//! Composite kernels for combining multiple kernel functions.
2//!
3//! This module provides ways to combine existing kernels through:
4//! - Weighted sum (convex combinations)
5//! - Product (multiplicative combinations)
6//! - Kernel alignment (meta-learning)
7
8use std::sync::Arc;
9
10use crate::error::{KernelError, Result};
11use crate::types::Kernel;
12
13/// Weighted sum of multiple kernels: K(x,y) = Σ_i w_i * K_i(x,y)
14///
15/// Combines multiple kernels using weighted averaging.
16/// Weights should sum to 1.0 for proper normalization.
17///
18/// # Example
19///
20/// ```rust
21/// use tensorlogic_sklears_kernels::{
22///     LinearKernel, RbfKernel, RbfKernelConfig,
23///     WeightedSumKernel, Kernel
24/// };
25///
26/// let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
27/// let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
28///
29/// let weights = vec![0.7, 0.3];
30/// let composite = WeightedSumKernel::new(vec![linear, rbf], weights).unwrap();
31///
32/// let x = vec![1.0, 2.0, 3.0];
33/// let y = vec![4.0, 5.0, 6.0];
34/// let sim = composite.compute(&x, &y).unwrap();
35/// // sim = 0.7 * linear(x,y) + 0.3 * rbf(x,y)
36/// ```
37pub struct WeightedSumKernel {
38    /// Component kernels
39    kernels: Vec<Arc<dyn Kernel>>,
40    /// Weights for each kernel
41    weights: Vec<f64>,
42    /// Whether weights are normalized
43    normalized: bool,
44}
45
46impl WeightedSumKernel {
47    /// Create a new weighted sum kernel
48    pub fn new(kernels: Vec<Box<dyn Kernel>>, weights: Vec<f64>) -> Result<Self> {
49        if kernels.is_empty() {
50            return Err(KernelError::InvalidParameter {
51                parameter: "kernels".to_string(),
52                value: "empty".to_string(),
53                reason: "at least one kernel required".to_string(),
54            });
55        }
56
57        if kernels.len() != weights.len() {
58            return Err(KernelError::DimensionMismatch {
59                expected: vec![kernels.len()],
60                got: vec![weights.len()],
61                context: "weighted sum kernel".to_string(),
62            });
63        }
64
65        // Check weights are non-negative
66        if weights.iter().any(|&w| w < 0.0) {
67            return Err(KernelError::InvalidParameter {
68                parameter: "weights".to_string(),
69                value: format!("{:?}", weights),
70                reason: "all weights must be non-negative".to_string(),
71            });
72        }
73
74        let weight_sum: f64 = weights.iter().sum();
75        if weight_sum <= 0.0 {
76            return Err(KernelError::InvalidParameter {
77                parameter: "weights".to_string(),
78                value: format!("{:?}", weights),
79                reason: "weights must sum to a positive value".to_string(),
80            });
81        }
82
83        // Convert Box to Arc for shared ownership
84        let kernels: Vec<Arc<dyn Kernel>> = kernels.into_iter().map(Arc::from).collect();
85
86        Ok(Self {
87            kernels,
88            weights,
89            normalized: false,
90        })
91    }
92
93    /// Create with normalized weights (sum to 1.0)
94    pub fn new_normalized(kernels: Vec<Box<dyn Kernel>>, mut weights: Vec<f64>) -> Result<Self> {
95        let weight_sum: f64 = weights.iter().sum();
96        if weight_sum <= 0.0 {
97            return Err(KernelError::InvalidParameter {
98                parameter: "weights".to_string(),
99                value: format!("{:?}", weights),
100                reason: "weights must sum to a positive value".to_string(),
101            });
102        }
103
104        // Normalize weights
105        for w in &mut weights {
106            *w /= weight_sum;
107        }
108
109        let mut kernel = Self::new(kernels, weights)?;
110        kernel.normalized = true;
111        Ok(kernel)
112    }
113
114    /// Create with uniform weights
115    pub fn uniform(kernels: Vec<Box<dyn Kernel>>) -> Result<Self> {
116        let n = kernels.len();
117        if n == 0 {
118            return Err(KernelError::InvalidParameter {
119                parameter: "kernels".to_string(),
120                value: "empty".to_string(),
121                reason: "at least one kernel required".to_string(),
122            });
123        }
124
125        let weights = vec![1.0 / n as f64; n];
126        Self::new_normalized(kernels, weights)
127    }
128}
129
130impl Kernel for WeightedSumKernel {
131    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
132        let mut result = 0.0;
133
134        for (kernel, &weight) in self.kernels.iter().zip(self.weights.iter()) {
135            let value = kernel.compute(x, y)?;
136            result += weight * value;
137        }
138
139        Ok(result)
140    }
141
142    fn name(&self) -> &str {
143        "WeightedSum"
144    }
145
146    fn is_psd(&self) -> bool {
147        // Weighted sum of PSD kernels is PSD if weights are non-negative
148        self.weights.iter().all(|&w| w >= 0.0) && self.kernels.iter().all(|k| k.is_psd())
149    }
150}
151
152/// Product of multiple kernels: K(x,y) = Π_i K_i(x,y)
153///
154/// Combines kernels through multiplication.
155/// The resulting kernel corresponds to the tensor product of feature spaces.
156///
157/// # Example
158///
159/// ```rust
160/// use tensorlogic_sklears_kernels::{
161///     LinearKernel, CosineKernel,
162///     ProductKernel, Kernel
163/// };
164///
165/// let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
166/// let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
167///
168/// let product = ProductKernel::new(vec![linear, cosine]).unwrap();
169///
170/// let x = vec![1.0, 2.0, 3.0];
171/// let y = vec![4.0, 5.0, 6.0];
172/// let sim = product.compute(&x, &y).unwrap();
173/// // sim = linear(x,y) * cosine(x,y)
174/// ```
175pub struct ProductKernel {
176    /// Component kernels
177    kernels: Vec<Arc<dyn Kernel>>,
178}
179
180impl ProductKernel {
181    /// Create a new product kernel
182    pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Result<Self> {
183        if kernels.is_empty() {
184            return Err(KernelError::InvalidParameter {
185                parameter: "kernels".to_string(),
186                value: "empty".to_string(),
187                reason: "at least one kernel required".to_string(),
188            });
189        }
190
191        // Convert Box to Arc for shared ownership
192        let kernels: Vec<Arc<dyn Kernel>> = kernels.into_iter().map(Arc::from).collect();
193
194        Ok(Self { kernels })
195    }
196}
197
198impl Kernel for ProductKernel {
199    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
200        let mut result = 1.0;
201
202        for kernel in &self.kernels {
203            let value = kernel.compute(x, y)?;
204            result *= value;
205        }
206
207        Ok(result)
208    }
209
210    fn name(&self) -> &str {
211        "Product"
212    }
213
214    fn is_psd(&self) -> bool {
215        // Product of PSD kernels is PSD
216        self.kernels.iter().all(|k| k.is_psd())
217    }
218}
219
220/// Kernel alignment computation for measuring similarity between kernels.
221///
222/// Kernel alignment measures how well two kernels agree on a dataset.
223/// It's useful for kernel selection and meta-learning.
224///
225/// # Formula
226///
227/// ```text
228/// A(K1, K2) = <K1, K2>_F / (||K1||_F * ||K2||_F)
229/// ```
230///
231/// Where `<·,·>_F` is the Frobenius inner product and `||·||_F` is the Frobenius norm.
232pub struct KernelAlignment;
233
234impl KernelAlignment {
235    /// Compute centered kernel alignment between two kernel matrices
236    ///
237    /// # Arguments
238    /// * `k1` - First kernel matrix
239    /// * `k2` - Second kernel matrix
240    ///
241    /// # Returns
242    /// Alignment score in range [-1, 1], where 1 means perfect alignment
243    pub fn compute_alignment(k1: &[Vec<f64>], k2: &[Vec<f64>]) -> Result<f64> {
244        if k1.is_empty() || k2.is_empty() {
245            return Err(KernelError::InvalidParameter {
246                parameter: "kernel_matrices".to_string(),
247                value: "empty".to_string(),
248                reason: "kernel matrices cannot be empty".to_string(),
249            });
250        }
251
252        let n1 = k1.len();
253        let n2 = k2.len();
254
255        if n1 != n2 {
256            return Err(KernelError::DimensionMismatch {
257                expected: vec![n1, n1],
258                got: vec![n2, n2],
259                context: "kernel alignment".to_string(),
260            });
261        }
262
263        // Check square matrices
264        for (i, row) in k1.iter().enumerate() {
265            if row.len() != n1 {
266                return Err(KernelError::DimensionMismatch {
267                    expected: vec![n1],
268                    got: vec![row.len()],
269                    context: format!("k1 row {}", i),
270                });
271            }
272        }
273
274        for (i, row) in k2.iter().enumerate() {
275            if row.len() != n2 {
276                return Err(KernelError::DimensionMismatch {
277                    expected: vec![n2],
278                    got: vec![row.len()],
279                    context: format!("k2 row {}", i),
280                });
281            }
282        }
283
284        // Center the kernel matrices
285        let k1_centered = Self::center_kernel_matrix(k1);
286        let k2_centered = Self::center_kernel_matrix(k2);
287
288        // Compute Frobenius inner product
289        let mut inner_product = 0.0;
290        for i in 0..n1 {
291            for j in 0..n1 {
292                inner_product += k1_centered[i][j] * k2_centered[i][j];
293            }
294        }
295
296        // Compute Frobenius norms
297        let norm1 = Self::frobenius_norm(&k1_centered);
298        let norm2 = Self::frobenius_norm(&k2_centered);
299
300        if norm1 == 0.0 || norm2 == 0.0 {
301            return Ok(0.0);
302        }
303
304        Ok(inner_product / (norm1 * norm2))
305    }
306
307    /// Center a kernel matrix
308    #[allow(clippy::needless_range_loop)]
309    fn center_kernel_matrix(k: &[Vec<f64>]) -> Vec<Vec<f64>> {
310        let n = k.len();
311        let mut centered = vec![vec![0.0; n]; n];
312
313        // Compute row and column means
314        let mut row_means = vec![0.0; n];
315        let mut col_means = vec![0.0; n];
316        let mut total_mean = 0.0;
317
318        for i in 0..n {
319            for j in 0..n {
320                row_means[i] += k[i][j];
321                col_means[j] += k[i][j];
322                total_mean += k[i][j];
323            }
324        }
325
326        for mean in &mut row_means {
327            *mean /= n as f64;
328        }
329        for mean in &mut col_means {
330            *mean /= n as f64;
331        }
332        total_mean /= (n * n) as f64;
333
334        // Center the matrix
335        for i in 0..n {
336            for j in 0..n {
337                centered[i][j] = k[i][j] - row_means[i] - col_means[j] + total_mean;
338            }
339        }
340
341        centered
342    }
343
344    /// Compute Frobenius norm of a matrix
345    fn frobenius_norm(k: &[Vec<f64>]) -> f64 {
346        let mut sum_sq = 0.0;
347        for row in k {
348            for &val in row {
349                sum_sq += val * val;
350            }
351        }
352        sum_sq.sqrt()
353    }
354}
355
356#[cfg(test)]
357#[allow(clippy::needless_range_loop)]
358mod tests {
359    use super::*;
360    use crate::tensor_kernels::{CosineKernel, LinearKernel, RbfKernel};
361    use crate::types::RbfKernelConfig;
362
363    #[test]
364    fn test_weighted_sum_kernel() {
365        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
366        let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
367
368        let weights = vec![0.7, 0.3];
369        let kernel = WeightedSumKernel::new(vec![linear, rbf], weights).unwrap();
370
371        let x = vec![1.0, 2.0, 3.0];
372        let y = vec![4.0, 5.0, 6.0];
373
374        let result = kernel.compute(&x, &y).unwrap();
375        assert!(result > 0.0);
376        assert_eq!(kernel.name(), "WeightedSum");
377    }
378
379    #[test]
380    fn test_weighted_sum_normalized() {
381        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
382        let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
383
384        let weights = vec![2.0, 3.0]; // Will be normalized to [0.4, 0.6]
385        let kernel = WeightedSumKernel::new_normalized(vec![linear, cosine], weights).unwrap();
386
387        let x = vec![1.0, 2.0, 3.0];
388        let y = vec![4.0, 5.0, 6.0];
389
390        let result = kernel.compute(&x, &y).unwrap();
391        assert!(result > 0.0);
392    }
393
394    #[test]
395    fn test_weighted_sum_uniform() {
396        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
397        let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
398        let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
399
400        let kernel = WeightedSumKernel::uniform(vec![linear, cosine, rbf]).unwrap();
401
402        let x = vec![1.0, 2.0, 3.0];
403        let y = vec![4.0, 5.0, 6.0];
404
405        let result = kernel.compute(&x, &y).unwrap();
406        assert!(result > 0.0);
407    }
408
409    #[test]
410    fn test_weighted_sum_empty_kernels() {
411        let result = WeightedSumKernel::new(vec![], vec![]);
412        assert!(result.is_err());
413    }
414
415    #[test]
416    fn test_weighted_sum_dimension_mismatch() {
417        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
418        let result = WeightedSumKernel::new(vec![linear], vec![0.5, 0.5]);
419        assert!(result.is_err());
420    }
421
422    #[test]
423    fn test_weighted_sum_negative_weights() {
424        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
425        let result = WeightedSumKernel::new(vec![linear], vec![-0.5]);
426        assert!(result.is_err());
427    }
428
429    #[test]
430    fn test_product_kernel() {
431        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
432        let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
433
434        let kernel = ProductKernel::new(vec![linear, cosine]).unwrap();
435
436        let x = vec![1.0, 2.0, 3.0];
437        let y = vec![4.0, 5.0, 6.0];
438
439        let result = kernel.compute(&x, &y).unwrap();
440        assert!(result > 0.0);
441        assert_eq!(kernel.name(), "Product");
442    }
443
444    #[test]
445    fn test_product_kernel_empty() {
446        let result = ProductKernel::new(vec![]);
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn test_product_psd_property() {
452        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
453        let rbf = Box::new(RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap()) as Box<dyn Kernel>;
454
455        let kernel = ProductKernel::new(vec![linear, rbf]).unwrap();
456        assert!(kernel.is_psd());
457    }
458
459    #[test]
460    fn test_kernel_alignment() {
461        // Create two similar kernel matrices
462        let k1 = vec![
463            vec![1.0, 0.8, 0.6],
464            vec![0.8, 1.0, 0.7],
465            vec![0.6, 0.7, 1.0],
466        ];
467
468        let k2 = vec![
469            vec![1.0, 0.75, 0.55],
470            vec![0.75, 1.0, 0.65],
471            vec![0.55, 0.65, 1.0],
472        ];
473
474        let alignment = KernelAlignment::compute_alignment(&k1, &k2).unwrap();
475
476        // Similar matrices should have high alignment
477        assert!(alignment > 0.9);
478        assert!(alignment <= 1.0);
479    }
480
481    #[test]
482    fn test_kernel_alignment_identity() {
483        let k = vec![
484            vec![1.0, 0.5, 0.3],
485            vec![0.5, 1.0, 0.4],
486            vec![0.3, 0.4, 1.0],
487        ];
488
489        let alignment = KernelAlignment::compute_alignment(&k, &k).unwrap();
490
491        // A kernel should have perfect alignment with itself
492        assert!((alignment - 1.0).abs() < 1e-10);
493    }
494
495    #[test]
496    fn test_kernel_alignment_dimension_mismatch() {
497        let k1 = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
498
499        let k2 = vec![
500            vec![1.0, 0.5, 0.3],
501            vec![0.5, 1.0, 0.4],
502            vec![0.3, 0.4, 1.0],
503        ];
504
505        let result = KernelAlignment::compute_alignment(&k1, &k2);
506        assert!(result.is_err());
507    }
508
509    #[test]
510    fn test_weighted_sum_kernel_matrix() {
511        let linear = Box::new(LinearKernel::new()) as Box<dyn Kernel>;
512        let cosine = Box::new(CosineKernel::new()) as Box<dyn Kernel>;
513
514        let kernel = WeightedSumKernel::uniform(vec![linear, cosine]).unwrap();
515
516        let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
517
518        let matrix = kernel.compute_matrix(&inputs).unwrap();
519        assert_eq!(matrix.len(), 3);
520        assert_eq!(matrix[0].len(), 3);
521
522        // Check symmetry
523        for i in 0..3 {
524            for j in 0..3 {
525                assert!((matrix[i][j] - matrix[j][i]).abs() < 1e-10);
526            }
527        }
528    }
529}