Skip to main content

ruvector_cnn/layers/
batchnorm.rs

1//! Batch Normalization Layer
2//!
3//! Implements batch normalization with learned scale (gamma) and shift (beta) parameters:
4//! y = gamma * (x - mean) / sqrt(var + epsilon) + beta
5//!
6//! During inference, uses running mean and variance computed during training.
7
8use crate::{simd, CnnError, CnnResult, Tensor};
9
10use super::Layer;
11
12/// Alias for BatchNorm (for API compatibility with PyTorch naming).
13pub type BatchNorm2d = BatchNorm;
14
15/// Batch Normalization layer
16///
17/// Normalizes the input across the channel dimension for NHWC tensors.
18#[derive(Debug, Clone)]
19pub struct BatchNorm {
20    /// Number of channels (features)
21    num_features: usize,
22    /// Learned scale parameter (gamma)
23    gamma: Vec<f32>,
24    /// Learned shift parameter (beta)
25    beta: Vec<f32>,
26    /// Running mean (for inference)
27    running_mean: Vec<f32>,
28    /// Running variance (for inference)
29    running_var: Vec<f32>,
30    /// Small constant for numerical stability
31    epsilon: f32,
32    /// Momentum for running statistics update (used in training mode)
33    #[allow(dead_code)]
34    momentum: f32,
35}
36
37impl BatchNorm {
38    /// Create a new BatchNorm layer with default initialization
39    ///
40    /// - gamma initialized to 1.0
41    /// - beta initialized to 0.0
42    /// - running_mean initialized to 0.0
43    /// - running_var initialized to 1.0
44    pub fn new(num_features: usize) -> Self {
45        Self {
46            num_features,
47            gamma: vec![1.0; num_features],
48            beta: vec![0.0; num_features],
49            running_mean: vec![0.0; num_features],
50            running_var: vec![1.0; num_features],
51            epsilon: 1e-5,
52            momentum: 0.1,
53        }
54    }
55
56    /// Create BatchNorm with custom epsilon
57    pub fn with_epsilon(num_features: usize, epsilon: f32) -> Self {
58        let mut bn = Self::new(num_features);
59        bn.epsilon = epsilon;
60        bn
61    }
62
63    /// Set the learned parameters (gamma, beta)
64    pub fn set_params(&mut self, gamma: Vec<f32>, beta: Vec<f32>) -> CnnResult<()> {
65        if gamma.len() != self.num_features || beta.len() != self.num_features {
66            return Err(CnnError::invalid_shape(
67                format!("num_features={}", self.num_features),
68                format!("gamma={}, beta={}", gamma.len(), beta.len()),
69            ));
70        }
71        self.gamma = gamma;
72        self.beta = beta;
73        Ok(())
74    }
75
76    /// Set the running statistics (mean, var)
77    pub fn set_running_stats(&mut self, mean: Vec<f32>, var: Vec<f32>) -> CnnResult<()> {
78        if mean.len() != self.num_features || var.len() != self.num_features {
79            return Err(CnnError::invalid_shape(
80                format!("num_features={}", self.num_features),
81                format!("mean={}, var={}", mean.len(), var.len()),
82            ));
83        }
84        self.running_mean = mean;
85        self.running_var = var;
86        Ok(())
87    }
88
89    /// Get the number of features (channels)
90    pub fn num_features(&self) -> usize {
91        self.num_features
92    }
93
94    /// Get a reference to gamma
95    pub fn gamma(&self) -> &[f32] {
96        &self.gamma
97    }
98
99    /// Get a reference to beta
100    pub fn beta(&self) -> &[f32] {
101        &self.beta
102    }
103
104    /// Get a reference to running mean
105    pub fn running_mean(&self) -> &[f32] {
106        &self.running_mean
107    }
108
109    /// Get a reference to running variance
110    pub fn running_var(&self) -> &[f32] {
111        &self.running_var
112    }
113}
114
115impl Layer for BatchNorm {
116    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
117        // Validate input shape (must be NHWC with matching channels)
118        let shape = input.shape();
119        if shape.len() != 4 {
120            return Err(CnnError::invalid_shape(
121                "4D tensor (NHWC)",
122                format!("{}D tensor", shape.len()),
123            ));
124        }
125
126        let channels = shape[3];
127        if channels != self.num_features {
128            return Err(CnnError::invalid_shape(
129                format!("{} channels", self.num_features),
130                format!("{} channels", channels),
131            ));
132        }
133
134        let mut output = Tensor::zeros(shape);
135
136        simd::batch_norm_simd(
137            input.data(),
138            output.data_mut(),
139            &self.gamma,
140            &self.beta,
141            &self.running_mean,
142            &self.running_var,
143            self.epsilon,
144            self.num_features,
145        );
146
147        Ok(output)
148    }
149
150    fn name(&self) -> &'static str {
151        "BatchNorm"
152    }
153
154    fn num_params(&self) -> usize {
155        // gamma + beta
156        self.num_features * 2
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_batch_norm_creation() {
166        let bn = BatchNorm::new(64);
167        assert_eq!(bn.num_features(), 64);
168        assert_eq!(bn.gamma().len(), 64);
169        assert_eq!(bn.beta().len(), 64);
170        assert_eq!(bn.num_params(), 128);
171    }
172
173    #[test]
174    fn test_batch_norm_forward() {
175        let bn = BatchNorm::new(4);
176        let input = Tensor::ones(&[1, 8, 8, 4]);
177        let output = bn.forward(&input).unwrap();
178
179        assert_eq!(output.shape(), input.shape());
180
181        // With default params (gamma=1, beta=0, mean=0, var=1):
182        // output = 1 * (1 - 0) / sqrt(1 + eps) + 0 ≈ 1
183        for &val in output.data() {
184            assert!((val - 1.0).abs() < 0.01);
185        }
186    }
187
188    #[test]
189    fn test_batch_norm_normalization() {
190        let mut bn = BatchNorm::new(2);
191
192        // Set mean=[1, 2], var=[1, 4]
193        bn.set_running_stats(vec![1.0, 2.0], vec![1.0, 4.0]).unwrap();
194
195        // Input: [[1, 2], [3, 4]] at each spatial location
196        let input = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], &[1, 2, 1, 2]).unwrap();
197        let output = bn.forward(&input).unwrap();
198
199        // For channel 0: (x - 1) / sqrt(1 + eps) ≈ (x - 1)
200        // For channel 1: (x - 2) / sqrt(4 + eps) ≈ (x - 2) / 2
201
202        // input[0,0] = 1, channel 0: (1-1)/1 = 0
203        assert!(output.data()[0].abs() < 0.01);
204        // input[0,1] = 2, channel 1: (2-2)/2 = 0
205        assert!(output.data()[1].abs() < 0.01);
206        // input[1,0] = 3, channel 0: (3-1)/1 = 2
207        assert!((output.data()[2] - 2.0).abs() < 0.01);
208        // input[1,1] = 4, channel 1: (4-2)/2 = 1
209        assert!((output.data()[3] - 1.0).abs() < 0.01);
210    }
211
212    #[test]
213    fn test_batch_norm_invalid_shape() {
214        let bn = BatchNorm::new(4);
215        let input = Tensor::ones(&[1, 8, 8, 8]); // Wrong number of channels
216
217        let result = bn.forward(&input);
218        assert!(result.is_err());
219    }
220}