smelte_rs/nn/layers/
layer_norm.rs

1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4/// TODO
5#[derive(Clone)]
6pub struct LayerNorm<T: Tensor> {
7    weight: T,
8    bias: T,
9    epsilon: f32,
10}
11
12impl<T: Tensor + TensorOps<T>> LayerNorm<T> {
13    /// TODO
14    pub fn new(weight: T, bias: T, epsilon: f32) -> Self {
15        Self {
16            weight,
17            bias,
18            epsilon,
19        }
20    }
21
22    /// TODO
23    pub fn forward(&self, tensor: &mut T) -> Result<(), SmeltError> {
24        T::normalize(tensor, self.epsilon)?;
25        T::mul(&self.weight, tensor)?;
26        T::add(&self.bias, tensor)
27    }
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use crate::cpu::f32::Tensor;
34
35    #[test]
36    fn test_layer_norm() {
37        let mut zeros = Tensor::zeros(vec![3, 2]);
38        let weights = Tensor::zeros(vec![3, 2]);
39        let bias = Tensor::zeros(vec![2]);
40
41        let linear = LayerNorm::new(weights, bias, 1e-5);
42
43        linear.forward(&mut zeros).unwrap();
44    }
45}