tensorlogic_sklears_kernels/deep_kernel/
feature_extractor.rs1use scirs2_core::random::{Normal, SeedableRng, StdRng};
20
21use crate::deep_kernel::layer::{Activation, DenseLayer};
22use crate::error::{KernelError, Result};
23
24pub type LayerCache = (Vec<f64>, Vec<f64>);
27
28pub type ForwardCache = (Vec<f64>, Vec<LayerCache>);
31
32pub trait NeuralFeatureMap: Send + Sync {
40 fn forward(&self, input: &[f64]) -> Result<Vec<f64>>;
42
43 fn parameters_mut(&mut self) -> &mut [f64];
47
48 fn parameters(&self) -> &[f64];
50
51 fn parameter_count(&self) -> usize;
53
54 fn input_dim(&self) -> usize;
56
57 fn output_dim(&self) -> usize;
59}
60
61#[derive(Clone, Debug)]
78pub struct MLPFeatureExtractor {
79 layers: Vec<DenseLayer>,
80 parameters: Vec<f64>,
81}
82
83impl MLPFeatureExtractor {
84 pub fn from_layers(layers: Vec<DenseLayer>) -> Result<Self> {
89 if layers.is_empty() {
90 return Err(KernelError::InvalidParameter {
91 parameter: "layers".to_string(),
92 value: "[]".to_string(),
93 reason: "MLPFeatureExtractor requires at least one layer".to_string(),
94 });
95 }
96 for pair in layers.windows(2) {
97 let (a, b) = (&pair[0], &pair[1]);
98 if a.output_dim() != b.input_dim() {
99 return Err(KernelError::DimensionMismatch {
100 expected: vec![a.output_dim()],
101 got: vec![b.input_dim()],
102 context: "MLPFeatureExtractor: layer shape chain".to_string(),
103 });
104 }
105 }
106 let parameters = flatten_layers(&layers);
107 Ok(Self { layers, parameters })
108 }
109
110 pub fn xavier_init(widths: &[usize], activations: &[Activation], seed: u64) -> Result<Self> {
115 if widths.len() < 2 {
116 return Err(KernelError::InvalidParameter {
117 parameter: "widths".to_string(),
118 value: format!("{:?}", widths),
119 reason: "xavier_init requires at least input and output widths".to_string(),
120 });
121 }
122 if widths.contains(&0) {
123 return Err(KernelError::InvalidParameter {
124 parameter: "widths".to_string(),
125 value: format!("{:?}", widths),
126 reason: "widths must be strictly positive".to_string(),
127 });
128 }
129 if activations.len() != widths.len() - 1 {
130 return Err(KernelError::DimensionMismatch {
131 expected: vec![widths.len() - 1],
132 got: vec![activations.len()],
133 context: "xavier_init: activations length".to_string(),
134 });
135 }
136 let mut rng = StdRng::seed_from_u64(seed);
137 let mut layers = Vec::with_capacity(widths.len() - 1);
138 for (pair, &activation) in widths.windows(2).zip(activations.iter()) {
139 let fan_in = pair[0];
140 let fan_out = pair[1];
141 let std = (2.0 / (fan_in + fan_out) as f64).sqrt();
142 let dist = Normal::new(0.0, std).map_err(|e| KernelError::InvalidParameter {
143 parameter: "xavier stddev".to_string(),
144 value: std.to_string(),
145 reason: format!("Normal::new failed: {}", e),
146 })?;
147 let mut weights = Vec::with_capacity(fan_out);
148 for _ in 0..fan_out {
149 let mut row = Vec::with_capacity(fan_in);
150 for _ in 0..fan_in {
151 row.push(rng.sample(dist));
152 }
153 weights.push(row);
154 }
155 let biases = vec![0.0; fan_out];
156 layers.push(DenseLayer::new(weights, biases, activation)?);
157 }
158 Self::from_layers(layers)
159 }
160
161 pub fn layers(&self) -> &[DenseLayer] {
163 &self.layers
164 }
165
166 pub fn num_layers(&self) -> usize {
168 self.layers.len()
169 }
170
171 pub fn forward_with_cache(&self, input: &[f64]) -> Result<ForwardCache> {
175 if input.len() != self.input_dim() {
176 return Err(KernelError::DimensionMismatch {
177 expected: vec![self.input_dim()],
178 got: vec![input.len()],
179 context: "MLPFeatureExtractor::forward_with_cache input".to_string(),
180 });
181 }
182 let mut cache = Vec::with_capacity(self.layers.len());
183 let mut current = input.to_vec();
184 for layer in &self.layers {
185 let (pre, post) = layer.forward_with_preactivation(¤t)?;
186 cache.push((pre, post.clone()));
187 current = post;
188 }
189 Ok((current, cache))
190 }
191
192 pub fn sync_from_flat(&mut self) -> Result<()> {
197 let mut idx = 0;
198 for layer in self.layers.iter_mut() {
199 for row in layer.weights.iter_mut() {
200 for w in row.iter_mut() {
201 let v = *self.parameters.get(idx).ok_or_else(|| {
202 KernelError::ComputationError(
203 "parameter buffer too short during sync_from_flat".to_string(),
204 )
205 })?;
206 if !v.is_finite() {
207 return Err(KernelError::InvalidParameter {
208 parameter: format!("parameters[{}]", idx),
209 value: v.to_string(),
210 reason: "parameters must remain finite".to_string(),
211 });
212 }
213 *w = v;
214 idx += 1;
215 }
216 }
217 for b in layer.biases.iter_mut() {
218 let v = *self.parameters.get(idx).ok_or_else(|| {
219 KernelError::ComputationError(
220 "parameter buffer too short during sync_from_flat".to_string(),
221 )
222 })?;
223 if !v.is_finite() {
224 return Err(KernelError::InvalidParameter {
225 parameter: format!("parameters[{}]", idx),
226 value: v.to_string(),
227 reason: "parameters must remain finite".to_string(),
228 });
229 }
230 *b = v;
231 idx += 1;
232 }
233 }
234 Ok(())
235 }
236}
237
238impl NeuralFeatureMap for MLPFeatureExtractor {
239 fn forward(&self, input: &[f64]) -> Result<Vec<f64>> {
240 if input.len() != self.input_dim() {
241 return Err(KernelError::DimensionMismatch {
242 expected: vec![self.input_dim()],
243 got: vec![input.len()],
244 context: "MLPFeatureExtractor::forward input".to_string(),
245 });
246 }
247 let mut current = input.to_vec();
248 for layer in &self.layers {
249 current = layer.forward(¤t)?;
250 }
251 Ok(current)
252 }
253
254 fn parameters_mut(&mut self) -> &mut [f64] {
255 &mut self.parameters
256 }
257
258 fn parameters(&self) -> &[f64] {
259 &self.parameters
260 }
261
262 fn parameter_count(&self) -> usize {
263 self.parameters.len()
264 }
265
266 fn input_dim(&self) -> usize {
267 self.layers
268 .first()
269 .map(|l| l.input_dim())
270 .unwrap_or_default()
271 }
272
273 fn output_dim(&self) -> usize {
274 self.layers
275 .last()
276 .map(|l| l.output_dim())
277 .unwrap_or_default()
278 }
279}
280
281fn flatten_layers(layers: &[DenseLayer]) -> Vec<f64> {
285 let mut out = Vec::with_capacity(layers.iter().map(DenseLayer::parameter_count).sum());
286 for layer in layers {
287 for row in &layer.weights {
288 out.extend_from_slice(row);
289 }
290 out.extend_from_slice(&layer.biases);
291 }
292 out
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn mlp_forward_identity_of_1x1() {
301 let layer =
303 DenseLayer::new(vec![vec![1.0]], vec![0.0], Activation::Identity).expect("valid");
304 let mlp = MLPFeatureExtractor::from_layers(vec![layer]).expect("valid mlp");
305 let out = mlp.forward(&[2.5]).expect("forward");
306 assert_eq!(out, vec![2.5]);
307 }
308
309 #[test]
310 fn mlp_rejects_shape_chain_mismatch() {
311 let a =
312 DenseLayer::new(vec![vec![1.0, 0.0]], vec![0.0], Activation::Identity).expect("valid");
313 let b = DenseLayer::new(vec![vec![1.0, 1.0, 1.0]], vec![0.0], Activation::Identity)
315 .expect("valid");
316 let err = MLPFeatureExtractor::from_layers(vec![a, b]).expect_err("must fail");
317 assert!(matches!(err, KernelError::DimensionMismatch { .. }));
318 }
319
320 #[test]
321 fn mlp_parameter_roundtrip() {
322 let layer = DenseLayer::new(
323 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
324 vec![0.5, -0.5],
325 Activation::ReLU,
326 )
327 .expect("valid");
328 let mlp = MLPFeatureExtractor::from_layers(vec![layer]).expect("valid");
329 assert_eq!(mlp.parameter_count(), 6);
331 assert_eq!(mlp.parameters(), &[1.0, 2.0, 3.0, 4.0, 0.5, -0.5]);
332 }
333}