tensorlogic_sklears_kernels/deep_kernel/
kernel.rs1use std::fmt;
21
22use crate::deep_kernel::feature_extractor::NeuralFeatureMap;
23use crate::error::Result;
24use crate::types::Kernel;
25
26#[derive(Clone, Debug)]
33pub struct DeepKernel<F: NeuralFeatureMap, K: Kernel> {
34 extractor: F,
35 base: K,
36}
37
38impl<F: NeuralFeatureMap, K: Kernel> DeepKernel<F, K> {
39 pub fn new(extractor: F, base: K) -> Self {
41 Self { extractor, base }
42 }
43
44 pub fn feature_extractor(&self) -> &F {
46 &self.extractor
47 }
48
49 pub fn feature_extractor_mut(&mut self) -> &mut F {
53 &mut self.extractor
54 }
55
56 pub fn base_kernel(&self) -> &K {
58 &self.base
59 }
60
61 pub fn features(&self, x: &[f64]) -> Result<Vec<f64>> {
63 self.extractor.forward(x)
64 }
65
66 pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
68 let fx = self.extractor.forward(x)?;
69 let fy = self.extractor.forward(y)?;
70 self.base.compute(&fx, &fy)
71 }
72
73 pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
79 let fx: Vec<Vec<f64>> = xs
80 .iter()
81 .map(|x| self.extractor.forward(x))
82 .collect::<Result<Vec<_>>>()?;
83 let fy: Vec<Vec<f64>> = ys
84 .iter()
85 .map(|y| self.extractor.forward(y))
86 .collect::<Result<Vec<_>>>()?;
87 let mut matrix = vec![vec![0.0; fy.len()]; fx.len()];
88 for i in 0..fx.len() {
89 for j in 0..fy.len() {
90 matrix[i][j] = self.base.compute(&fx[i], &fy[j])?;
91 }
92 }
93 Ok(matrix)
94 }
95
96 pub fn compute_symmetric_gram(&self, xs: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
101 let fx: Vec<Vec<f64>> = xs
102 .iter()
103 .map(|x| self.extractor.forward(x))
104 .collect::<Result<Vec<_>>>()?;
105 let n = fx.len();
106 let mut matrix = vec![vec![0.0; n]; n];
107 for i in 0..n {
108 for j in i..n {
109 let v = self.base.compute(&fx[i], &fx[j])?;
110 matrix[i][j] = v;
111 matrix[j][i] = v;
112 }
113 }
114 Ok(matrix)
115 }
116}
117
118impl<F: NeuralFeatureMap, K: Kernel> Kernel for DeepKernel<F, K> {
119 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
120 self.evaluate(x, y)
121 }
122
123 fn name(&self) -> &str {
124 "DeepKernel"
125 }
126
127 fn is_psd(&self) -> bool {
128 self.base.is_psd()
132 }
133}
134
135pub trait FeatureMapShape {
140 fn feature_dim(&self) -> usize;
143}
144
145impl<M: NeuralFeatureMap> FeatureMapShape for M {
146 fn feature_dim(&self) -> usize {
147 self.output_dim()
148 }
149}
150
151pub struct DeepKernelSummary<'a, F, K>
153where
154 F: NeuralFeatureMap,
155 K: Kernel,
156{
157 pub kernel: &'a DeepKernel<F, K>,
158}
159
160impl<F, K> fmt::Display for DeepKernelSummary<'_, F, K>
161where
162 F: NeuralFeatureMap,
163 K: Kernel,
164{
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 write!(
167 f,
168 "DeepKernel(in={}, features={}, base={})",
169 self.kernel.extractor.input_dim(),
170 self.kernel.extractor.output_dim(),
171 self.kernel.base.name()
172 )
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
180 use crate::deep_kernel::layer::{Activation, DenseLayer};
181 use crate::types::RbfKernelConfig;
182 use crate::{LinearKernel, RbfKernel};
183
184 fn identity_mlp_1x1() -> MLPFeatureExtractor {
185 let layer =
186 DenseLayer::new(vec![vec![1.0]], vec![0.0], Activation::Identity).expect("valid");
187 MLPFeatureExtractor::from_layers(vec![layer]).expect("valid")
188 }
189
190 #[test]
191 fn deep_kernel_with_identity_equals_base() {
192 let linear = LinearKernel::new();
193 let dkl = DeepKernel::new(identity_mlp_1x1(), linear);
194 let expected = LinearKernel::new().compute(&[3.0], &[4.0]).expect("linear");
195 let got = dkl.compute(&[3.0], &[4.0]).expect("deep");
196 assert!((got - expected).abs() < 1e-12);
197 }
198
199 #[test]
200 fn deep_kernel_propagates_psd_from_base() {
201 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
202 let dkl = DeepKernel::new(identity_mlp_1x1(), rbf);
203 assert!(dkl.is_psd());
204 }
205}