Skip to main content

tensorlogic_sklears_kernels/deep_kernel/
kernel.rs

1//! The [`DeepKernel`] type — a Deep Kernel Learning wrapper that
2//! composes a base kernel with a neural feature extractor.
3//!
4//! Given a base kernel `K_base` and a feature map `g_θ`, the Deep
5//! Kernel is
6//!
7//! ```text
8//! K_DKL(x, y) = K_base(g_θ(x), g_θ(y)).
9//! ```
10//!
11//! This generic wrapper implements the crate-level [`Kernel`] trait so a
12//! `DeepKernel` can slot into any downstream machinery that consumes
13//! `dyn Kernel` (SVM adapters, Gram-matrix utilities, kernel-alignment
14//! search, etc.).
15//!
16//! The base kernel and feature extractor are both owned by the
17//! wrapper. Cloning clones both; mutating parameters requires holding a
18//! `&mut DeepKernel` and going through [`DeepKernel::feature_extractor_mut`].
19
20use std::fmt;
21
22use crate::deep_kernel::feature_extractor::NeuralFeatureMap;
23use crate::error::Result;
24use crate::types::Kernel;
25
26/// Composition of a neural feature extractor with a classical kernel.
27///
28/// `F` — the neural feature extractor (e.g.
29/// [`crate::deep_kernel::MLPFeatureExtractor`]).
30///
31/// `K` — the base kernel (e.g. [`crate::RbfKernel`]).
32#[derive(Clone, Debug)]
33pub struct DeepKernel<F: NeuralFeatureMap, K: Kernel> {
34    extractor: F,
35    base: K,
36}
37
38impl<F: NeuralFeatureMap, K: Kernel> DeepKernel<F, K> {
39    /// Compose a feature extractor with a base kernel.
40    pub fn new(extractor: F, base: K) -> Self {
41        Self { extractor, base }
42    }
43
44    /// Immutable view of the feature extractor.
45    pub fn feature_extractor(&self) -> &F {
46        &self.extractor
47    }
48
49    /// Mutable view of the feature extractor — needed by optimisers
50    /// that write into `parameters_mut()` and then call `sync_from_flat`
51    /// on a concrete MLP.
52    pub fn feature_extractor_mut(&mut self) -> &mut F {
53        &mut self.extractor
54    }
55
56    /// Immutable view of the base kernel.
57    pub fn base_kernel(&self) -> &K {
58        &self.base
59    }
60
61    /// Apply the feature map to a single input.
62    pub fn features(&self, x: &[f64]) -> Result<Vec<f64>> {
63        self.extractor.forward(x)
64    }
65
66    /// Evaluate the composed kernel on a single input pair.
67    pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
68        let fx = self.extractor.forward(x)?;
69        let fy = self.extractor.forward(y)?;
70        self.base.compute(&fx, &fy)
71    }
72
73    /// Compute a Gram matrix `G[i,j] = K_DKL(xs[i], ys[j])`.
74    ///
75    /// Feature maps are cached — each `xs[i]` is passed through the
76    /// extractor at most once, same for `ys[j]`. For square `xs == ys`
77    /// callers should prefer [`Self::compute_symmetric_gram`].
78    pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
79        let fx: Vec<Vec<f64>> = xs
80            .iter()
81            .map(|x| self.extractor.forward(x))
82            .collect::<Result<Vec<_>>>()?;
83        let fy: Vec<Vec<f64>> = ys
84            .iter()
85            .map(|y| self.extractor.forward(y))
86            .collect::<Result<Vec<_>>>()?;
87        let mut matrix = vec![vec![0.0; fy.len()]; fx.len()];
88        for i in 0..fx.len() {
89            for j in 0..fy.len() {
90                matrix[i][j] = self.base.compute(&fx[i], &fy[j])?;
91            }
92        }
93        Ok(matrix)
94    }
95
96    /// Symmetric Gram matrix for a single input set.
97    ///
98    /// Takes advantage of `K(a, b) == K(b, a)` to halve the base-kernel
99    /// evaluations; still calls the feature extractor `n` times.
100    pub fn compute_symmetric_gram(&self, xs: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
101        let fx: Vec<Vec<f64>> = xs
102            .iter()
103            .map(|x| self.extractor.forward(x))
104            .collect::<Result<Vec<_>>>()?;
105        let n = fx.len();
106        let mut matrix = vec![vec![0.0; n]; n];
107        for i in 0..n {
108            for j in i..n {
109                let v = self.base.compute(&fx[i], &fx[j])?;
110                matrix[i][j] = v;
111                matrix[j][i] = v;
112            }
113        }
114        Ok(matrix)
115    }
116}
117
118impl<F: NeuralFeatureMap, K: Kernel> Kernel for DeepKernel<F, K> {
119    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
120        self.evaluate(x, y)
121    }
122
123    fn name(&self) -> &str {
124        "DeepKernel"
125    }
126
127    fn is_psd(&self) -> bool {
128        // K_DKL(x, y) = K_base(g(x), g(y)) is PSD iff K_base is PSD, by
129        // the classical "kernels are closed under composition with
130        // arbitrary maps" result.
131        self.base.is_psd()
132    }
133}
134
135/// Helper trait for kinds of feature extractor whose output dimension
136/// matches the base kernel's expected input dimension. Implemented
137/// automatically for every `NeuralFeatureMap`; exists purely as a
138/// documentation anchor.
139pub trait FeatureMapShape {
140    /// Output dimension of the feature map — i.e. the dimension the
141    /// base kernel will see.
142    fn feature_dim(&self) -> usize;
143}
144
145impl<M: NeuralFeatureMap> FeatureMapShape for M {
146    fn feature_dim(&self) -> usize {
147        self.output_dim()
148    }
149}
150
151/// Debug helper — prints extractor shape and base kernel name.
152pub struct DeepKernelSummary<'a, F, K>
153where
154    F: NeuralFeatureMap,
155    K: Kernel,
156{
157    pub kernel: &'a DeepKernel<F, K>,
158}
159
160impl<F, K> fmt::Display for DeepKernelSummary<'_, F, K>
161where
162    F: NeuralFeatureMap,
163    K: Kernel,
164{
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(
167            f,
168            "DeepKernel(in={}, features={}, base={})",
169            self.kernel.extractor.input_dim(),
170            self.kernel.extractor.output_dim(),
171            self.kernel.base.name()
172        )
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
180    use crate::deep_kernel::layer::{Activation, DenseLayer};
181    use crate::types::RbfKernelConfig;
182    use crate::{LinearKernel, RbfKernel};
183
184    fn identity_mlp_1x1() -> MLPFeatureExtractor {
185        let layer =
186            DenseLayer::new(vec![vec![1.0]], vec![0.0], Activation::Identity).expect("valid");
187        MLPFeatureExtractor::from_layers(vec![layer]).expect("valid")
188    }
189
190    #[test]
191    fn deep_kernel_with_identity_equals_base() {
192        let linear = LinearKernel::new();
193        let dkl = DeepKernel::new(identity_mlp_1x1(), linear);
194        let expected = LinearKernel::new().compute(&[3.0], &[4.0]).expect("linear");
195        let got = dkl.compute(&[3.0], &[4.0]).expect("deep");
196        assert!((got - expected).abs() < 1e-12);
197    }
198
199    #[test]
200    fn deep_kernel_propagates_psd_from_base() {
201        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
202        let dkl = DeepKernel::new(identity_mlp_1x1(), rbf);
203        assert!(dkl.is_psd());
204    }
205}