smelte_rs/nn/layers/
layer_norm.rs1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4#[derive(Clone)]
6pub struct LayerNorm<T: Tensor> {
7 weight: T,
8 bias: T,
9 epsilon: f32,
10}
11
12impl<T: Tensor + TensorOps<T>> LayerNorm<T> {
13 pub fn new(weight: T, bias: T, epsilon: f32) -> Self {
15 Self {
16 weight,
17 bias,
18 epsilon,
19 }
20 }
21
22 pub fn forward(&self, tensor: &mut T) -> Result<(), SmeltError> {
24 T::normalize(tensor, self.epsilon)?;
25 T::mul(&self.weight, tensor)?;
26 T::add(&self.bias, tensor)
27 }
28}
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33 use crate::cpu::f32::Tensor;
34
35 #[test]
36 fn test_layer_norm() {
37 let mut zeros = Tensor::zeros(vec![3, 2]);
38 let weights = Tensor::zeros(vec![3, 2]);
39 let bias = Tensor::zeros(vec![2]);
40
41 let linear = LayerNorm::new(weights, bias, 1e-5);
42
43 linear.forward(&mut zeros).unwrap();
44 }
45}