1use rand::Rng;
14use rand::distributions::{Distribution, Uniform};
15
16use crate::Activation;
17use crate::{Error, Result};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum Init {
22 Zeros,
23 Xavier,
27 He,
31}
32
33#[derive(Debug, Clone)]
34pub struct Layer {
38 in_dim: usize,
39 out_dim: usize,
40 activation: Activation,
41 weights: Vec<f32>,
43 biases: Vec<f32>,
44}
45
46impl Layer {
47 #[inline]
48 pub fn activation(&self) -> Activation {
50 self.activation
51 }
52}
53
54impl Layer {
55 #[cfg(feature = "serde")]
59 pub(crate) fn from_parts(
60 in_dim: usize,
61 out_dim: usize,
62 activation: Activation,
63 weights: Vec<f32>,
64 biases: Vec<f32>,
65 ) -> Result<Self> {
66 if in_dim == 0 || out_dim == 0 {
67 return Err(Error::InvalidData(format!(
68 "layer dims must be > 0, got in_dim={in_dim} out_dim={out_dim}"
69 )));
70 }
71
72 activation
73 .validate()
74 .map_err(|e| Error::InvalidData(format!("invalid activation: {e}")))?;
75
76 let expected_w = in_dim
77 .checked_mul(out_dim)
78 .ok_or_else(|| Error::InvalidData("layer weight shape overflow".to_owned()))?;
79 if weights.len() != expected_w {
80 return Err(Error::InvalidData(format!(
81 "weights length {} does not match out_dim * in_dim ({} * {})",
82 weights.len(),
83 out_dim,
84 in_dim
85 )));
86 }
87 if biases.len() != out_dim {
88 return Err(Error::InvalidData(format!(
89 "biases length {} does not match out_dim {}",
90 biases.len(),
91 out_dim
92 )));
93 }
94
95 if weights.iter().any(|v| !v.is_finite()) {
96 return Err(Error::InvalidData(
97 "weights must contain only finite values".to_owned(),
98 ));
99 }
100 if biases.iter().any(|v| !v.is_finite()) {
101 return Err(Error::InvalidData(
102 "biases must contain only finite values".to_owned(),
103 ));
104 }
105
106 Ok(Self {
107 in_dim,
108 out_dim,
109 activation,
110 weights,
111 biases,
112 })
113 }
114
115 pub fn new_with_rng<R: Rng + ?Sized>(
116 in_dim: usize,
117 out_dim: usize,
118 init: Init,
119 activation: Activation,
120 rng: &mut R,
121 ) -> Result<Self> {
122 if in_dim == 0 || out_dim == 0 {
123 return Err(Error::InvalidConfig("layer dims must be > 0".to_owned()));
124 }
125
126 activation.validate()?;
127
128 let mut weights = vec![0.0; in_dim * out_dim];
129 match init {
130 Init::Zeros => {}
131 Init::Xavier => {
132 let fan_in = in_dim as f32;
133 let fan_out = out_dim as f32;
134 let limit = (6.0 / (fan_in + fan_out)).sqrt();
135 let dist = Uniform::new(-limit, limit);
136 for w in &mut weights {
137 *w = dist.sample(rng);
138 }
139 }
140 Init::He => {
141 let fan_in = in_dim as f32;
142 let limit = (6.0 / fan_in).sqrt();
143 let dist = Uniform::new(-limit, limit);
144 for w in &mut weights {
145 *w = dist.sample(rng);
146 }
147 }
148 }
149
150 let biases = vec![0.0; out_dim];
151
152 Ok(Self {
153 in_dim,
154 out_dim,
155 activation,
156 weights,
157 biases,
158 })
159 }
160
161 #[inline]
162 pub fn in_dim(&self) -> usize {
164 self.in_dim
165 }
166
167 #[inline]
168 pub fn out_dim(&self) -> usize {
170 self.out_dim
171 }
172
173 #[inline]
174 pub(crate) fn weights(&self) -> &[f32] {
175 &self.weights
176 }
177
178 #[inline]
179 pub(crate) fn biases(&self) -> &[f32] {
180 &self.biases
181 }
182
183 #[inline]
184 #[cfg(test)]
185 pub(crate) fn weights_mut(&mut self) -> &mut [f32] {
186 &mut self.weights
187 }
188
189 #[inline]
190 #[cfg(test)]
191 pub(crate) fn biases_mut(&mut self) -> &mut [f32] {
192 &mut self.biases
193 }
194
195 #[inline]
205 pub fn forward(&self, inputs: &[f32], outputs: &mut [f32]) {
206 assert_eq!(
207 inputs.len(),
208 self.in_dim,
209 "inputs len {} does not match layer in_dim {}",
210 inputs.len(),
211 self.in_dim
212 );
213 assert_eq!(
214 outputs.len(),
215 self.out_dim,
216 "outputs len {} does not match layer out_dim {}",
217 outputs.len(),
218 self.out_dim
219 );
220
221 let activation = self.activation;
222
223 for (o, out) in outputs.iter_mut().enumerate() {
224 let mut sum = self.biases[o];
225 let row = o * self.in_dim;
226 for (i, &x) in inputs.iter().enumerate() {
227 sum = self.weights[row + i].mul_add(x, sum);
228 }
229 *out = activation.forward(sum);
230 }
231 }
232
233 #[inline]
253 pub fn backward(
254 &self,
255 inputs: &[f32],
256 outputs: &[f32],
257 d_outputs: &[f32],
258 d_inputs: &mut [f32],
259 d_weights: &mut [f32],
260 d_biases: &mut [f32],
261 ) {
262 assert_eq!(
263 inputs.len(),
264 self.in_dim,
265 "inputs len {} does not match layer in_dim {}",
266 inputs.len(),
267 self.in_dim
268 );
269 assert_eq!(
270 outputs.len(),
271 self.out_dim,
272 "outputs len {} does not match layer out_dim {}",
273 outputs.len(),
274 self.out_dim
275 );
276 assert_eq!(
277 d_outputs.len(),
278 self.out_dim,
279 "d_outputs len {} does not match layer out_dim {}",
280 d_outputs.len(),
281 self.out_dim
282 );
283 assert_eq!(
284 d_inputs.len(),
285 self.in_dim,
286 "d_inputs len {} does not match layer in_dim {}",
287 d_inputs.len(),
288 self.in_dim
289 );
290 assert_eq!(
291 d_weights.len(),
292 self.weights.len(),
293 "d_weights len {} does not match weights len {}",
294 d_weights.len(),
295 self.weights.len()
296 );
297 assert_eq!(
298 d_biases.len(),
299 self.out_dim,
300 "d_biases len {} does not match layer out_dim {}",
301 d_biases.len(),
302 self.out_dim
303 );
304
305 d_inputs.fill(0.0);
307
308 let activation = self.activation;
309
310 for o in 0..self.out_dim {
311 let d_z = d_outputs[o] * activation.grad_from_output(outputs[o]);
312 d_biases[o] = d_z;
313
314 let row = o * self.in_dim;
315 for i in 0..self.in_dim {
316 let w = self.weights[row + i];
317 d_weights[row + i] = d_z * inputs[i];
318 d_inputs[i] = w.mul_add(d_z, d_inputs[i]);
319 }
320 }
321 }
322
323 #[inline]
340 pub fn backward_accumulate(
341 &self,
342 inputs: &[f32],
343 outputs: &[f32],
344 d_outputs: &[f32],
345 d_inputs: &mut [f32],
346 d_weights: &mut [f32],
347 d_biases: &mut [f32],
348 ) {
349 assert_eq!(
350 inputs.len(),
351 self.in_dim,
352 "inputs len {} does not match layer in_dim {}",
353 inputs.len(),
354 self.in_dim
355 );
356 assert_eq!(
357 outputs.len(),
358 self.out_dim,
359 "outputs len {} does not match layer out_dim {}",
360 outputs.len(),
361 self.out_dim
362 );
363 assert_eq!(
364 d_outputs.len(),
365 self.out_dim,
366 "d_outputs len {} does not match layer out_dim {}",
367 d_outputs.len(),
368 self.out_dim
369 );
370 assert_eq!(
371 d_inputs.len(),
372 self.in_dim,
373 "d_inputs len {} does not match layer in_dim {}",
374 d_inputs.len(),
375 self.in_dim
376 );
377 assert_eq!(
378 d_weights.len(),
379 self.weights.len(),
380 "d_weights len {} does not match weights len {}",
381 d_weights.len(),
382 self.weights.len()
383 );
384 assert_eq!(
385 d_biases.len(),
386 self.out_dim,
387 "d_biases len {} does not match layer out_dim {}",
388 d_biases.len(),
389 self.out_dim
390 );
391
392 d_inputs.fill(0.0);
394
395 let activation = self.activation;
396
397 for o in 0..self.out_dim {
398 let d_z = d_outputs[o] * activation.grad_from_output(outputs[o]);
399 d_biases[o] += d_z;
400
401 let row = o * self.in_dim;
402 for i in 0..self.in_dim {
403 let w = self.weights[row + i];
404 d_weights[row + i] += d_z * inputs[i];
405 d_inputs[i] = w.mul_add(d_z, d_inputs[i]);
406 }
407 }
408 }
409
410 #[inline]
416 pub fn sgd_step(&mut self, d_weights: &[f32], d_biases: &[f32], lr: f32) {
417 assert_eq!(
418 d_weights.len(),
419 self.weights.len(),
420 "d_weights len {} does not match weights len {}",
421 d_weights.len(),
422 self.weights.len()
423 );
424 assert_eq!(
425 d_biases.len(),
426 self.biases.len(),
427 "d_biases len {} does not match biases len {}",
428 d_biases.len(),
429 self.biases.len()
430 );
431
432 for (w, &dw) in self.weights.iter_mut().zip(d_weights) {
433 *w -= lr * dw;
434 }
435 for (b, &db) in self.biases.iter_mut().zip(d_biases) {
436 *b -= lr * db;
437 }
438 }
439
440 pub(crate) fn apply_weight_decay(&mut self, lr: f32, weight_decay: f32) {
444 assert!(
445 lr.is_finite() && lr > 0.0,
446 "learning rate must be finite and > 0"
447 );
448 assert!(
449 weight_decay.is_finite() && weight_decay >= 0.0,
450 "weight_decay must be finite and >= 0"
451 );
452
453 if weight_decay == 0.0 {
454 return;
455 }
456
457 let scale = lr * weight_decay;
458 for w in &mut self.weights {
459 *w -= scale * *w;
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use rand::SeedableRng;
468 use rand::rngs::StdRng;
469
470 fn loss_for_layer(layer: &Layer, input: &[f32], target: &[f32], out: &mut [f32]) -> f32 {
471 layer.forward(input, out);
472 crate::loss::mse(out, target)
473 }
474
475 fn assert_close(analytic: f32, numeric: f32, abs_tol: f32, rel_tol: f32) {
476 let diff = (analytic - numeric).abs();
477 let scale = analytic.abs().max(numeric.abs()).max(1.0);
478 assert!(
479 diff <= abs_tol || diff / scale <= rel_tol,
480 "analytic={analytic} numeric={numeric} diff={diff}"
481 );
482 }
483
484 #[test]
485 fn seeded_init_is_deterministic() {
486 let mut rng_a = StdRng::seed_from_u64(123);
487 let mut rng_b = StdRng::seed_from_u64(123);
488 let a = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng_a).unwrap();
489 let b = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng_b).unwrap();
490 assert_eq!(a.weights, b.weights);
491 assert_eq!(a.biases, b.biases);
492 }
493
494 #[test]
495 fn backward_matches_numeric_gradients() {
496 let in_dim = 3;
497 let out_dim = 2;
498 let mut rng = StdRng::seed_from_u64(0);
499 let mut layer =
500 Layer::new_with_rng(in_dim, out_dim, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
501
502 let mut input = vec![0.3_f32, -0.7_f32, 0.1_f32];
503 let target = vec![0.2_f32, -0.1_f32];
504
505 let mut outputs = vec![0.0_f32; out_dim];
506 layer.forward(&input, &mut outputs);
507
508 let mut d_outputs = vec![0.0_f32; out_dim];
509 let _loss = crate::loss::mse_backward(&outputs, &target, &mut d_outputs);
510
511 let mut d_inputs = vec![0.0_f32; in_dim];
512 let mut d_weights = vec![0.0_f32; in_dim * out_dim];
513 let mut d_biases = vec![0.0_f32; out_dim];
514
515 layer.backward(
516 &input,
517 &outputs,
518 &d_outputs,
519 &mut d_inputs,
520 &mut d_weights,
521 &mut d_biases,
522 );
523
524 let eps = 1e-3_f32;
525 let abs_tol = 1e-3_f32;
526 let rel_tol = 1e-2_f32;
527
528 let mut out_tmp = vec![0.0_f32; out_dim];
530 for (p, &analytic) in d_weights.iter().enumerate() {
531 let orig = layer.weights[p];
532
533 layer.weights[p] = orig + eps;
534 let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
535
536 layer.weights[p] = orig - eps;
537 let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
538
539 layer.weights[p] = orig;
540
541 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
542 assert_close(analytic, numeric, abs_tol, rel_tol);
543 }
544
545 for (p, &analytic) in d_biases.iter().enumerate() {
547 let orig = layer.biases[p];
548
549 layer.biases[p] = orig + eps;
550 let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
551
552 layer.biases[p] = orig - eps;
553 let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
554
555 layer.biases[p] = orig;
556
557 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
558 assert_close(analytic, numeric, abs_tol, rel_tol);
559 }
560
561 for i in 0..input.len() {
563 let orig = input[i];
564
565 input[i] = orig + eps;
566 let loss_plus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
567
568 input[i] = orig - eps;
569 let loss_minus = loss_for_layer(&layer, &input, &target, &mut out_tmp);
570
571 input[i] = orig;
572
573 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
574 let analytic = d_inputs[i];
575 assert_close(analytic, numeric, abs_tol, rel_tol);
576 }
577 }
578
579 #[test]
580 #[should_panic]
581 fn forward_panics_on_input_shape_mismatch() {
582 let mut rng = StdRng::seed_from_u64(0);
583 let layer = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
584 let input = vec![0.0_f32; 2];
585 let mut out = vec![0.0_f32; 2];
586 layer.forward(&input, &mut out);
587 }
588
589 #[test]
590 #[should_panic]
591 fn forward_panics_on_output_shape_mismatch() {
592 let mut rng = StdRng::seed_from_u64(0);
593 let layer = Layer::new_with_rng(3, 2, Init::Xavier, Activation::Tanh, &mut rng).unwrap();
594 let input = vec![0.0_f32; 3];
595 let mut out = vec![0.0_f32; 1];
596 layer.forward(&input, &mut out);
597 }
598}