ruvector_cnn/layers/
batchnorm.rs1use crate::{simd, CnnError, CnnResult, Tensor};
9
10use super::Layer;
11
12pub type BatchNorm2d = BatchNorm;
14
15#[derive(Debug, Clone)]
19pub struct BatchNorm {
20 num_features: usize,
22 gamma: Vec<f32>,
24 beta: Vec<f32>,
26 running_mean: Vec<f32>,
28 running_var: Vec<f32>,
30 epsilon: f32,
32 #[allow(dead_code)]
34 momentum: f32,
35}
36
37impl BatchNorm {
38 pub fn new(num_features: usize) -> Self {
45 Self {
46 num_features,
47 gamma: vec![1.0; num_features],
48 beta: vec![0.0; num_features],
49 running_mean: vec![0.0; num_features],
50 running_var: vec![1.0; num_features],
51 epsilon: 1e-5,
52 momentum: 0.1,
53 }
54 }
55
56 pub fn with_epsilon(num_features: usize, epsilon: f32) -> Self {
58 let mut bn = Self::new(num_features);
59 bn.epsilon = epsilon;
60 bn
61 }
62
63 pub fn set_params(&mut self, gamma: Vec<f32>, beta: Vec<f32>) -> CnnResult<()> {
65 if gamma.len() != self.num_features || beta.len() != self.num_features {
66 return Err(CnnError::invalid_shape(
67 format!("num_features={}", self.num_features),
68 format!("gamma={}, beta={}", gamma.len(), beta.len()),
69 ));
70 }
71 self.gamma = gamma;
72 self.beta = beta;
73 Ok(())
74 }
75
76 pub fn set_running_stats(&mut self, mean: Vec<f32>, var: Vec<f32>) -> CnnResult<()> {
78 if mean.len() != self.num_features || var.len() != self.num_features {
79 return Err(CnnError::invalid_shape(
80 format!("num_features={}", self.num_features),
81 format!("mean={}, var={}", mean.len(), var.len()),
82 ));
83 }
84 self.running_mean = mean;
85 self.running_var = var;
86 Ok(())
87 }
88
89 pub fn num_features(&self) -> usize {
91 self.num_features
92 }
93
94 pub fn gamma(&self) -> &[f32] {
96 &self.gamma
97 }
98
99 pub fn beta(&self) -> &[f32] {
101 &self.beta
102 }
103
104 pub fn running_mean(&self) -> &[f32] {
106 &self.running_mean
107 }
108
109 pub fn running_var(&self) -> &[f32] {
111 &self.running_var
112 }
113}
114
115impl Layer for BatchNorm {
116 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
117 let shape = input.shape();
119 if shape.len() != 4 {
120 return Err(CnnError::invalid_shape(
121 "4D tensor (NHWC)",
122 format!("{}D tensor", shape.len()),
123 ));
124 }
125
126 let channels = shape[3];
127 if channels != self.num_features {
128 return Err(CnnError::invalid_shape(
129 format!("{} channels", self.num_features),
130 format!("{} channels", channels),
131 ));
132 }
133
134 let mut output = Tensor::zeros(shape);
135
136 simd::batch_norm_simd(
137 input.data(),
138 output.data_mut(),
139 &self.gamma,
140 &self.beta,
141 &self.running_mean,
142 &self.running_var,
143 self.epsilon,
144 self.num_features,
145 );
146
147 Ok(output)
148 }
149
150 fn name(&self) -> &'static str {
151 "BatchNorm"
152 }
153
154 fn num_params(&self) -> usize {
155 self.num_features * 2
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_batch_norm_creation() {
166 let bn = BatchNorm::new(64);
167 assert_eq!(bn.num_features(), 64);
168 assert_eq!(bn.gamma().len(), 64);
169 assert_eq!(bn.beta().len(), 64);
170 assert_eq!(bn.num_params(), 128);
171 }
172
173 #[test]
174 fn test_batch_norm_forward() {
175 let bn = BatchNorm::new(4);
176 let input = Tensor::ones(&[1, 8, 8, 4]);
177 let output = bn.forward(&input).unwrap();
178
179 assert_eq!(output.shape(), input.shape());
180
181 for &val in output.data() {
184 assert!((val - 1.0).abs() < 0.01);
185 }
186 }
187
188 #[test]
189 fn test_batch_norm_normalization() {
190 let mut bn = BatchNorm::new(2);
191
192 bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap();
194
195 let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap();
197 let output = bn.forward(&input).unwrap();
198
199 assert!(output.data()[0].abs() < 0.01);
204 assert!(output.data()[1].abs() < 0.01);
206 assert!((output.data()[2] - 2.0).abs() < 0.01);
208 assert!((output.data()[3] - 1.0).abs() < 0.01);
210 }
211
212 #[test]
213 fn test_batch_norm_invalid_shape() {
214 let bn = BatchNorm::new(4);
215 let input = Tensor::ones(&[1, 8, 8, 8]); let result = bn.forward(&input);
218 assert!(result.is_err());
219 }
220}