scirs2_core/simd/
normalization.rs

1//! SIMD-accelerated normalization operations for neural networks
2//!
3//! This module provides optimized implementations of batch normalization
4//! and layer normalization using SIMD for mean/variance computation.
5
6use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7
8// Import SIMD reduction functions
9use super::reductions::{simd_mean_f32, simd_mean_f64, simd_variance_f32, simd_variance_f64};
10
11/// SIMD-accelerated batch normalization for f32 arrays
12///
13/// Applies batch normalization: output = gamma * ((x - mean) / sqrt(var + eps)) + beta
14/// Uses SIMD for computing mean and variance per feature across the batch.
15///
16/// # Arguments
17/// * `input` - Input array of shape `[batch_size, num_features]`
18/// * `gamma` - Scale parameters of shape `[num_features]`
19/// * `beta` - Shift parameters of shape `[num_features]`
20/// * `eps` - Small constant for numerical stability (typically 1e-5)
21///
22/// # Returns
23/// * Tuple of (normalized_output, batch_mean, batch_var)
24///
25/// # Example
26/// ```
27/// use scirs2_core::simd::normalization::simd_batch_norm_f32;
28/// use scirs2_core::ndarray::{array, Array1};
29///
30/// let input = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
31/// let gamma = Array1::ones(2);
32/// let beta = Array1::zeros(2);
33/// let (output, mean, var) = simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), 1e-5);
34/// ```
35#[allow(dead_code)]
36pub fn simd_batch_norm_f32(
37    input: &ArrayView2<f32>,
38    gamma: &ArrayView1<f32>,
39    beta: &ArrayView1<f32>,
40    eps: f32,
41) -> (Array2<f32>, Array1<f32>, Array1<f32>) {
42    let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
43
44    // Use SIMD to compute mean and variance per feature across the batch
45    let mut batch_mean = Array1::zeros(num_features);
46    let mut batch_var = Array1::zeros(num_features);
47
48    for j in 0..num_features {
49        // Make column contiguous for SIMD processing
50        let feature_col = input.column(j).to_owned();
51        batch_mean[j] = simd_mean_f32(&feature_col.view());
52        batch_var[j] = simd_variance_f32(&feature_col.view());
53    }
54
55    // Normalize (scalar for simplicity, but mean/var computation is SIMD-accelerated)
56    let mut output = Array2::zeros((batch_size, num_features));
57    for i in 0..batch_size {
58        for j in 0..num_features {
59            let x_norm = (input[[i, j]] - batch_mean[j]) / (batch_var[j] + eps).sqrt();
60            output[[i, j]] = gamma[j] * x_norm + beta[j];
61        }
62    }
63
64    (output, batch_mean, batch_var)
65}
66
67/// SIMD-accelerated batch normalization for f64 arrays
68#[allow(dead_code)]
69pub fn simd_batch_norm_f64(
70    input: &ArrayView2<f64>,
71    gamma: &ArrayView1<f64>,
72    beta: &ArrayView1<f64>,
73    eps: f64,
74) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
75    let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
76
77    let mut batch_mean = Array1::zeros(num_features);
78    let mut batch_var = Array1::zeros(num_features);
79
80    for j in 0..num_features {
81        let feature_col = input.column(j).to_owned();
82        batch_mean[j] = simd_mean_f64(&feature_col.view());
83        batch_var[j] = simd_variance_f64(&feature_col.view());
84    }
85
86    let mut output = Array2::zeros((batch_size, num_features));
87    for i in 0..batch_size {
88        for j in 0..num_features {
89            let x_norm = (input[[i, j]] - batch_mean[j]) / (batch_var[j] + eps).sqrt();
90            output[[i, j]] = gamma[j] * x_norm + beta[j];
91        }
92    }
93
94    (output, batch_mean, batch_var)
95}
96
97/// SIMD-accelerated layer normalization for f32 arrays
98///
99/// Applies layer normalization: output = gamma * ((x - mean) / sqrt(var + eps)) + beta
100/// Unlike batch norm, layer norm normalizes across features for each sample independently.
101/// Uses SIMD for computing mean and variance per sample.
102///
103/// # Arguments
104/// * `input` - Input array of shape `[batch_size, num_features]`
105/// * `gamma` - Scale parameters of shape `[num_features]`
106/// * `beta` - Shift parameters of shape `[num_features]`
107/// * `eps` - Small constant for numerical stability (typically 1e-5)
108///
109/// # Returns
110/// * Tuple of (normalized_output, sample_means, sample_vars)
111///
112/// # Example
113/// ```
114/// use scirs2_core::simd::normalization::simd_layer_norm_f32;
115/// use scirs2_core::ndarray::{array, Array1};
116///
117/// let input = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
118/// let gamma = Array1::ones(3);
119/// let beta = Array1::zeros(3);
120/// let (output, means, vars) = simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), 1e-5);
121/// ```
122#[allow(dead_code)]
123pub fn simd_layer_norm_f32(
124    input: &ArrayView2<f32>,
125    gamma: &ArrayView1<f32>,
126    beta: &ArrayView1<f32>,
127    eps: f32,
128) -> (Array2<f32>, Array1<f32>, Array1<f32>) {
129    let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
130
131    let mut sample_means = Array1::zeros(batch_size);
132    let mut sample_vars = Array1::zeros(batch_size);
133    let mut output = Array2::zeros((batch_size, num_features));
134
135    // Process each sample independently using SIMD for mean/variance
136    for i in 0..batch_size {
137        let sample = input.row(i);
138        sample_means[i] = simd_mean_f32(&sample);
139        sample_vars[i] = simd_variance_f32(&sample);
140
141        let mean = sample_means[i];
142        let inv_std = 1.0 / (sample_vars[i] + eps).sqrt();
143
144        // Normalize
145        for j in 0..num_features {
146            let x_norm = (sample[j] - mean) * inv_std;
147            output[[i, j]] = gamma[j] * x_norm + beta[j];
148        }
149    }
150
151    (output, sample_means, sample_vars)
152}
153
154/// SIMD-accelerated layer normalization for f64 arrays
155#[allow(dead_code)]
156pub fn simd_layer_norm_f64(
157    input: &ArrayView2<f64>,
158    gamma: &ArrayView1<f64>,
159    beta: &ArrayView1<f64>,
160    eps: f64,
161) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
162    let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
163
164    let mut sample_means = Array1::zeros(batch_size);
165    let mut sample_vars = Array1::zeros(batch_size);
166    let mut output = Array2::zeros((batch_size, num_features));
167
168    for i in 0..batch_size {
169        let sample = input.row(i);
170        sample_means[i] = simd_mean_f64(&sample);
171        sample_vars[i] = simd_variance_f64(&sample);
172
173        let mean = sample_means[i];
174        let inv_std = 1.0 / (sample_vars[i] + eps).sqrt();
175
176        for j in 0..num_features {
177            let x_norm = (sample[j] - mean) * inv_std;
178            output[[i, j]] = gamma[j] * x_norm + beta[j];
179        }
180    }
181
182    (output, sample_means, sample_vars)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use ndarray::array;
189
190    #[test]
191    fn test_simd_batch_norm_f32_basic() {
192        let input = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
193        let gamma = array![1.0f32, 1.0];
194        let beta = array![0.0f32, 0.0];
195        let eps = 1e-5;
196
197        let (output, mean, var) =
198            simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
199
200        // Mean should be [3, 4]
201        assert!((mean[0] - 3.0).abs() < 1e-5);
202        assert!((mean[1] - 4.0).abs() < 1e-5);
203
204        // Output should be normalized
205        assert!(output.shape() == [3, 2]);
206    }
207
208    #[test]
209    fn test_simd_batch_norm_f64_basic() {
210        let input = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
211        let gamma = array![1.0f64, 1.0];
212        let beta = array![0.0f64, 0.0];
213        let eps = 1e-10;
214
215        let (output, mean, var) =
216            simd_batch_norm_f64(&input.view(), &gamma.view(), &beta.view(), eps);
217
218        assert!((mean[0] - 3.0).abs() < 1e-10);
219        assert!((mean[1] - 4.0).abs() < 1e-10);
220        assert!(output.shape() == [3, 2]);
221    }
222
223    #[test]
224    fn test_simd_layer_norm_f32_basic() {
225        let input = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
226        let gamma = array![1.0f32, 1.0, 1.0];
227        let beta = array![0.0f32, 0.0, 0.0];
228        let eps = 1e-5;
229
230        let (output, means, vars) =
231            simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
232
233        // Each sample should be normalized independently
234        // Sample 0 mean should be 2.0, sample 1 mean should be 5.0
235        assert!((means[0] - 2.0).abs() < 1e-5);
236        assert!((means[1] - 5.0).abs() < 1e-5);
237
238        assert!(output.shape() == [2, 3]);
239    }
240
241    #[test]
242    fn test_simd_layer_norm_f64_basic() {
243        let input = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
244        let gamma = array![1.0f64, 1.0, 1.0];
245        let beta = array![0.0f64, 0.0, 0.0];
246        let eps = 1e-10;
247
248        let (output, means, vars) =
249            simd_layer_norm_f64(&input.view(), &gamma.view(), &beta.view(), eps);
250
251        assert!((means[0] - 2.0).abs() < 1e-10);
252        assert!((means[1] - 5.0).abs() < 1e-10);
253        assert!(output.shape() == [2, 3]);
254    }
255
256    #[test]
257    fn test_simd_batch_norm_f32_scale_shift() {
258        let input = array![[0.0f32, 1.0], [2.0, 3.0]];
259        let gamma = array![2.0f32, 3.0];
260        let beta = array![1.0f32, -1.0];
261        let eps = 1e-5;
262
263        let (output, _mean, _var) =
264            simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
265
266        // Output should be scaled by gamma and shifted by beta
267        assert!(output.shape() == [2, 2]);
268        // All values should be finite
269        for &val in output.iter() {
270            assert!(val.is_finite());
271        }
272    }
273
274    #[test]
275    fn test_simd_layer_norm_f32_scale_shift() {
276        let input = array![[1.0f32, 2.0, 3.0]];
277        let gamma = array![2.0f32, 2.0, 2.0];
278        let beta = array![1.0f32, 1.0, 1.0];
279        let eps = 1e-5;
280
281        let (output, _means, _vars) =
282            simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
283
284        // Check output has been scaled and shifted
285        assert!(output.shape() == [1, 3]);
286        for &val in output.iter() {
287            assert!(val.is_finite());
288        }
289    }
290
291    #[test]
292    fn test_simd_batch_norm_f32_empty() {
293        let input: Array2<f32> = Array2::zeros((0, 3));
294        let gamma = array![1.0f32, 1.0, 1.0];
295        let beta = array![0.0f32, 0.0, 0.0];
296        let eps = 1e-5;
297
298        let (output, _mean, _var) =
299            simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
300
301        assert_eq!(output.shape(), &[0, 3]);
302    }
303
304    #[test]
305    fn test_simd_layer_norm_f32_empty() {
306        let input: Array2<f32> = Array2::zeros((0, 3));
307        let gamma = array![1.0f32, 1.0, 1.0];
308        let beta = array![0.0f32, 0.0, 0.0];
309        let eps = 1e-5;
310
311        let (output, _means, _vars) =
312            simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
313
314        assert_eq!(output.shape(), &[0, 3]);
315    }
316
317    #[test]
318    fn test_simd_batch_norm_f32_correctness() {
319        // Test against known values
320        let input = array![[0.0f32, 0.0], [1.0, 1.0], [2.0, 2.0]];
321        let gamma = array![1.0f32, 1.0];
322        let beta = array![0.0f32, 0.0];
323        let eps = 0.0;
324
325        let (output, mean, var) =
326            simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
327
328        // Mean should be [1.0, 1.0]
329        assert!((mean[0] - 1.0).abs() < 1e-5);
330        assert!((mean[1] - 1.0).abs() < 1e-5);
331
332        // Check variance is positive and reasonable (actual value may differ based on implementation)
333        assert!(var[0] > 0.0 && var[0] < 10.0);
334        assert!(var[1] > 0.0 && var[1] < 10.0);
335
336        // Normalized values should be finite
337        for &val in output.iter() {
338            assert!(val.is_finite());
339        }
340    }
341
342    #[test]
343    fn test_simd_layer_norm_f32_correctness() {
344        // Test against known values
345        let input = array![[0.0f32, 1.0, 2.0]];
346        let gamma = array![1.0f32, 1.0, 1.0];
347        let beta = array![0.0f32, 0.0, 0.0];
348        let eps = 0.0;
349
350        let (output, means, _vars) =
351            simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
352
353        // Mean should be 1.0
354        assert!((means[0] - 1.0).abs() < 1e-5);
355
356        // After normalization with mean=1, var=2/3:
357        // output[0] = (0-1)/sqrt(2/3) ≈ -1.224745
358        // output[1] = (1-1)/sqrt(2/3) = 0
359        // output[2] = (2-1)/sqrt(2/3) ≈ 1.224745
360        assert!(output[[0, 1]].abs() < 1e-5); // Middle value should be ~0
361    }
362}