tensorlogic_sklears_kernels/deep_kernel/
gradient.rs1use crate::deep_kernel::feature_extractor::{LayerCache, MLPFeatureExtractor, NeuralFeatureMap};
33use crate::deep_kernel::kernel::DeepKernel;
34use crate::deep_kernel::layer::Activation;
35use crate::error::{KernelError, Result};
36use crate::tensor_kernels::RbfKernel;
37use crate::types::Kernel;
38
39pub fn finite_difference_gradient<K: Kernel>(
48 kernel: &mut DeepKernel<MLPFeatureExtractor, K>,
49 x: &[f64],
50 y: &[f64],
51 h: f64,
52) -> Result<Vec<f64>> {
53 if !(h.is_finite() && h > 0.0) {
54 return Err(KernelError::InvalidParameter {
55 parameter: "h".to_string(),
56 value: h.to_string(),
57 reason: "finite-difference step must be a positive finite number".to_string(),
58 });
59 }
60 let p = kernel.feature_extractor().parameter_count();
61 let mut grad = Vec::with_capacity(p);
62 let baseline = kernel.feature_extractor().parameters().to_vec();
63 for i in 0..p {
64 let mut plus = baseline.clone();
65 plus[i] += h;
66 kernel
67 .feature_extractor_mut()
68 .parameters_mut()
69 .copy_from_slice(&plus);
70 kernel.feature_extractor_mut().sync_from_flat()?;
71 let f_plus = kernel.compute(x, y)?;
72
73 let mut minus = baseline.clone();
74 minus[i] -= h;
75 kernel
76 .feature_extractor_mut()
77 .parameters_mut()
78 .copy_from_slice(&minus);
79 kernel.feature_extractor_mut().sync_from_flat()?;
80 let f_minus = kernel.compute(x, y)?;
81
82 grad.push((f_plus - f_minus) / (2.0 * h));
83 }
84 kernel
87 .feature_extractor_mut()
88 .parameters_mut()
89 .copy_from_slice(&baseline);
90 kernel.feature_extractor_mut().sync_from_flat()?;
91 Ok(grad)
92}
93
94pub fn rbf_dkl_gradient(
106 kernel: &DeepKernel<MLPFeatureExtractor, RbfKernel>,
107 x: &[f64],
108 y: &[f64],
109) -> Result<Vec<f64>> {
110 let mlp = kernel.feature_extractor();
111 let (fx, cache_x) = mlp.forward_with_cache(x)?;
112 let (fy, cache_y) = mlp.forward_with_cache(y)?;
113 if fx.len() != fy.len() {
114 return Err(KernelError::DimensionMismatch {
115 expected: vec![fx.len()],
116 got: vec![fy.len()],
117 context: "rbf_dkl_gradient: feature dims".to_string(),
118 });
119 }
120 let diff: Vec<f64> = fx.iter().zip(fy.iter()).map(|(a, b)| a - b).collect();
121 let sq_dist: f64 = diff.iter().map(|d| d * d).sum();
122 let gamma = kernel.base_kernel().gamma();
123 let k_val = (-gamma * sq_dist).exp();
124 let seed_x: Vec<f64> = diff.iter().map(|d| -2.0 * gamma * d * k_val).collect();
126 let seed_y: Vec<f64> = diff.iter().map(|d| 2.0 * gamma * d * k_val).collect();
127
128 let mut grad = vec![0.0; mlp.parameter_count()];
129 accumulate_backward(mlp, &cache_x, x, &seed_x, &mut grad)?;
130 accumulate_backward(mlp, &cache_y, y, &seed_y, &mut grad)?;
131 Ok(grad)
132}
133
134fn accumulate_backward(
147 mlp: &MLPFeatureExtractor,
148 caches: &[LayerCache],
149 input: &[f64],
150 seed: &[f64],
151 out_grad: &mut [f64],
152) -> Result<()> {
153 let layers = mlp.layers();
154 if caches.len() != layers.len() {
155 return Err(KernelError::DimensionMismatch {
156 expected: vec![layers.len()],
157 got: vec![caches.len()],
158 context: "accumulate_backward: cache length".to_string(),
159 });
160 }
161 let mut offsets = Vec::with_capacity(layers.len());
164 let mut running = 0usize;
165 for layer in layers {
166 offsets.push(running);
167 running += layer.parameter_count();
168 }
169
170 let mut delta = seed.to_vec();
172 if delta.len() != layers[layers.len() - 1].output_dim() {
173 return Err(KernelError::DimensionMismatch {
174 expected: vec![layers[layers.len() - 1].output_dim()],
175 got: vec![delta.len()],
176 context: "accumulate_backward: seed length".to_string(),
177 });
178 }
179
180 for layer_idx in (0..layers.len()).rev() {
181 let layer = &layers[layer_idx];
182 let (pre, _post) = &caches[layer_idx];
183 let activation = layer.activation();
185 let mut delta_pre = Vec::with_capacity(pre.len());
186 for (d, &p) in delta.iter().zip(pre.iter()) {
187 delta_pre.push(d * derivative(activation, p));
188 }
189 let prev_activation: &[f64] = if layer_idx == 0 {
191 input
192 } else {
193 &caches[layer_idx - 1].1
194 };
195 let w_base = offsets[layer_idx];
199 let in_dim = layer.input_dim();
200 let out_dim = layer.output_dim();
201 for (i, &dpre_i) in delta_pre.iter().enumerate() {
202 let row_offset = w_base + i * in_dim;
203 for (j, &prev_j) in prev_activation.iter().enumerate() {
204 out_grad[row_offset + j] += dpre_i * prev_j;
205 }
206 }
207 let b_base = w_base + out_dim * in_dim;
208 for (i, &dpre_i) in delta_pre.iter().enumerate() {
209 out_grad[b_base + i] += dpre_i;
210 }
211 if layer_idx > 0 {
213 let mut new_delta = vec![0.0; in_dim];
214 for (i, &dpre_i) in delta_pre.iter().enumerate() {
215 let row = &layer.weights[i];
216 for (j, &w_ij) in row.iter().enumerate() {
217 new_delta[j] += dpre_i * w_ij;
218 }
219 }
220 delta = new_delta;
221 }
222 }
223 Ok(())
224}
225
226fn derivative(activation: Activation, z: f64) -> f64 {
229 activation.derivative(z)
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
236 use crate::deep_kernel::kernel::DeepKernel;
237 use crate::deep_kernel::layer::Activation;
238 use crate::types::RbfKernelConfig;
239
240 fn mini_mlp(seed: u64) -> MLPFeatureExtractor {
241 MLPFeatureExtractor::xavier_init(
242 &[2, 3, 2],
243 &[Activation::Tanh, Activation::Identity],
244 seed,
245 )
246 .expect("xavier init")
247 }
248
249 #[test]
250 fn analytical_matches_finite_difference_for_rbf_mlp() {
251 let mlp = mini_mlp(17);
252 let rbf = RbfKernel::new(RbfKernelConfig::new(0.8)).expect("valid");
253 let mut dkl = DeepKernel::new(mlp, rbf);
254
255 let x = vec![0.3, -0.5];
256 let y = vec![-0.2, 0.4];
257 let analytical = rbf_dkl_gradient(&dkl, &x, &y).expect("analytical");
258 let numerical = finite_difference_gradient(&mut dkl, &x, &y, 1e-5).expect("finite diff");
259 assert_eq!(analytical.len(), numerical.len());
260 for (i, (a, n)) in analytical.iter().zip(numerical.iter()).enumerate() {
261 assert!(
262 (a - n).abs() < 1e-3,
263 "param {} mismatch: analytical={}, numerical={}",
264 i,
265 a,
266 n
267 );
268 }
269 }
270
271 #[test]
272 fn finite_difference_restores_parameters() {
273 let mlp = mini_mlp(11);
274 let before = mlp.parameters().to_vec();
275 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
276 let mut dkl = DeepKernel::new(mlp, rbf);
277 let _ = finite_difference_gradient(&mut dkl, &[0.2, 0.1], &[-0.1, 0.3], 1e-5)
278 .expect("finite diff");
279 let after = dkl.feature_extractor().parameters().to_vec();
280 for (a, b) in before.iter().zip(after.iter()) {
281 assert!((a - b).abs() < 1e-12);
282 }
283 }
284
285 #[test]
286 fn finite_difference_rejects_zero_step() {
287 let mlp = mini_mlp(0);
288 let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
289 let mut dkl = DeepKernel::new(mlp, rbf);
290 let err = finite_difference_gradient(&mut dkl, &[0.0, 0.0], &[0.0, 0.0], 0.0)
291 .expect_err("zero step must fail");
292 assert!(matches!(err, KernelError::InvalidParameter { .. }));
293 }
294}