Skip to main content

tensorlogic_sklears_kernels/deep_kernel/
feature_extractor.rs

1//! Differentiable feature extractors for Deep Kernel Learning.
2//!
3//! [`NeuralFeatureMap`] is the trait an `F` must satisfy to be plugged
4//! into [`crate::deep_kernel::DeepKernel`]. The v0.2.0 preview ships a
5//! single implementation — [`MLPFeatureExtractor`] — a stack of
6//! [`DenseLayer`]s with ReLU / Tanh / Identity activations. Future
7//! releases may add CNN / Transformer extractors; the trait is designed
8//! to keep those additions out-of-tree until they are ready.
9//!
10//! # Naming note
11//!
12//! The trait is named `NeuralFeatureMap` (not `FeatureExtractor`) to
13//! avoid colliding with the pre-existing
14//! [`crate::feature_extraction::FeatureExtractor`] struct, which is
15//! specifically for turning [`tensorlogic_ir::TLExpr`] into numeric
16//! features. The two types coexist in the same crate and serve
17//! complementary purposes.
18
19use scirs2_core::random::{Normal, SeedableRng, StdRng};
20
21use crate::deep_kernel::layer::{Activation, DenseLayer};
22use crate::error::{KernelError, Result};
23
24/// Per-layer cache `(pre_activation, post_activation)` used by the
25/// analytical backprop path in [`crate::deep_kernel::gradient`].
26pub type LayerCache = (Vec<f64>, Vec<f64>);
27
28/// Bundle returned by [`MLPFeatureExtractor::forward_with_cache`] —
29/// `(final_output, per_layer_caches)`.
30pub type ForwardCache = (Vec<f64>, Vec<LayerCache>);
31
32/// A differentiable map `ℝ^{d_in} → ℝ^{d_out}` used as the feature
33/// extractor inside a Deep Kernel.
34///
35/// Implementations must be deterministic given their parameters
36/// (so `forward(x)` produces the same output for the same input at any
37/// point in time), and must be `Send + Sync` so they can be shared
38/// across threads like every other crate-level kernel.
39pub trait NeuralFeatureMap: Send + Sync {
40    /// Map an input vector to feature space.
41    fn forward(&self, input: &[f64]) -> Result<Vec<f64>>;
42
43    /// Mutable view of the flat parameter vector (weights + biases, in
44    /// layer order). Exposed as `&mut [f64]` so that optimisers can
45    /// apply updates in place without owning the extractor.
46    fn parameters_mut(&mut self) -> &mut [f64];
47
48    /// Immutable view of the flat parameter vector.
49    fn parameters(&self) -> &[f64];
50
51    /// Number of trainable scalar parameters.
52    fn parameter_count(&self) -> usize;
53
54    /// Input dimension expected by `forward`.
55    fn input_dim(&self) -> usize;
56
57    /// Output dimension produced by `forward`.
58    fn output_dim(&self) -> usize;
59}
60
61/// Multi-layer perceptron feature extractor.
62///
63/// The network is a sequence of [`DenseLayer`]s applied in order; the
64/// output of layer `i` is the input of layer `i + 1`. The layer stack
65/// must be non-empty and layer shapes must match transitively.
66///
67/// Parameters are stored twice on purpose:
68///
69/// * as structured `layers: Vec<DenseLayer>` — used by `forward`.
70/// * as flat `parameters: Vec<f64>` — exposed to optimisers.
71///
72/// The two views are kept in sync: [`MLPFeatureExtractor::parameters_mut`]
73/// returns a borrow into the flat buffer and
74/// [`MLPFeatureExtractor::sync_from_flat`] pushes the flat buffer back
75/// into the layer weights. Mutating the flat buffer directly requires a
76/// subsequent `sync_from_flat` call before the next `forward`.
77#[derive(Clone, Debug)]
78pub struct MLPFeatureExtractor {
79    layers: Vec<DenseLayer>,
80    parameters: Vec<f64>,
81}
82
83impl MLPFeatureExtractor {
84    /// Wrap an existing `Vec<DenseLayer>` as an MLP feature extractor.
85    ///
86    /// Fails when the layer stack is empty or when consecutive layer
87    /// shapes do not match.
88    pub fn from_layers(layers: Vec<DenseLayer>) -> Result<Self> {
89        if layers.is_empty() {
90            return Err(KernelError::InvalidParameter {
91                parameter: "layers".to_string(),
92                value: "[]".to_string(),
93                reason: "MLPFeatureExtractor requires at least one layer".to_string(),
94            });
95        }
96        for pair in layers.windows(2) {
97            let (a, b) = (&pair[0], &pair[1]);
98            if a.output_dim() != b.input_dim() {
99                return Err(KernelError::DimensionMismatch {
100                    expected: vec![a.output_dim()],
101                    got: vec![b.input_dim()],
102                    context: "MLPFeatureExtractor: layer shape chain".to_string(),
103                });
104            }
105        }
106        let parameters = flatten_layers(&layers);
107        Ok(Self { layers, parameters })
108    }
109
110    /// Build an MLP from a list of layer widths and a parallel list of
111    /// activations (one per weight matrix — i.e. `widths.len() - 1`
112    /// entries). Weights are Xavier/Glorot-normal initialised via
113    /// SciRS2-Core's seeded RNG; biases are zero.
114    pub fn xavier_init(widths: &[usize], activations: &[Activation], seed: u64) -> Result<Self> {
115        if widths.len() < 2 {
116            return Err(KernelError::InvalidParameter {
117                parameter: "widths".to_string(),
118                value: format!("{:?}", widths),
119                reason: "xavier_init requires at least input and output widths".to_string(),
120            });
121        }
122        if widths.contains(&0) {
123            return Err(KernelError::InvalidParameter {
124                parameter: "widths".to_string(),
125                value: format!("{:?}", widths),
126                reason: "widths must be strictly positive".to_string(),
127            });
128        }
129        if activations.len() != widths.len() - 1 {
130            return Err(KernelError::DimensionMismatch {
131                expected: vec![widths.len() - 1],
132                got: vec![activations.len()],
133                context: "xavier_init: activations length".to_string(),
134            });
135        }
136        let mut rng = StdRng::seed_from_u64(seed);
137        let mut layers = Vec::with_capacity(widths.len() - 1);
138        for (pair, &activation) in widths.windows(2).zip(activations.iter()) {
139            let fan_in = pair[0];
140            let fan_out = pair[1];
141            let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
142            let dist = Normal::new(0.0, std).map_err(|e| KernelError::InvalidParameter {
143                parameter: "xavier stddev".to_string(),
144                value: std.to_string(),
145                reason: format!("Normal::new failed: {}", e),
146            })?;
147            let mut weights = Vec::with_capacity(fan_out);
148            for _ in 0..fan_out {
149                let mut row = Vec::with_capacity(fan_in);
150                for _ in 0..fan_in {
151                    row.push(rng.sample(dist));
152                }
153                weights.push(row);
154            }
155            let biases = vec![0.0; fan_out];
156            layers.push(DenseLayer::new(weights, biases, activation)?);
157        }
158        Self::from_layers(layers)
159    }
160
161    /// Immutable view of the layer stack.
162    pub fn layers(&self) -> &[DenseLayer] {
163        &self.layers
164    }
165
166    /// Number of layers.
167    pub fn num_layers(&self) -> usize {
168        self.layers.len()
169    }
170
171    /// Forward pass with per-layer caches of `(pre_activation,
172    /// post_activation)` tensors. Used by the analytical gradient path
173    /// in [`crate::deep_kernel::gradient`].
174    pub fn forward_with_cache(&self, input: &[f64]) -> Result<ForwardCache> {
175        if input.len() != self.input_dim() {
176            return Err(KernelError::DimensionMismatch {
177                expected: vec![self.input_dim()],
178                got: vec![input.len()],
179                context: "MLPFeatureExtractor::forward_with_cache input".to_string(),
180            });
181        }
182        let mut cache = Vec::with_capacity(self.layers.len());
183        let mut current = input.to_vec();
184        for layer in &self.layers {
185            let (pre, post) = layer.forward_with_preactivation(&current)?;
186            cache.push((pre, post.clone()));
187            current = post;
188        }
189        Ok((current, cache))
190    }
191
192    /// Push the flat parameter buffer back into the per-layer
193    /// `weights` / `biases`. Call this after mutating the flat buffer
194    /// returned from [`Self::parameters_mut`] but before the next
195    /// forward pass.
196    pub fn sync_from_flat(&mut self) -> Result<()> {
197        let mut idx = 0;
198        for layer in self.layers.iter_mut() {
199            for row in layer.weights.iter_mut() {
200                for w in row.iter_mut() {
201                    let v = *self.parameters.get(idx).ok_or_else(|| {
202                        KernelError::ComputationError(
203                            "parameter buffer too short during sync_from_flat".to_string(),
204                        )
205                    })?;
206                    if !v.is_finite() {
207                        return Err(KernelError::InvalidParameter {
208                            parameter: format!("parameters[{}]", idx),
209                            value: v.to_string(),
210                            reason: "parameters must remain finite".to_string(),
211                        });
212                    }
213                    *w = v;
214                    idx += 1;
215                }
216            }
217            for b in layer.biases.iter_mut() {
218                let v = *self.parameters.get(idx).ok_or_else(|| {
219                    KernelError::ComputationError(
220                        "parameter buffer too short during sync_from_flat".to_string(),
221                    )
222                })?;
223                if !v.is_finite() {
224                    return Err(KernelError::InvalidParameter {
225                        parameter: format!("parameters[{}]", idx),
226                        value: v.to_string(),
227                        reason: "parameters must remain finite".to_string(),
228                    });
229                }
230                *b = v;
231                idx += 1;
232            }
233        }
234        Ok(())
235    }
236}
237
238impl NeuralFeatureMap for MLPFeatureExtractor {
239    fn forward(&self, input: &[f64]) -> Result<Vec<f64>> {
240        if input.len() != self.input_dim() {
241            return Err(KernelError::DimensionMismatch {
242                expected: vec![self.input_dim()],
243                got: vec![input.len()],
244                context: "MLPFeatureExtractor::forward input".to_string(),
245            });
246        }
247        let mut current = input.to_vec();
248        for layer in &self.layers {
249            current = layer.forward(&current)?;
250        }
251        Ok(current)
252    }
253
254    fn parameters_mut(&mut self) -> &mut [f64] {
255        &mut self.parameters
256    }
257
258    fn parameters(&self) -> &[f64] {
259        &self.parameters
260    }
261
262    fn parameter_count(&self) -> usize {
263        self.parameters.len()
264    }
265
266    fn input_dim(&self) -> usize {
267        self.layers
268            .first()
269            .map(|l| l.input_dim())
270            .unwrap_or_default()
271    }
272
273    fn output_dim(&self) -> usize {
274        self.layers
275            .last()
276            .map(|l| l.output_dim())
277            .unwrap_or_default()
278    }
279}
280
281/// Serialise the weights / biases of a layer stack into a flat vector,
282/// in the canonical `layer0.weights ++ layer0.biases ++ layer1.weights
283/// ++ ...` order.
284fn flatten_layers(layers: &[DenseLayer]) -> Vec<f64> {
285    let mut out = Vec::with_capacity(layers.iter().map(DenseLayer::parameter_count).sum());
286    for layer in layers {
287        for row in &layer.weights {
288            out.extend_from_slice(row);
289        }
290        out.extend_from_slice(&layer.biases);
291    }
292    out
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn mlp_forward_identity_of_1x1() {
301        // A 1→1 identity-MLP should reproduce its input exactly.
302        let layer =
303            DenseLayer::new(vec![vec![1.0]], vec![0.0], Activation::Identity).expect("valid");
304        let mlp = MLPFeatureExtractor::from_layers(vec![layer]).expect("valid mlp");
305        let out = mlp.forward(&[2.5]).expect("forward");
306        assert_eq!(out, vec![2.5]);
307    }
308
309    #[test]
310    fn mlp_rejects_shape_chain_mismatch() {
311        let a =
312            DenseLayer::new(vec![vec![1.0, 0.0]], vec![0.0], Activation::Identity).expect("valid");
313        // Output of `a` is 1, but `b` expects 3.
314        let b = DenseLayer::new(vec![vec![1.0, 1.0, 1.0]], vec![0.0], Activation::Identity)
315            .expect("valid");
316        let err = MLPFeatureExtractor::from_layers(vec![a, b]).expect_err("must fail");
317        assert!(matches!(err, KernelError::DimensionMismatch { .. }));
318    }
319
320    #[test]
321    fn mlp_parameter_roundtrip() {
322        let layer = DenseLayer::new(
323            vec![vec![1.0, 2.0], vec![3.0, 4.0]],
324            vec![0.5, -0.5],
325            Activation::ReLU,
326        )
327        .expect("valid");
328        let mlp = MLPFeatureExtractor::from_layers(vec![layer]).expect("valid");
329        // 2x2 weights + 2 biases = 6 params.
330        assert_eq!(mlp.parameter_count(), 6);
331        assert_eq!(mlp.parameters(), &[1.0, 2.0, 3.0, 4.0, 0.5, -0.5]);
332    }
333}