tensorlogic_sklears_kernels/deep_kernel/
layer.rs1use crate::error::{KernelError, Result};
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
20pub enum Activation {
21 Identity,
23 ReLU,
25 Tanh,
27}
28
29impl Activation {
30 pub fn apply_inplace(&self, values: &mut [f64]) {
32 match self {
33 Self::Identity => {}
34 Self::ReLU => {
35 for v in values.iter_mut() {
36 if *v < 0.0 {
37 *v = 0.0;
38 }
39 }
40 }
41 Self::Tanh => {
42 for v in values.iter_mut() {
43 *v = v.tanh();
44 }
45 }
46 }
47 }
48
49 pub fn apply_scalar(&self, v: f64) -> f64 {
51 match self {
52 Self::Identity => v,
53 Self::ReLU => v.max(0.0),
54 Self::Tanh => v.tanh(),
55 }
56 }
57
58 pub fn derivative(&self, pre_activation: f64) -> f64 {
67 match self {
68 Self::Identity => 1.0,
69 Self::ReLU => {
70 if pre_activation > 0.0 {
71 1.0
72 } else {
73 0.0
74 }
75 }
76 Self::Tanh => {
77 let t = pre_activation.tanh();
78 1.0 - t * t
79 }
80 }
81 }
82
83 pub fn name(&self) -> &'static str {
85 match self {
86 Self::Identity => "Identity",
87 Self::ReLU => "ReLU",
88 Self::Tanh => "Tanh",
89 }
90 }
91}
92
93#[derive(Clone, Debug)]
100pub struct DenseLayer {
101 pub weights: Vec<Vec<f64>>,
103 pub biases: Vec<f64>,
105 pub activation: Activation,
107}
108
109impl DenseLayer {
110 pub fn new(weights: Vec<Vec<f64>>, biases: Vec<f64>, activation: Activation) -> Result<Self> {
116 if weights.is_empty() {
117 return Err(KernelError::InvalidParameter {
118 parameter: "weights".to_string(),
119 value: "[]".to_string(),
120 reason: "dense layer must have at least one output".to_string(),
121 });
122 }
123 let input_dim = weights[0].len();
124 if input_dim == 0 {
125 return Err(KernelError::InvalidParameter {
126 parameter: "weights[0]".to_string(),
127 value: "[]".to_string(),
128 reason: "dense layer must have at least one input".to_string(),
129 });
130 }
131 for (i, row) in weights.iter().enumerate() {
132 if row.len() != input_dim {
133 return Err(KernelError::DimensionMismatch {
134 expected: vec![input_dim],
135 got: vec![row.len()],
136 context: format!("DenseLayer::new weights[{}]", i),
137 });
138 }
139 for (j, &w) in row.iter().enumerate() {
140 if !w.is_finite() {
141 return Err(KernelError::InvalidParameter {
142 parameter: format!("weights[{}][{}]", i, j),
143 value: w.to_string(),
144 reason: "weights must be finite".to_string(),
145 });
146 }
147 }
148 }
149 if biases.len() != weights.len() {
150 return Err(KernelError::DimensionMismatch {
151 expected: vec![weights.len()],
152 got: vec![biases.len()],
153 context: "DenseLayer::new biases length".to_string(),
154 });
155 }
156 for (i, &b) in biases.iter().enumerate() {
157 if !b.is_finite() {
158 return Err(KernelError::InvalidParameter {
159 parameter: format!("biases[{}]", i),
160 value: b.to_string(),
161 reason: "biases must be finite".to_string(),
162 });
163 }
164 }
165 Ok(Self {
166 weights,
167 biases,
168 activation,
169 })
170 }
171
172 pub fn input_dim(&self) -> usize {
174 self.weights[0].len()
175 }
176
177 pub fn output_dim(&self) -> usize {
179 self.weights.len()
180 }
181
182 pub fn activation(&self) -> Activation {
184 self.activation
185 }
186
187 pub fn forward(&self, input: &[f64]) -> Result<Vec<f64>> {
189 let (_, post) = self.forward_with_preactivation(input)?;
190 Ok(post)
191 }
192
193 pub fn forward_with_preactivation(&self, input: &[f64]) -> Result<(Vec<f64>, Vec<f64>)> {
198 if input.len() != self.input_dim() {
199 return Err(KernelError::DimensionMismatch {
200 expected: vec![self.input_dim()],
201 got: vec![input.len()],
202 context: "DenseLayer::forward input length".to_string(),
203 });
204 }
205 let mut pre = Vec::with_capacity(self.output_dim());
206 for (row, &bias) in self.weights.iter().zip(self.biases.iter()) {
207 let mut acc = bias;
208 for (w, x) in row.iter().zip(input.iter()) {
209 acc += w * x;
210 }
211 pre.push(acc);
212 }
213 let mut post = pre.clone();
214 self.activation.apply_inplace(&mut post);
215 Ok((pre, post))
216 }
217
218 pub fn parameter_count(&self) -> usize {
220 self.output_dim() * self.input_dim() + self.output_dim()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn activation_relu_clamps_negative_to_zero() {
230 let mut v = vec![-2.0, -0.5, 0.0, 0.5, 3.5];
231 Activation::ReLU.apply_inplace(&mut v);
232 assert_eq!(v, vec![0.0, 0.0, 0.0, 0.5, 3.5]);
233 }
234
235 #[test]
236 fn activation_tanh_matches_std() {
237 let v = Activation::Tanh.apply_scalar(0.7);
238 assert!((v - 0.7_f64.tanh()).abs() < 1e-12);
239 }
240
241 #[test]
242 fn activation_derivative_identity_is_one() {
243 assert_eq!(Activation::Identity.derivative(-5.0), 1.0);
244 assert_eq!(Activation::Identity.derivative(7.0), 1.0);
245 }
246
247 #[test]
248 fn dense_layer_forward_identity() {
249 let layer = DenseLayer::new(
250 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
251 vec![0.0, 0.0],
252 Activation::Identity,
253 )
254 .expect("valid layer");
255 let out = layer.forward(&[3.0, 4.0]).expect("forward");
256 assert_eq!(out, vec![3.0, 4.0]);
257 }
258
259 #[test]
260 fn dense_layer_rejects_dim_mismatch_input() {
261 let layer =
262 DenseLayer::new(vec![vec![1.0, 2.0]], vec![0.5], Activation::Identity).expect("valid");
263 let err = layer
264 .forward(&[1.0, 2.0, 3.0])
265 .expect_err("must fail on 3-dim input");
266 assert!(matches!(err, KernelError::DimensionMismatch { .. }));
267 }
268
269 #[test]
270 fn dense_layer_rejects_jagged_weights() {
271 let err = DenseLayer::new(
272 vec![vec![1.0, 2.0], vec![3.0]],
273 vec![0.0, 0.0],
274 Activation::Identity,
275 )
276 .expect_err("must fail");
277 assert!(matches!(err, KernelError::DimensionMismatch { .. }));
278 }
279
280 #[test]
281 fn dense_layer_rejects_bias_length_mismatch() {
282 let err = DenseLayer::new(
283 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
284 vec![0.0],
285 Activation::Identity,
286 )
287 .expect_err("must fail");
288 assert!(matches!(err, KernelError::DimensionMismatch { .. }));
289 }
290
291 #[test]
292 fn dense_layer_parameter_count() {
293 let layer = DenseLayer::new(
294 vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]],
295 vec![0.1, 0.2],
296 Activation::ReLU,
297 )
298 .expect("valid");
299 assert_eq!(layer.parameter_count(), 8);
301 }
302}