Skip to main content

tensorlogic_sklears_kernels/deep_kernel/
gradient.rs

1//! Gradient helpers for [`DeepKernel`]s.
2//!
3//! Two paths are provided:
4//!
5//! * [`finite_difference_gradient`] — central differences over the flat
6//!   parameter buffer. Works for any base kernel and any feature
7//!   extractor that implements [`NeuralFeatureMap`]; `O(2P)` forward
8//!   passes in the number of parameters `P` and used by the crate's own
9//!   correctness tests as a reference.
10//! * [`rbf_dkl_gradient`] — the analytical gradient for the
11//!   RBF-base / MLP-extractor special case. Closed form:
12//!
13//!   `∂K_DKL / ∂θ = K_DKL · (-2γ) · Σ_k (g(x) - g(y))_k · ∂g_k(x)/∂θ`
14//!   `             + K_DKL · ( 2γ) · Σ_k (g(x) - g(y))_k · ∂g_k(y)/∂θ`
15//!
16//!   (the two sums come from `∂/∂θ || g(x) - g(y) ||²`). The Jacobians
17//!   `∂g_k(·)/∂θ` are obtained by standard MLP backprop, reusing the
18//!   per-layer pre/post-activation cache produced by
19//!   [`MLPFeatureExtractor::forward_with_cache`].
20//!
21//! # Scope (v0.2.0 preview)
22//!
23//! * Analytical chain rule is implemented for the
24//!   [`MLPFeatureExtractor`] + [`RbfKernel`] pair only — i.e. the
25//!   paradigmatic DKL configuration. Other combinations must be
26//!   gradient-checked via finite differences; autodiff integration is
27//!   out of scope for this release.
28//! * Gradients w.r.t. base-kernel hyperparameters (e.g. the RBF `γ`)
29//!   are **not** implemented here; the mixture side of the workspace
30//!   (`learned_composition`) handles that use case.
31
32use crate::deep_kernel::feature_extractor::{LayerCache, MLPFeatureExtractor, NeuralFeatureMap};
33use crate::deep_kernel::kernel::DeepKernel;
34use crate::deep_kernel::layer::Activation;
35use crate::error::{KernelError, Result};
36use crate::tensor_kernels::RbfKernel;
37use crate::types::Kernel;
38
39/// Numerical gradient `∂K_DKL/∂θ` via central finite differences on the
40/// flat parameter buffer. Returns a vector of length
41/// `kernel.feature_extractor().parameter_count()`.
42///
43/// The caller must pass an `MLPFeatureExtractor` (or any extractor that
44/// shares the `parameters` / `sync_from_flat` contract) — the helper
45/// needs to perturb the flat buffer and then push the update back into
46/// the layer weights before the next forward pass.
47pub fn finite_difference_gradient<K: Kernel>(
48    kernel: &mut DeepKernel<MLPFeatureExtractor, K>,
49    x: &[f64],
50    y: &[f64],
51    h: f64,
52) -> Result<Vec<f64>> {
53    if !(h.is_finite() && h > 0.0) {
54        return Err(KernelError::InvalidParameter {
55            parameter: "h".to_string(),
56            value: h.to_string(),
57            reason: "finite-difference step must be a positive finite number".to_string(),
58        });
59    }
60    let p = kernel.feature_extractor().parameter_count();
61    let mut grad = Vec::with_capacity(p);
62    let baseline = kernel.feature_extractor().parameters().to_vec();
63    for i in 0..p {
64        let mut plus = baseline.clone();
65        plus[i] += h;
66        kernel
67            .feature_extractor_mut()
68            .parameters_mut()
69            .copy_from_slice(&plus);
70        kernel.feature_extractor_mut().sync_from_flat()?;
71        let f_plus = kernel.compute(x, y)?;
72
73        let mut minus = baseline.clone();
74        minus[i] -= h;
75        kernel
76            .feature_extractor_mut()
77            .parameters_mut()
78            .copy_from_slice(&minus);
79        kernel.feature_extractor_mut().sync_from_flat()?;
80        let f_minus = kernel.compute(x, y)?;
81
82        grad.push((f_plus - f_minus) / (2.0 * h));
83    }
84    // Restore the original parameter buffer so the kernel is unchanged
85    // from the caller's point of view.
86    kernel
87        .feature_extractor_mut()
88        .parameters_mut()
89        .copy_from_slice(&baseline);
90    kernel.feature_extractor_mut().sync_from_flat()?;
91    Ok(grad)
92}
93
94/// Analytical gradient of `K_DKL(x, y)` w.r.t. the MLP parameters for
95/// the RBF-base case. Returns a vector of length
96/// `kernel.feature_extractor().parameter_count()` whose entries mirror
97/// the flat parameter layout
98/// `layer0.weights(row-major) ++ layer0.biases ++ layer1.weights ++ ...`.
99///
100/// The closed form is derived in the module doc-comment. The implementation
101/// performs one forward pass with per-layer caches on each of `x` and
102/// `y`, computes the output-space difference vector `Δ = g(x) - g(y)`,
103/// and back-propagates it through both networks to accumulate the
104/// per-parameter gradient.
105pub fn rbf_dkl_gradient(
106    kernel: &DeepKernel<MLPFeatureExtractor, RbfKernel>,
107    x: &[f64],
108    y: &[f64],
109) -> Result<Vec<f64>> {
110    let mlp = kernel.feature_extractor();
111    let (fx, cache_x) = mlp.forward_with_cache(x)?;
112    let (fy, cache_y) = mlp.forward_with_cache(y)?;
113    if fx.len() != fy.len() {
114        return Err(KernelError::DimensionMismatch {
115            expected: vec![fx.len()],
116            got: vec![fy.len()],
117            context: "rbf_dkl_gradient: feature dims".to_string(),
118        });
119    }
120    let diff: Vec<f64> = fx.iter().zip(fy.iter()).map(|(a, b)| a - b).collect();
121    let sq_dist: f64 = diff.iter().map(|d| d * d).sum();
122    let gamma = kernel.base_kernel().gamma();
123    let k_val = (-gamma * sq_dist).exp();
124    // Seed vectors for backprop: ∂K/∂g(x)_k = -2γ·Δ_k·K, ∂K/∂g(y)_k = +2γ·Δ_k·K.
125    let seed_x: Vec<f64> = diff.iter().map(|d| -2.0 * gamma * d * k_val).collect();
126    let seed_y: Vec<f64> = diff.iter().map(|d| 2.0 * gamma * d * k_val).collect();
127
128    let mut grad = vec![0.0; mlp.parameter_count()];
129    accumulate_backward(mlp, &cache_x, x, &seed_x, &mut grad)?;
130    accumulate_backward(mlp, &cache_y, y, &seed_y, &mut grad)?;
131    Ok(grad)
132}
133
134/// Backpropagate an output-space gradient through an MLP and accumulate
135/// the per-parameter gradient into `out_grad`.
136///
137/// * `caches` holds `(pre_activation, post_activation)` for each layer
138///   as produced by `MLPFeatureExtractor::forward_with_cache`.
139/// * `input` is the original input to layer 0 (so we can form the
140///   Jacobian of layer 0's weights w.r.t. the input).
141/// * `seed` is `∂K/∂(output of MLP)`, length = output dimension.
142///
143/// Mutates `out_grad` in place. The layout matches
144/// `flatten_layers` in `feature_extractor.rs`:
145/// `(layer0.weights row-major, layer0.biases, layer1.weights, ...)`.
146fn accumulate_backward(
147    mlp: &MLPFeatureExtractor,
148    caches: &[LayerCache],
149    input: &[f64],
150    seed: &[f64],
151    out_grad: &mut [f64],
152) -> Result<()> {
153    let layers = mlp.layers();
154    if caches.len() != layers.len() {
155        return Err(KernelError::DimensionMismatch {
156            expected: vec![layers.len()],
157            got: vec![caches.len()],
158            context: "accumulate_backward: cache length".to_string(),
159        });
160    }
161    // Build a reverse offset table — offsets[i] points at the first
162    // parameter slot for layer `i` inside the flat buffer.
163    let mut offsets = Vec::with_capacity(layers.len());
164    let mut running = 0usize;
165    for layer in layers {
166        offsets.push(running);
167        running += layer.parameter_count();
168    }
169
170    // `delta` holds ∂K / ∂(post-activation of the current layer).
171    let mut delta = seed.to_vec();
172    if delta.len() != layers[layers.len() - 1].output_dim() {
173        return Err(KernelError::DimensionMismatch {
174            expected: vec![layers[layers.len() - 1].output_dim()],
175            got: vec![delta.len()],
176            context: "accumulate_backward: seed length".to_string(),
177        });
178    }
179
180    for layer_idx in (0..layers.len()).rev() {
181        let layer = &layers[layer_idx];
182        let (pre, _post) = &caches[layer_idx];
183        // ∂K / ∂(pre-activation of this layer) = delta * f'(pre).
184        let activation = layer.activation();
185        let mut delta_pre = Vec::with_capacity(pre.len());
186        for (d, &p) in delta.iter().zip(pre.iter()) {
187            delta_pre.push(d * derivative(activation, p));
188        }
189        // Previous-layer activations (input to this layer).
190        let prev_activation: &[f64] = if layer_idx == 0 {
191            input
192        } else {
193            &caches[layer_idx - 1].1
194        };
195        // Gradients on this layer's parameters.
196        // weight grad: (∂K/∂pre[i]) * prev_activation[j] for the
197        // entry weights[i][j].
198        let w_base = offsets[layer_idx];
199        let in_dim = layer.input_dim();
200        let out_dim = layer.output_dim();
201        for (i, &dpre_i) in delta_pre.iter().enumerate() {
202            let row_offset = w_base + i * in_dim;
203            for (j, &prev_j) in prev_activation.iter().enumerate() {
204                out_grad[row_offset + j] += dpre_i * prev_j;
205            }
206        }
207        let b_base = w_base + out_dim * in_dim;
208        for (i, &dpre_i) in delta_pre.iter().enumerate() {
209            out_grad[b_base + i] += dpre_i;
210        }
211        // Propagate delta backward through the affine layer.
212        if layer_idx > 0 {
213            let mut new_delta = vec![0.0; in_dim];
214            for (i, &dpre_i) in delta_pre.iter().enumerate() {
215                let row = &layer.weights[i];
216                for (j, &w_ij) in row.iter().enumerate() {
217                    new_delta[j] += dpre_i * w_ij;
218                }
219            }
220            delta = new_delta;
221        }
222    }
223    Ok(())
224}
225
226/// Local wrapper around [`Activation::derivative`] so `accumulate_backward`
227/// does not need to import the enum directly (keeps imports tight).
228fn derivative(activation: Activation, z: f64) -> f64 {
229    activation.derivative(z)
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
236    use crate::deep_kernel::kernel::DeepKernel;
237    use crate::deep_kernel::layer::Activation;
238    use crate::types::RbfKernelConfig;
239
240    fn mini_mlp(seed: u64) -> MLPFeatureExtractor {
241        MLPFeatureExtractor::xavier_init(
242            &[2, 3, 2],
243            &[Activation::Tanh, Activation::Identity],
244            seed,
245        )
246        .expect("xavier init")
247    }
248
249    #[test]
250    fn analytical_matches_finite_difference_for_rbf_mlp() {
251        let mlp = mini_mlp(17);
252        let rbf = RbfKernel::new(RbfKernelConfig::new(0.8)).expect("valid");
253        let mut dkl = DeepKernel::new(mlp, rbf);
254
255        let x = vec![0.3, -0.5];
256        let y = vec![-0.2, 0.4];
257        let analytical = rbf_dkl_gradient(&dkl, &x, &y).expect("analytical");
258        let numerical = finite_difference_gradient(&mut dkl, &x, &y, 1e-5).expect("finite diff");
259        assert_eq!(analytical.len(), numerical.len());
260        for (i, (a, n)) in analytical.iter().zip(numerical.iter()).enumerate() {
261            assert!(
262                (a - n).abs() < 1e-3,
263                "param {} mismatch: analytical={}, numerical={}",
264                i,
265                a,
266                n
267            );
268        }
269    }
270
271    #[test]
272    fn finite_difference_restores_parameters() {
273        let mlp = mini_mlp(11);
274        let before = mlp.parameters().to_vec();
275        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
276        let mut dkl = DeepKernel::new(mlp, rbf);
277        let _ = finite_difference_gradient(&mut dkl, &[0.2, 0.1], &[-0.1, 0.3], 1e-5)
278            .expect("finite diff");
279        let after = dkl.feature_extractor().parameters().to_vec();
280        for (a, b) in before.iter().zip(after.iter()) {
281            assert!((a - b).abs() < 1e-12);
282        }
283    }
284
285    #[test]
286    fn finite_difference_rejects_zero_step() {
287        let mlp = mini_mlp(0);
288        let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
289        let mut dkl = DeepKernel::new(mlp, rbf);
290        let err = finite_difference_gradient(&mut dkl, &[0.0, 0.0], &[0.0, 0.0], 0.0)
291            .expect_err("zero step must fail");
292        assert!(matches!(err, KernelError::InvalidParameter { .. }));
293    }
294}