zenu_layer/layers/
batch_norm_2d.rs

1use std::collections::HashMap;
2
3use zenu_autograd::{
4    creator::{ones::ones, zeros::zeros},
5    nn::batch_norm::{batch_norm_2d, BatchNorm2dAutoGradConfig},
6    Variable,
7};
8use zenu_matrix::{device::Device, dim::DimTrait, num::Num};
9
10use crate::{Module, Parameters};
11
12pub struct BatchNorm2d<T: Num, D: Device> {
13    config: BatchNorm2dAutoGradConfig<T>,
14    momentum: f64,
15    pub scale: Variable<T, D>,
16    pub bias: Variable<T, D>,
17    pub mean: Variable<T, D>,
18    pub variance: Variable<T, D>,
19}
20
21impl<T: Num, D: Device> Module<T, D> for BatchNorm2d<T, D> {
22    type Input = Variable<T, D>;
23    type Output = Variable<T, D>;
24    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25        if input.get_shape() != self.config.get_shape() {
26            self.config.update_shape(input.get_shape().slice());
27        }
28        batch_norm_2d(
29            input,
30            self.scale.clone(),
31            self.bias.clone(),
32            self.mean.clone(),
33            self.variance.clone(),
34            self.momentum,
35            self.config.clone(),
36        )
37    }
38}
39
40impl<T: Num, D: Device> Parameters<T, D> for BatchNorm2d<T, D> {
41    fn weights(&self) -> HashMap<String, Variable<T, D>> {
42        let mut weights = HashMap::new();
43        weights.insert("batch_norm_2d.scale".to_string(), self.scale.clone());
44        weights
45    }
46
47    fn biases(&self) -> HashMap<String, Variable<T, D>> {
48        let mut biases = HashMap::new();
49        biases.insert("batch_norm_2d.bias".to_string(), self.bias.clone());
50        biases
51    }
52
53    fn parameters(&self) -> HashMap<String, Variable<T, D>> {
54        let mut parameters = HashMap::new();
55        for (key, value) in &self.weights() {
56            parameters.insert(key.clone(), value.clone());
57        }
58        for (key, value) in &self.biases() {
59            parameters.insert(key.clone(), value.clone());
60        }
61        parameters.insert("batch_norm_2d.mean".to_string(), self.mean.clone());
62        parameters.insert("batch_norm_2d.variance".to_string(), self.variance.clone());
63        parameters
64    }
65}
66
67impl<T: Num, D: Device> BatchNorm2d<T, D> {
68    #[must_use]
69    pub fn new(channels: usize, momentum: f64) -> Self {
70        let scale = ones([channels]);
71        let bias = zeros([channels]);
72        let mean = zeros([channels]);
73        let variance = ones([channels]);
74
75        scale.set_is_train(true);
76        bias.set_is_train(true);
77
78        scale.set_name("batch_norm_2d.scale");
79        bias.set_name("batch_norm_2d.bias");
80        mean.set_name("batch_norm_2d.mean");
81        variance.set_name("batch_norm_2d.variance");
82
83        let config = BatchNorm2dAutoGradConfig::default();
84        Self {
85            config,
86            momentum,
87            scale,
88            bias,
89            mean,
90            variance,
91        }
92    }
93
94    pub fn to<Dout: Device>(self) -> BatchNorm2d<T, Dout> {
95        BatchNorm2d {
96            config: self.config,
97            momentum: self.momentum,
98            scale: self.scale.to(),
99            bias: self.bias.to(),
100            mean: self.mean.to(),
101            variance: self.variance.to(),
102        }
103    }
104}