1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
pub mod load;

use burn::{
    config::Config,
    module::{Module, Param},
    tensor::{backend::Backend, Tensor},
};

#[derive(Config)]
pub struct GroupNormConfig {
    n_group: usize,
    n_channel: usize,
    #[config(default = 1e-5)]
    eps: f64,
}

impl GroupNormConfig {
    pub fn init<B: Backend>(&self) -> GroupNorm<B> {
        assert!(
            self.n_channel % self.n_group == 0,
            "The number of channels {} must be divisible by the number of groups {}",
            self.n_channel,
            self.n_group
        );

        let gamma = Tensor::ones([self.n_channel]).into();
        let beta = Tensor::zeros([self.n_channel]).into();

        let eps = self.eps;

        GroupNorm {
            n_group: self.n_group,
            n_channel: self.n_channel,
            gamma,
            beta,
            eps,
        }
    }
}

#[derive(Module, Debug)]
pub struct GroupNorm<B: Backend> {
    n_group: usize,
    n_channel: usize,
    gamma: Param<Tensor<B, 1>>,
    beta: Param<Tensor<B, 1>>,
    eps: f64,
}

impl<B: Backend> GroupNorm<B> {
    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
        let shape = x.shape();
        let n_batch = shape.dims[0];
        let num_elements = shape.num_elements();

        let mut affine_shape = [1; D];
        affine_shape[1] = self.n_channel;

        layernorm(
            x.reshape([
                n_batch,
                self.n_group,
                num_elements / (n_batch * self.n_group),
            ]),
            self.eps,
        )
        .reshape(shape)
        .mul(self.gamma.val().reshape(affine_shape))
        .add(self.beta.val().reshape(affine_shape))
    }
}

pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tensor<B, D> {
    //let (var, mean) = x.clone().var_mean_bias(D - 1);
    //x.sub(mean).div(var.sqrt().add_scalar(eps))

    let u = x.clone() - x.mean_dim(D - 1);
    u.clone()
        .div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
}