1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7
8use super::reductions::{simd_mean_f32, simd_mean_f64, simd_variance_f32, simd_variance_f64};
10
11#[allow(dead_code)]
36pub fn simd_batch_norm_f32(
37 input: &ArrayView2<f32>,
38 gamma: &ArrayView1<f32>,
39 beta: &ArrayView1<f32>,
40 eps: f32,
41) -> (Array2<f32>, Array1<f32>, Array1<f32>) {
42 let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
43
44 let mut batch_mean = Array1::zeros(num_features);
46 let mut batch_var = Array1::zeros(num_features);
47
48 for j in 0..num_features {
49 let feature_col = input.column(j).to_owned();
51 batch_mean[j] = simd_mean_f32(&feature_col.view());
52 batch_var[j] = simd_variance_f32(&feature_col.view());
53 }
54
55 let mut output = Array2::zeros((batch_size, num_features));
57 for i in 0..batch_size {
58 for j in 0..num_features {
59 let x_norm = (input[[i, j]] - batch_mean[j]) / (batch_var[j] + eps).sqrt();
60 output[[i, j]] = gamma[j] * x_norm + beta[j];
61 }
62 }
63
64 (output, batch_mean, batch_var)
65}
66
67#[allow(dead_code)]
69pub fn simd_batch_norm_f64(
70 input: &ArrayView2<f64>,
71 gamma: &ArrayView1<f64>,
72 beta: &ArrayView1<f64>,
73 eps: f64,
74) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
75 let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
76
77 let mut batch_mean = Array1::zeros(num_features);
78 let mut batch_var = Array1::zeros(num_features);
79
80 for j in 0..num_features {
81 let feature_col = input.column(j).to_owned();
82 batch_mean[j] = simd_mean_f64(&feature_col.view());
83 batch_var[j] = simd_variance_f64(&feature_col.view());
84 }
85
86 let mut output = Array2::zeros((batch_size, num_features));
87 for i in 0..batch_size {
88 for j in 0..num_features {
89 let x_norm = (input[[i, j]] - batch_mean[j]) / (batch_var[j] + eps).sqrt();
90 output[[i, j]] = gamma[j] * x_norm + beta[j];
91 }
92 }
93
94 (output, batch_mean, batch_var)
95}
96
97#[allow(dead_code)]
123pub fn simd_layer_norm_f32(
124 input: &ArrayView2<f32>,
125 gamma: &ArrayView1<f32>,
126 beta: &ArrayView1<f32>,
127 eps: f32,
128) -> (Array2<f32>, Array1<f32>, Array1<f32>) {
129 let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
130
131 let mut sample_means = Array1::zeros(batch_size);
132 let mut sample_vars = Array1::zeros(batch_size);
133 let mut output = Array2::zeros((batch_size, num_features));
134
135 for i in 0..batch_size {
137 let sample = input.row(i);
138 sample_means[i] = simd_mean_f32(&sample);
139 sample_vars[i] = simd_variance_f32(&sample);
140
141 let mean = sample_means[i];
142 let inv_std = 1.0 / (sample_vars[i] + eps).sqrt();
143
144 for j in 0..num_features {
146 let x_norm = (sample[j] - mean) * inv_std;
147 output[[i, j]] = gamma[j] * x_norm + beta[j];
148 }
149 }
150
151 (output, sample_means, sample_vars)
152}
153
154#[allow(dead_code)]
156pub fn simd_layer_norm_f64(
157 input: &ArrayView2<f64>,
158 gamma: &ArrayView1<f64>,
159 beta: &ArrayView1<f64>,
160 eps: f64,
161) -> (Array2<f64>, Array1<f64>, Array1<f64>) {
162 let (batch_size, num_features) = (input.shape()[0], input.shape()[1]);
163
164 let mut sample_means = Array1::zeros(batch_size);
165 let mut sample_vars = Array1::zeros(batch_size);
166 let mut output = Array2::zeros((batch_size, num_features));
167
168 for i in 0..batch_size {
169 let sample = input.row(i);
170 sample_means[i] = simd_mean_f64(&sample);
171 sample_vars[i] = simd_variance_f64(&sample);
172
173 let mean = sample_means[i];
174 let inv_std = 1.0 / (sample_vars[i] + eps).sqrt();
175
176 for j in 0..num_features {
177 let x_norm = (sample[j] - mean) * inv_std;
178 output[[i, j]] = gamma[j] * x_norm + beta[j];
179 }
180 }
181
182 (output, sample_means, sample_vars)
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use ndarray::array;
189
190 #[test]
191 fn test_simd_batch_norm_f32_basic() {
192 let input = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]];
193 let gamma = array![1.0f32, 1.0];
194 let beta = array![0.0f32, 0.0];
195 let eps = 1e-5;
196
197 let (output, mean, var) =
198 simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
199
200 assert!((mean[0] - 3.0).abs() < 1e-5);
202 assert!((mean[1] - 4.0).abs() < 1e-5);
203
204 assert!(output.shape() == [3, 2]);
206 }
207
208 #[test]
209 fn test_simd_batch_norm_f64_basic() {
210 let input = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
211 let gamma = array![1.0f64, 1.0];
212 let beta = array![0.0f64, 0.0];
213 let eps = 1e-10;
214
215 let (output, mean, var) =
216 simd_batch_norm_f64(&input.view(), &gamma.view(), &beta.view(), eps);
217
218 assert!((mean[0] - 3.0).abs() < 1e-10);
219 assert!((mean[1] - 4.0).abs() < 1e-10);
220 assert!(output.shape() == [3, 2]);
221 }
222
223 #[test]
224 fn test_simd_layer_norm_f32_basic() {
225 let input = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
226 let gamma = array![1.0f32, 1.0, 1.0];
227 let beta = array![0.0f32, 0.0, 0.0];
228 let eps = 1e-5;
229
230 let (output, means, vars) =
231 simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
232
233 assert!((means[0] - 2.0).abs() < 1e-5);
236 assert!((means[1] - 5.0).abs() < 1e-5);
237
238 assert!(output.shape() == [2, 3]);
239 }
240
241 #[test]
242 fn test_simd_layer_norm_f64_basic() {
243 let input = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
244 let gamma = array![1.0f64, 1.0, 1.0];
245 let beta = array![0.0f64, 0.0, 0.0];
246 let eps = 1e-10;
247
248 let (output, means, vars) =
249 simd_layer_norm_f64(&input.view(), &gamma.view(), &beta.view(), eps);
250
251 assert!((means[0] - 2.0).abs() < 1e-10);
252 assert!((means[1] - 5.0).abs() < 1e-10);
253 assert!(output.shape() == [2, 3]);
254 }
255
256 #[test]
257 fn test_simd_batch_norm_f32_scale_shift() {
258 let input = array![[0.0f32, 1.0], [2.0, 3.0]];
259 let gamma = array![2.0f32, 3.0];
260 let beta = array![1.0f32, -1.0];
261 let eps = 1e-5;
262
263 let (output, _mean, _var) =
264 simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
265
266 assert!(output.shape() == [2, 2]);
268 for &val in output.iter() {
270 assert!(val.is_finite());
271 }
272 }
273
274 #[test]
275 fn test_simd_layer_norm_f32_scale_shift() {
276 let input = array![[1.0f32, 2.0, 3.0]];
277 let gamma = array![2.0f32, 2.0, 2.0];
278 let beta = array![1.0f32, 1.0, 1.0];
279 let eps = 1e-5;
280
281 let (output, _means, _vars) =
282 simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
283
284 assert!(output.shape() == [1, 3]);
286 for &val in output.iter() {
287 assert!(val.is_finite());
288 }
289 }
290
291 #[test]
292 fn test_simd_batch_norm_f32_empty() {
293 let input: Array2<f32> = Array2::zeros((0, 3));
294 let gamma = array![1.0f32, 1.0, 1.0];
295 let beta = array![0.0f32, 0.0, 0.0];
296 let eps = 1e-5;
297
298 let (output, _mean, _var) =
299 simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
300
301 assert_eq!(output.shape(), &[0, 3]);
302 }
303
304 #[test]
305 fn test_simd_layer_norm_f32_empty() {
306 let input: Array2<f32> = Array2::zeros((0, 3));
307 let gamma = array![1.0f32, 1.0, 1.0];
308 let beta = array![0.0f32, 0.0, 0.0];
309 let eps = 1e-5;
310
311 let (output, _means, _vars) =
312 simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
313
314 assert_eq!(output.shape(), &[0, 3]);
315 }
316
317 #[test]
318 fn test_simd_batch_norm_f32_correctness() {
319 let input = array![[0.0f32, 0.0], [1.0, 1.0], [2.0, 2.0]];
321 let gamma = array![1.0f32, 1.0];
322 let beta = array![0.0f32, 0.0];
323 let eps = 0.0;
324
325 let (output, mean, var) =
326 simd_batch_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
327
328 assert!((mean[0] - 1.0).abs() < 1e-5);
330 assert!((mean[1] - 1.0).abs() < 1e-5);
331
332 assert!(var[0] > 0.0 && var[0] < 10.0);
334 assert!(var[1] > 0.0 && var[1] < 10.0);
335
336 for &val in output.iter() {
338 assert!(val.is_finite());
339 }
340 }
341
342 #[test]
343 fn test_simd_layer_norm_f32_correctness() {
344 let input = array![[0.0f32, 1.0, 2.0]];
346 let gamma = array![1.0f32, 1.0, 1.0];
347 let beta = array![0.0f32, 0.0, 0.0];
348 let eps = 0.0;
349
350 let (output, means, _vars) =
351 simd_layer_norm_f32(&input.view(), &gamma.view(), &beta.view(), eps);
352
353 assert!((means[0] - 1.0).abs() < 1e-5);
355
356 assert!(output[[0, 1]].abs() < 1e-5); }
362}