zenu_layer/layers/
batch_norm_2d.rs1use std::collections::HashMap;
2
3use zenu_autograd::{
4 creator::{ones::ones, zeros::zeros},
5 nn::batch_norm::{batch_norm_2d, BatchNorm2dAutoGradConfig},
6 Variable,
7};
8use zenu_matrix::{device::Device, dim::DimTrait, num::Num};
9
10use crate::{Module, Parameters};
11
12pub struct BatchNorm2d<T: Num, D: Device> {
13 config: BatchNorm2dAutoGradConfig<T>,
14 momentum: f64,
15 pub scale: Variable<T, D>,
16 pub bias: Variable<T, D>,
17 pub mean: Variable<T, D>,
18 pub variance: Variable<T, D>,
19}
20
21impl<T: Num, D: Device> Module<T, D> for BatchNorm2d<T, D> {
22 type Input = Variable<T, D>;
23 type Output = Variable<T, D>;
24 fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25 if input.get_shape() != self.config.get_shape() {
26 self.config.update_shape(input.get_shape().slice());
27 }
28 batch_norm_2d(
29 input,
30 self.scale.clone(),
31 self.bias.clone(),
32 self.mean.clone(),
33 self.variance.clone(),
34 self.momentum,
35 self.config.clone(),
36 )
37 }
38}
39
40impl<T: Num, D: Device> Parameters<T, D> for BatchNorm2d<T, D> {
41 fn weights(&self) -> HashMap<String, Variable<T, D>> {
42 let mut weights = HashMap::new();
43 weights.insert("batch_norm_2d.scale".to_string(), self.scale.clone());
44 weights
45 }
46
47 fn biases(&self) -> HashMap<String, Variable<T, D>> {
48 let mut biases = HashMap::new();
49 biases.insert("batch_norm_2d.bias".to_string(), self.bias.clone());
50 biases
51 }
52
53 fn parameters(&self) -> HashMap<String, Variable<T, D>> {
54 let mut parameters = HashMap::new();
55 for (key, value) in &self.weights() {
56 parameters.insert(key.clone(), value.clone());
57 }
58 for (key, value) in &self.biases() {
59 parameters.insert(key.clone(), value.clone());
60 }
61 parameters.insert("batch_norm_2d.mean".to_string(), self.mean.clone());
62 parameters.insert("batch_norm_2d.variance".to_string(), self.variance.clone());
63 parameters
64 }
65}
66
67impl<T: Num, D: Device> BatchNorm2d<T, D> {
68 #[must_use]
69 pub fn new(channels: usize, momentum: f64) -> Self {
70 let scale = ones([channels]);
71 let bias = zeros([channels]);
72 let mean = zeros([channels]);
73 let variance = ones([channels]);
74
75 scale.set_is_train(true);
76 bias.set_is_train(true);
77
78 scale.set_name("batch_norm_2d.scale");
79 bias.set_name("batch_norm_2d.bias");
80 mean.set_name("batch_norm_2d.mean");
81 variance.set_name("batch_norm_2d.variance");
82
83 let config = BatchNorm2dAutoGradConfig::default();
84 Self {
85 config,
86 momentum,
87 scale,
88 bias,
89 mean,
90 variance,
91 }
92 }
93
94 pub fn to<Dout: Device>(self) -> BatchNorm2d<T, Dout> {
95 BatchNorm2d {
96 config: self.config,
97 momentum: self.momentum,
98 scale: self.scale.to(),
99 bias: self.bias.to(),
100 mean: self.mean.to(),
101 variance: self.variance.to(),
102 }
103 }
104}