1use crate::{
2 device::{cpu::Cpu, Device, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Ref},
5 num::Num,
6};
7
8#[cfg(feature = "nvidia")]
9use zenu_cuda::cudnn::{
10 batch_norm::{
11 BatchNorm2d, BatchNorm2dBackward, BatchNorm2dBackwardBuilder, BatchNorm2dBuilder,
12 BatchNorm2dInference, BatchNorm2dInferenceBuilder,
13 },
14 TensorFormat,
15};
16
17#[cfg(feature = "nvidia")]
18use crate::device::nvidia::Nvidia;
19
20pub struct BatchNorm2dConfig<T: Num> {
21 #[cfg(feature = "nvidia")]
22 pub device_batch_norm: BatchNorm2d<T>,
23 _phantom: std::marker::PhantomData<T>,
24}
25
26impl<T: Num> BatchNorm2dConfig<T> {
27 #[must_use]
28 #[allow(unused_variables)]
29 pub fn new(dim: DimDyn) -> Self {
30 BatchNorm2dConfig::<T> {
31 #[cfg(feature = "nvidia")]
32 device_batch_norm: create_batch_norm_gpu::<T>(dim),
33 _phantom: std::marker::PhantomData,
34 }
35 }
36}
37
38pub struct BatchNorm2dBackwardConfig<T> {
39 #[cfg(feature = "nvidia")]
40 pub device_batch_norm_backward: BatchNorm2dBackward<T>,
41 _phantom: std::marker::PhantomData<T>,
42}
43
44impl<T: Num> BatchNorm2dBackwardConfig<T> {
45 #[must_use]
46 #[allow(unused_variables)]
47 pub fn new(dim: DimDyn) -> Self {
48 BatchNorm2dBackwardConfig::<T> {
49 #[cfg(feature = "nvidia")]
50 device_batch_norm_backward: create_batch_norm_backward_gpu::<T>(dim),
51 _phantom: std::marker::PhantomData,
52 }
53 }
54}
55
56pub struct BatchNorm2dInferenceConfig<T> {
57 #[cfg(feature = "nvidia")]
58 pub device_batch_norm_inference: BatchNorm2dInference<T>,
59 _phantom: std::marker::PhantomData<T>,
60}
61
62impl<T: Num> BatchNorm2dInferenceConfig<T> {
63 #[must_use]
64 pub fn new(dim: DimDyn) -> Self {
65 BatchNorm2dInferenceConfig::<T> {
66 #[cfg(feature = "nvidia")]
67 device_batch_norm_inference: create_batch_norm_inference_gpu::<T>(dim),
68 _phantom: std::marker::PhantomData,
69 }
70 }
71}
72
73#[cfg(feature = "nvidia")]
74fn create_batch_norm_gpu<T: Num>(input: DimDyn) -> BatchNorm2d<T> {
75 let input = (
76 input[0].try_into().unwrap(),
77 input[1].try_into().unwrap(),
78 input[2].try_into().unwrap(),
79 input[3].try_into().unwrap(),
80 );
81 BatchNorm2dBuilder::<T>::new()
82 .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
83 .unwrap()
84 .output(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
85 .unwrap()
86 .scale_bias_mean_var(input.1, TensorFormat::NCHW)
87 .unwrap()
88 .build()
89}
90
91#[cfg(feature = "nvidia")]
92fn create_batch_norm_backward_gpu<T: Num>(input: DimDyn) -> BatchNorm2dBackward<T> {
93 let input = (
94 input[0].try_into().unwrap(),
95 input[1].try_into().unwrap(),
96 input[2].try_into().unwrap(),
97 input[3].try_into().unwrap(),
98 );
99 BatchNorm2dBackwardBuilder::<T>::new()
100 .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
101 .unwrap()
102 .input_grad(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
103 .unwrap()
104 .output_grad(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
105 .unwrap()
106 .scale_bias_mean_var(input.1, TensorFormat::NCHW)
107 .unwrap()
108 .build()
109}
110
111#[cfg(feature = "nvidia")]
112fn create_batch_norm_inference_gpu<T: Num>(input: DimDyn) -> BatchNorm2dInference<T> {
113 let input = (
114 input[0].try_into().unwrap(),
115 input[1].try_into().unwrap(),
116 input[2].try_into().unwrap(),
117 input[3].try_into().unwrap(),
118 );
119 BatchNorm2dInferenceBuilder::<T>::new()
120 .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
121 .unwrap()
122 .output(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
123 .unwrap()
124 .scale_bias_mean_var(input.1, TensorFormat::NCHW)
125 .unwrap()
126 .build()
127}
128
129pub trait BatchNormalization: DeviceBase {
130 #[expect(clippy::too_many_arguments)]
131 fn batch_norm_2d_forward_train<T: Num>(
132 momentum: f64,
133 x: Matrix<Ref<&T>, DimDyn, Self>,
134 y: Matrix<Ref<&mut T>, DimDyn, Self>,
135 scale: Matrix<Ref<&T>, DimDyn, Self>,
136 bias: Matrix<Ref<&T>, DimDyn, Self>,
137 mean: Matrix<Ref<&mut T>, DimDyn, Self>,
138 variance: Matrix<Ref<&mut T>, DimDyn, Self>,
139 saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
140 saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
141 device_batch_norm: &Option<BatchNorm2dConfig<T>>,
142 );
143
144 #[expect(clippy::too_many_arguments)]
145 fn batch_norm_2d_backward<T: Num>(
146 x: Matrix<Ref<&T>, DimDyn, Self>,
147 y_grad: Matrix<Ref<&T>, DimDyn, Self>,
148 x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
149 scale: Matrix<Ref<&T>, DimDyn, Self>,
150 scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
151 bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
152 saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
153 saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
154 device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
155 );
156
157 fn bach_norm_2d_forward_inference<T: Num>(
158 x: Matrix<Ref<&T>, DimDyn, Self>,
159 y: Matrix<Ref<&mut T>, DimDyn, Self>,
160 scale: Matrix<Ref<&T>, DimDyn, Self>,
161 bias: Matrix<Ref<&T>, DimDyn, Self>,
162 mean: Matrix<Ref<&T>, DimDyn, Self>,
163 variance: Matrix<Ref<&T>, DimDyn, Self>,
164 device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
165 );
166}
167
168#[cfg(feature = "nvidia")]
169impl BatchNormalization for Nvidia {
170 fn batch_norm_2d_forward_train<T: Num>(
171 momentum: f64,
172 x: Matrix<Ref<&T>, DimDyn, Self>,
173 y: Matrix<Ref<&mut T>, DimDyn, Self>,
174 scale: Matrix<Ref<&T>, DimDyn, Self>,
175 bias: Matrix<Ref<&T>, DimDyn, Self>,
176 mean: Matrix<Ref<&mut T>, DimDyn, Self>,
177 variance: Matrix<Ref<&mut T>, DimDyn, Self>,
178 saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
179 saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
180 device_batch_norm: &Option<BatchNorm2dConfig<T>>,
181 ) {
182 let momentum = 1. - momentum;
183 let batch_norm = match device_batch_norm {
184 Some(ref batch_norm) => &batch_norm.device_batch_norm,
185 None => &create_batch_norm_gpu::<T>(x.shape()),
186 };
187 let saving_mean = match saving_mean {
188 Some(saved_mean) => saved_mean.as_mut_ptr(),
189 None => std::ptr::null_mut(),
190 };
191 let saving_inv_variance = match saving_inv_variance {
192 Some(saved_inv_variance) => saved_inv_variance.as_mut_ptr(),
193 None => std::ptr::null_mut(),
194 };
195 batch_norm
196 .forward_train(
197 T::one(),
198 T::zero(),
199 x.as_ptr(),
200 y.as_mut_ptr(),
201 scale.as_ptr(),
202 bias.as_ptr(),
203 mean.as_mut_ptr(),
204 variance.as_mut_ptr(),
205 momentum,
206 saving_mean,
207 saving_inv_variance,
208 )
209 .unwrap();
210 }
211
212 fn batch_norm_2d_backward<T: Num>(
213 x: Matrix<Ref<&T>, DimDyn, Self>,
214 y_grad: Matrix<Ref<&T>, DimDyn, Self>,
215 x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
216 scale: Matrix<Ref<&T>, DimDyn, Self>,
217 scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
218 bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
219 saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
220 saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
221 device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
222 ) {
223 let batch_norm_backward = match device_batch_norm_backward {
224 Some(ref batch_norm_backward) => &batch_norm_backward.device_batch_norm_backward,
225 None => &create_batch_norm_backward_gpu::<T>(x.shape()),
226 };
227 let saving_mean = match saving_mean {
228 Some(saved_mean) => saved_mean.as_ptr(),
229 None => std::ptr::null_mut(),
230 };
231 let saving_inv_variance = match saving_inv_variance {
232 Some(saved_inv_variance) => saved_inv_variance.as_ptr(),
233 None => std::ptr::null_mut(),
234 };
235 batch_norm_backward
236 .backward(
237 T::one(),
238 T::zero(),
239 T::one(),
240 T::zero(),
241 x.as_ptr(),
242 y_grad.as_ptr(),
243 x_grad.as_mut_ptr(),
244 scale.as_ptr(),
245 scale_grad.as_mut_ptr(),
246 bias_grad.as_mut_ptr(),
247 saving_mean,
248 saving_inv_variance,
249 )
250 .unwrap();
251 }
252
253 fn bach_norm_2d_forward_inference<T: Num>(
254 x: Matrix<Ref<&T>, DimDyn, Self>,
255 y: Matrix<Ref<&mut T>, DimDyn, Self>,
256 scale: Matrix<Ref<&T>, DimDyn, Self>,
257 bias: Matrix<Ref<&T>, DimDyn, Self>,
258 mean: Matrix<Ref<&T>, DimDyn, Self>,
259 variance: Matrix<Ref<&T>, DimDyn, Self>,
260 device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
261 ) {
262 let batch_norm_inference = match device_batch_norm_inference {
263 Some(ref batch_norm_inference) => &batch_norm_inference.device_batch_norm_inference,
264 None => &create_batch_norm_inference_gpu::<T>(x.shape()),
265 };
266 batch_norm_inference
267 .forward_inference(
268 T::one(),
269 T::zero(),
270 x.as_ptr(),
271 y.as_mut_ptr(),
272 scale.as_ptr(),
273 bias.as_ptr(),
274 mean.as_ptr(),
275 variance.as_ptr(),
276 )
277 .unwrap();
278 }
279}
280
281impl BatchNormalization for Cpu {
282 fn batch_norm_2d_forward_train<T: Num>(
283 momentum: f64,
284 x: Matrix<Ref<&T>, DimDyn, Self>,
285 y: Matrix<Ref<&mut T>, DimDyn, Self>,
286 scale: Matrix<Ref<&T>, DimDyn, Self>,
287 bias: Matrix<Ref<&T>, DimDyn, Self>,
288 mean: Matrix<Ref<&mut T>, DimDyn, Self>,
289 variance: Matrix<Ref<&mut T>, DimDyn, Self>,
290 saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
291 saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
292 _: &Option<BatchNorm2dConfig<T>>,
293 ) {
294 let momentum = T::from_f64(momentum);
295 let epsilon = T::from_f64(1e-10);
296 let x_shape = x.shape();
297 let n = x_shape[0] * x_shape[2] * x_shape[3];
298 let c = x_shape[1];
299 let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
300 let x_reshaped = x_transposed.reshape([n, c]);
301
302 let num_elements = T::from_usize(x_reshaped.shape()[0]);
303
304 let x_mean = x_reshaped.mean(Some(0), false);
305 let x_diff = &x_reshaped - &x_mean;
306 let x_variance = x_reshaped.variance(Some(0), false);
307 let x_variance_unbiased = &x_variance * (num_elements / (num_elements - T::one()));
308
309 let mean_t = &x_mean * (T::one() - momentum) + &mean * momentum;
310 let variance_t = &x_variance_unbiased * (T::one() - momentum) + &variance * momentum;
311
312 let inv_var = Matrix::<_, DimDyn, _>::ones(variance_t.shape()) / (&x_variance + epsilon);
313 let inv_std = inv_var.sqrt();
314
315 mean.copy_from(&mean_t);
316 variance.copy_from(&variance_t);
317
318 if let Some(saving_mean_mat) = saving_mean {
319 saving_mean_mat.copy_from(&x_mean);
320 }
321 if let Some(saving_inv_variance_mat) = saving_inv_variance {
322 saving_inv_variance_mat.copy_from(&inv_std);
323 }
324
325 let x_normalized = &x_diff * &inv_std;
326 let y_tmp = &x_normalized * &scale + &bias;
327 let y_transposed = y_tmp.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
328 y.copy_from(&y_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
329 }
330
331 fn batch_norm_2d_backward<T: Num>(
332 x: Matrix<Ref<&T>, DimDyn, Self>,
333 y_grad: Matrix<Ref<&T>, DimDyn, Self>,
334 x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
335 scale: Matrix<Ref<&T>, DimDyn, Self>,
336 scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
337 bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
338 saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
339 saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
340 _: &Option<BatchNorm2dBackwardConfig<T>>,
341 ) {
342 let epsilon = T::from_f64(1e-10);
343 let n = x.shape()[0] * x.shape()[2] * x.shape()[3];
344 let c = x.shape()[1];
345 let x_shape = x.shape();
346
347 let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
349 let x_reshaped = x_transposed.reshape([n, c]);
350
351 let y_grad_transposed = y_grad.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
352 let y_grad_reshaped = y_grad_transposed.reshape([n, c]);
353
354 let mean = if let Some(ref mean_mat) = saving_mean {
355 mean_mat.new_matrix()
356 } else {
357 x_reshaped.mean(Some(0), false)
358 };
359
360 let inv_std = if let Some(ref inv_variance_mat) = saving_inv_variance {
361 inv_variance_mat.new_matrix()
362 } else {
363 let x_variance = x_reshaped.variance(Some(0), false);
364 let inv_var =
365 Matrix::<_, DimDyn, _>::ones(x_variance.shape()) / (&x_variance + epsilon);
366 inv_var.sqrt()
367 };
368
369 let x_centered = &x_reshaped - &mean;
370 let x_hat = &x_centered * &inv_std;
371
372 bias_grad.copy_from(&y_grad_reshaped.to_ref().sum(0, false));
373 scale_grad.copy_from(&(&x_hat * &y_grad_reshaped).to_ref().sum(0, false));
374
375 let term1 = &inv_std * &y_grad_reshaped * scale;
377 let mut term2 = term1.to_ref().sum(0, false) / T::from_usize(n);
378 term2.add_axis(0);
379 let mut term3 =
380 &x_centered * (&term1 * &x_centered).to_ref().sum(0, false) / T::from_usize(n);
381 term3.add_axis(0);
382 let term3 = term3 * &inv_std * &inv_std;
383
384 let x_grad_reshaped = term1 - term2 - term3;
385
386 let x_grad_transposed =
387 x_grad_reshaped.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
388
389 x_grad.copy_from(&x_grad_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
390 }
391
392 fn bach_norm_2d_forward_inference<T: Num>(
393 x: Matrix<Ref<&T>, DimDyn, Self>,
394 y: Matrix<Ref<&mut T>, DimDyn, Self>,
395 scale: Matrix<Ref<&T>, DimDyn, Self>,
396 bias: Matrix<Ref<&T>, DimDyn, Self>,
397 mean: Matrix<Ref<&T>, DimDyn, Self>,
398 variance: Matrix<Ref<&T>, DimDyn, Self>,
399 _: &Option<BatchNorm2dInferenceConfig<T>>,
400 ) {
401 let epsilon = T::from_f64(1e-10);
402 let n = x.shape()[0] * x.shape()[2] * x.shape()[3];
403 let c = x.shape()[1];
404 let x_shape = x.shape();
405
406 let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
408 let x_reshaped = x_transposed.reshape([n, c]);
409
410 let mean = mean.to_ref();
411 let inv_std = Matrix::<_, DimDyn, _>::ones(variance.shape()) / (&variance + epsilon).sqrt();
412
413 let x_centered = &x_reshaped - mean;
414 let x_hat = &x_centered * &inv_std;
415
416 let y_tmp = &x_hat * &scale + &bias;
417 let y_transposed = y_tmp.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
418 y.copy_from(&y_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
419 }
420}
421
422#[expect(clippy::too_many_arguments)]
423fn batch_norm_2d_shape_check(
424 x: DimDyn,
425 y: DimDyn,
426 scale: DimDyn,
427 bias: DimDyn,
428 mean: DimDyn,
429 variance: DimDyn,
430 saving_mean: Option<DimDyn>,
431 saving_inv_variance: Option<DimDyn>,
432) -> Result<(), String> {
433 if scale.len() != 1 {
434 return Err("scale must be a vector".to_string());
435 }
436 if bias.len() != 1 {
437 return Err("bias must be a vector".to_string());
438 }
439 if mean.len() != 1 {
440 return Err("mean must be a vector".to_string());
441 }
442 if variance.len() != 1 {
443 return Err("variance must be a vector".to_string());
444 }
445 if let Some(saving_mean) = saving_mean {
446 if saving_mean.len() != 1 {
447 return Err("saving_mean must be a vector".to_string());
448 }
449 }
450 if let Some(saving_inv_variance) = saving_inv_variance {
451 if saving_inv_variance.len() != 1 {
452 return Err("saving_inv_variance must be a vector".to_string());
453 }
454 }
455 if x.len() != 4 {
456 return Err("x and y must have the same number of elements".to_string());
457 }
458 if x != y {
459 return Err("x and y must have the same shape".to_string());
460 }
461 if x[1] != scale[0] {
462 return Err("x and scale must have the same number of channels".to_string());
463 }
464 if x[1] != bias[0] {
465 return Err("x and bias must have the same number of channels".to_string());
466 }
467 if x[1] != mean[0] {
468 return Err("x and mean must have the same number of channels".to_string());
469 }
470 if x[1] != variance[0] {
471 return Err("x and variance must have the same number of channels".to_string());
472 }
473 if let Some(saving_mean) = saving_mean {
474 if x[1] != saving_mean[0] {
475 return Err("x and saving_mean must have the same number of channels".to_string());
476 }
477 }
478 if let Some(saving_inv_variance) = saving_inv_variance {
479 if x[1] != saving_inv_variance[0] {
480 return Err(
481 "x and saving_inv_variance must have the same number of channels".to_string(),
482 );
483 }
484 }
485 Ok(())
486}
487
488#[expect(clippy::too_many_arguments)]
489fn batch_norm_2d_backward_shape_check(
490 x: DimDyn,
491 y_grad: DimDyn,
492 x_grad: DimDyn,
493 scale: DimDyn,
494 scale_grad: DimDyn,
495 bias_grad: DimDyn,
496 saving_mean: Option<DimDyn>,
497 saving_inv_variance: Option<DimDyn>,
498) -> Result<(), String> {
499 if scale.len() != 1 {
500 return Err("scale must be a vector".to_string());
501 }
502 if bias_grad.len() != 1 {
503 return Err("bias_grad must be a vector".to_string());
504 }
505 if let Some(saving_mean) = saving_mean {
506 if saving_mean.len() != 1 {
507 return Err("saving_mean must be a vector".to_string());
508 }
509 }
510 if let Some(saving_inv_variance) = saving_inv_variance {
511 if saving_inv_variance.len() != 1 {
512 return Err("saving_inv_variance must be a vector".to_string());
513 }
514 }
515 if x.len() != 4 {
516 return Err("x and y_grad must have the same number of elements".to_string());
517 }
518 if x != y_grad {
519 return Err("x and y_grad must have the same shape".to_string());
520 }
521 if x != x_grad {
522 return Err("x and x_grad must have the same shape".to_string());
523 }
524 if x[1] != scale[0] {
525 return Err("x and scale must have the same number of channels".to_string());
526 }
527 if x[1] != scale_grad[0] {
528 return Err("x and scale_grad must have the same number of channels".to_string());
529 }
530 if x[1] != bias_grad[0] {
531 return Err("x and bias_grad must have the same number of channels".to_string());
532 }
533 if let Some(saving_mean) = saving_mean {
534 if x[1] != saving_mean[0] {
535 return Err("x and saving_mean must have the same number of channels".to_string());
536 }
537 }
538 if let Some(saving_inv_variance) = saving_inv_variance {
539 if x[1] != saving_inv_variance[0] {
540 return Err(
541 "x and saving_inv_variance must have the same number of channels".to_string(),
542 );
543 }
544 }
545 Ok(())
546}
547
548#[expect(clippy::too_many_arguments, clippy::missing_errors_doc)]
549pub fn try_batch_norm_2d_forward_trian<T: Num, D: Device>(
550 momentum: f64,
551 x: Matrix<Ref<&T>, DimDyn, D>,
552 y: Matrix<Ref<&mut T>, DimDyn, D>,
553 scale: Matrix<Ref<&T>, DimDyn, D>,
554 bias: Matrix<Ref<&T>, DimDyn, D>,
555 mean: Matrix<Ref<&mut T>, DimDyn, D>,
556 variance: Matrix<Ref<&mut T>, DimDyn, D>,
557 saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
558 saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
559 device_batch_norm: &Option<BatchNorm2dConfig<T>>,
560) -> Result<(), String> {
561 let x_shape = x.shape();
562 let y_shape = y.shape();
563 let scale_shape = scale.shape();
564 let bias_shape = bias.shape();
565 let mean_shape = mean.shape();
566 let variance_shape = variance.shape();
567 let saving_mean_shape = saving_mean.as_ref().map(Matrix::shape);
568 let saving_inv_variance_shape = saving_inv_variance.as_ref().map(Matrix::shape);
569
570 batch_norm_2d_shape_check(
571 x_shape,
572 y_shape,
573 scale_shape,
574 bias_shape,
575 mean_shape,
576 variance_shape,
577 saving_mean_shape,
578 saving_inv_variance_shape,
579 )?;
580
581 D::batch_norm_2d_forward_train(
582 momentum,
583 x,
584 y,
585 scale,
586 bias,
587 mean,
588 variance,
589 saving_mean,
590 saving_inv_variance,
591 device_batch_norm,
592 );
593
594 Ok(())
595}
596
597#[expect(clippy::missing_errors_doc)]
598pub fn try_batch_norm_2d_forward_inference<T: Num, D: Device>(
599 x: Matrix<Ref<&T>, DimDyn, D>,
600 y: Matrix<Ref<&mut T>, DimDyn, D>,
601 scale: Matrix<Ref<&T>, DimDyn, D>,
602 bias: Matrix<Ref<&T>, DimDyn, D>,
603 mean: Matrix<Ref<&T>, DimDyn, D>,
604 variance: Matrix<Ref<&T>, DimDyn, D>,
605 device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
606) -> Result<(), String> {
607 let x_shape = x.shape();
608 let y_shape = y.shape();
609 let scale_shape = scale.shape();
610 let bias_shape = bias.shape();
611 let mean_shape = mean.shape();
612 let variance_shape = variance.shape();
613
614 batch_norm_2d_shape_check(
615 x_shape,
616 y_shape,
617 scale_shape,
618 bias_shape,
619 mean_shape,
620 variance_shape,
621 None,
622 None,
623 )?;
624
625 D::bach_norm_2d_forward_inference(
626 x,
627 y,
628 scale,
629 bias,
630 mean,
631 variance,
632 device_batch_norm_inference,
633 );
634
635 Ok(())
636}
637
638#[expect(clippy::too_many_arguments, clippy::missing_errors_doc)]
639pub fn try_batch_norm_2d_backward<T: Num, D: Device>(
640 x: Matrix<Ref<&T>, DimDyn, D>,
641 y_grad: Matrix<Ref<&T>, DimDyn, D>,
642 x_grad: Matrix<Ref<&mut T>, DimDyn, D>,
643 scale: Matrix<Ref<&T>, DimDyn, D>,
644 scale_grad: Matrix<Ref<&mut T>, DimDyn, D>,
645 bias_grad: Matrix<Ref<&mut T>, DimDyn, D>,
646 saving_mean: Option<Matrix<Ref<&T>, DimDyn, D>>,
647 saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, D>>,
648 device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
649) -> Result<(), String> {
650 let x_shape = x.shape();
651 let y_grad_shape = y_grad.shape();
652 let x_grad_shape = x_grad.shape();
653 let scale_shape = scale.shape();
654 let scale_grad_shape = scale_grad.shape();
655 let bias_grad_shape = bias_grad.shape();
656 let saving_mean_shape = saving_mean.as_ref().map(Matrix::shape);
657 let saving_inv_variance_shape = saving_inv_variance.as_ref().map(Matrix::shape);
658
659 batch_norm_2d_backward_shape_check(
660 x_shape,
661 y_grad_shape,
662 x_grad_shape,
663 scale_shape,
664 scale_grad_shape,
665 bias_grad_shape,
666 saving_mean_shape,
667 saving_inv_variance_shape,
668 )?;
669
670 D::batch_norm_2d_backward(
671 x,
672 y_grad,
673 x_grad,
674 scale,
675 scale_grad,
676 bias_grad,
677 saving_mean,
678 saving_inv_variance,
679 device_batch_norm_backward,
680 );
681
682 Ok(())
683}
684
685#[expect(clippy::too_many_arguments, clippy::missing_panics_doc)]
686pub fn batch_norm_2d_forward_train<T: Num, D: Device>(
687 momentum: f64,
688 x: Matrix<Ref<&T>, DimDyn, D>,
689 y: Matrix<Ref<&mut T>, DimDyn, D>,
690 scale: Matrix<Ref<&T>, DimDyn, D>,
691 bias: Matrix<Ref<&T>, DimDyn, D>,
692 mean: Matrix<Ref<&mut T>, DimDyn, D>,
693 variance: Matrix<Ref<&mut T>, DimDyn, D>,
694 saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
695 saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
696 device_batch_norm: &Option<BatchNorm2dConfig<T>>,
697) {
698 try_batch_norm_2d_forward_trian(
699 momentum,
700 x,
701 y,
702 scale,
703 bias,
704 mean,
705 variance,
706 saving_mean,
707 saving_inv_variance,
708 device_batch_norm,
709 )
710 .unwrap();
711}
712
713#[expect(clippy::missing_panics_doc)]
714pub fn batch_norm_2d_forward_inference<T: Num, D: Device>(
715 x: Matrix<Ref<&T>, DimDyn, D>,
716 y: Matrix<Ref<&mut T>, DimDyn, D>,
717 scale: Matrix<Ref<&T>, DimDyn, D>,
718 bias: Matrix<Ref<&T>, DimDyn, D>,
719 mean: Matrix<Ref<&T>, DimDyn, D>,
720 variance: Matrix<Ref<&T>, DimDyn, D>,
721 device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
722) {
723 try_batch_norm_2d_forward_inference(
724 x,
725 y,
726 scale,
727 bias,
728 mean,
729 variance,
730 device_batch_norm_inference,
731 )
732 .unwrap();
733}
734
735#[expect(clippy::too_many_arguments, clippy::missing_panics_doc)]
736pub fn batch_norm_2d_backward<T: Num, D: Device>(
737 x: Matrix<Ref<&T>, DimDyn, D>,
738 y_grad: Matrix<Ref<&T>, DimDyn, D>,
739 x_grad: Matrix<Ref<&mut T>, DimDyn, D>,
740 scale: Matrix<Ref<&T>, DimDyn, D>,
741 scale_grad: Matrix<Ref<&mut T>, DimDyn, D>,
742 bias_grad: Matrix<Ref<&mut T>, DimDyn, D>,
743 saving_mean: Option<Matrix<Ref<&T>, DimDyn, D>>,
744 saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, D>>,
745 device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
746) {
747 try_batch_norm_2d_backward(
748 x,
749 y_grad,
750 x_grad,
751 scale,
752 scale_grad,
753 bias_grad,
754 saving_mean,
755 saving_inv_variance,
756 device_batch_norm_backward,
757 )
758 .unwrap();
759}
760
761#[expect(clippy::unreadable_literal)]
762#[cfg(test)]
763mod batch_norm {
764 use crate::{
765 device::Device,
766 dim::DimDyn,
767 matrix::{Matrix, Owned},
768 };
769
770 use zenu_test::*;
771
772 use super::*;
773
774 #[derive(Debug)]
775 struct BatchNormInputs<D: Device> {
776 x: Matrix<Owned<f32>, DimDyn, D>,
777 y: Matrix<Owned<f32>, DimDyn, D>,
778 scale: Matrix<Owned<f32>, DimDyn, D>,
779 bias: Matrix<Owned<f32>, DimDyn, D>,
780 mean: Matrix<Owned<f32>, DimDyn, D>,
781 variance: Matrix<Owned<f32>, DimDyn, D>,
782 saved_mean: Matrix<Owned<f32>, DimDyn, D>,
783 saved_variance: Matrix<Owned<f32>, DimDyn, D>,
784 }
785
786 fn small_data<D: Device>() -> BatchNormInputs<D> {
787 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
788 vec![
789 -1.1258398,
790 -1.1523602,
791 -0.25057858,
792 -0.4338788,
793 0.84871036,
794 0.69200915,
795 -0.31601277,
796 -2.1152194,
797 0.32227492,
798 -1.2633348,
799 0.3499832,
800 0.30813393,
801 0.11984151,
802 1.2376579,
803 1.1167772,
804 -0.24727815,
805 ],
806 [2, 2, 2, 2],
807 );
808 let y = vec![
809 -1.0970649,
810 -1.1374662,
811 0.23631285,
812 -0.04292771,
813 0.66504365,
814 0.5121599,
815 -0.4713051,
816 -2.2266803,
817 1.109001,
818 -1.3065253,
819 1.1512119,
820 1.0874585,
821 -0.04606889,
822 1.0445158,
823 0.92657995,
824 -0.40424496,
825 ];
826 let running_mean = vec![-0.36513, 0.15035464];
827 let running_variance = vec![0.4431935, 1.0805689];
828 let saved_mean = vec![-0.40570003, 0.16706072];
829 let saved_variance = vec![1.5234232, 0.97564316];
830 let scale = vec![1.0, 1.0];
831 let bias = vec![0.0, 0.0];
832 let y = Matrix::<Owned<f32>, DimDyn, D>::from_vec(y, [2, 2, 2, 2]);
833 let mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(running_mean, [2]);
834 let variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(running_variance, [2]);
835 let scale = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale, [2]);
836 let bias = Matrix::<Owned<f32>, DimDyn, D>::from_vec(bias, [2]);
837 let saved_mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_mean, [2]);
838 let saved_variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_variance, [2]);
839 BatchNormInputs {
840 x,
841 y,
842 scale,
843 bias,
844 mean,
845 variance,
846 saved_mean,
847 saved_variance,
848 }
849 }
850
851 fn small_foward<D: Device>() {
852 let inputs = small_data::<D>();
853 let mut y_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.y.shape());
854 let mut mean_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.mean.shape());
855 let mut variance_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.variance.shape());
856 let mut saved_mean_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.saved_mean.shape());
857 let mut saved_variance_out =
858 Matrix::<Owned<f32>, DimDyn, D>::alloc(inputs.saved_variance.shape());
859 let batch_norm = BatchNorm2dConfig::<f32>::new(inputs.x.shape());
860 D::batch_norm_2d_forward_train(
861 0.1,
862 inputs.x.to_ref(),
863 y_out.to_ref_mut(),
864 inputs.scale.to_ref(),
865 inputs.bias.to_ref(),
866 mean_out.to_ref_mut(),
867 variance_out.to_ref_mut(),
868 Some(saved_mean_out.to_ref_mut()),
869 Some(saved_variance_out.to_ref_mut()),
870 &Some(batch_norm),
871 );
872
873 assert_mat_eq_epsilon!(y_out.to_ref(), inputs.y.to_ref(), 2e-4);
874 assert_mat_eq_epsilon!(mean_out.to_ref(), inputs.mean.to_ref(), 2e-4);
875 assert_mat_eq_epsilon!(variance_out.to_ref(), inputs.variance.to_ref(), 2e-4);
876 assert_mat_eq_epsilon!(saved_mean_out.to_ref(), inputs.saved_mean.to_ref(), 2e-4);
877 assert_mat_eq_epsilon!(
878 saved_variance_out.to_ref(),
879 inputs.saved_variance.to_ref(),
880 2e-4
881 );
882 }
883 run_mat_test!(small_foward, small_forward_cpu, small_forward_gpu);
884
885 #[derive(Debug)]
886 struct BatchNormBackward<D: Device> {
887 x: Matrix<Owned<f32>, DimDyn, D>,
888 y_grad: Matrix<Owned<f32>, DimDyn, D>,
889 scale: Matrix<Owned<f32>, DimDyn, D>,
890 saved_mean: Matrix<Owned<f32>, DimDyn, D>,
891 saved_variance: Matrix<Owned<f32>, DimDyn, D>,
892 }
893
894 fn small_data_backward<D: Device>() -> BatchNormBackward<D> {
895 let x = vec![
896 -1.1258398,
897 -1.1523602,
898 -0.25057858,
899 -0.4338788,
900 0.84871036,
901 0.69200915,
902 -0.31601277,
903 -2.1152194,
904 0.32227492,
905 -1.2633348,
906 0.3499832,
907 0.30813393,
908 0.11984151,
909 1.2376579,
910 1.1167772,
911 -0.24727815,
912 ];
913 let y_grad = vec![
914 -0.9246624,
915 -0.42534423,
916 -2.6438458,
917 0.14518386,
918 -0.1208664,
919 -0.57972574,
920 -0.622851,
921 -0.3283869,
922 -1.0745419,
923 -0.36314395,
924 -1.6710504,
925 2.2655048,
926 0.3116848,
927 -0.1841891,
928 1.2866427,
929 1.1819527,
930 ];
931 let saved_mean = vec![-0.04057, 0.01670607];
932 let saved_variance = vec![0.9492437, 1.0200632];
933 let scale = vec![1.0, 1.0];
934 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(x, [2, 2, 2, 2]);
935 let y_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(y_grad, [2, 2, 2, 2]);
936 let scale = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale, [2]);
937 let saved_mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_mean, [2]);
938 let saved_variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_variance, [2]);
939 BatchNormBackward {
940 x,
941 y_grad,
942 scale,
943 saved_mean,
944 saved_variance,
945 }
946 }
947
948 fn small_backward<D: Device>() {
949 let inputs = small_data_backward::<D>();
950 let mut x_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.x.shape());
951 let mut scale_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.scale.shape());
952 let mut bias_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.scale.shape());
953 let batch_norm_backward = BatchNorm2dBackwardConfig::<f32>::new(inputs.x.shape());
954 D::batch_norm_2d_backward(
955 inputs.x.to_ref(),
956 inputs.y_grad.to_ref(),
957 x_grad.to_ref_mut(),
958 inputs.scale.to_ref(),
959 scale_grad.to_ref_mut(),
960 bias_grad.to_ref_mut(),
961 Some(inputs.saved_mean.to_ref()),
962 Some(inputs.saved_variance.to_ref()),
963 &Some(batch_norm_backward),
964 );
965
966 let x_grad_ans = vec![
967 -0.06967929,
968 0.41043705,
969 -1.9042997,
970 0.7856185,
971 -0.39005604,
972 -0.83055514,
973 -0.69721717,
974 -0.080333665,
975 -0.54731166,
976 0.4951802,
977 -1.1199604,
978 2.6264815,
979 0.1793941,
980 -0.52307177,
981 0.99853456,
982 1.131705,
983 ];
984 let scale_grad_ans = vec![2.0560942, 1.352522];
985 let bias_grad_ans = vec![-4.6919003, 0.9442612];
986 let x_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(x_grad_ans, [2, 2, 2, 2]);
987 let scale_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale_grad_ans, [2]);
988 let bias_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(bias_grad_ans, [2]);
989 assert_mat_eq_epsilon!(x_grad.to_ref(), x_grad_ans.to_ref(), 2e-4);
990 assert_mat_eq_epsilon!(scale_grad.to_ref(), scale_grad_ans.to_ref(), 2e-4);
991 assert_mat_eq_epsilon!(bias_grad.to_ref(), bias_grad_ans.to_ref(), 2e-4);
992 }
993 run_mat_test!(small_backward, small_backward_cpu, small_backward_gpu);
994
995 fn small_foward_inference<D: Device>() {
996 let inputs = small_forward_inference_data::<f32, D>();
997 let mut y_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.y.shape());
998 let batch_norm_inference = BatchNorm2dInferenceConfig::<f32>::new(inputs.x.shape());
999 D::bach_norm_2d_forward_inference(
1000 inputs.x.to_ref(),
1001 y_out.to_ref_mut(),
1002 inputs.scale.to_ref(),
1003 inputs.bias.to_ref(),
1004 inputs.mean.to_ref(),
1005 inputs.variance.to_ref(),
1006 &Some(batch_norm_inference),
1007 );
1008
1009 assert_mat_eq_epsilon!(y_out.to_ref(), inputs.y.to_ref(), 3e-3);
1010 }
1011 run_mat_test!(
1012 small_foward_inference,
1013 small_forward_inference_cpu,
1014 small_forward_inference_gpu
1015 );
1016
1017 #[derive(Debug)]
1018 struct ForwardInputs<T: Num, D: Device> {
1019 x: Matrix<Owned<T>, DimDyn, D>,
1020 y: Matrix<Owned<T>, DimDyn, D>,
1021 scale: Matrix<Owned<T>, DimDyn, D>,
1022 bias: Matrix<Owned<T>, DimDyn, D>,
1023 mean: Matrix<Owned<T>, DimDyn, D>,
1024 variance: Matrix<Owned<T>, DimDyn, D>,
1025 }
1026
1027 fn small_forward_inference_data<T: Num, D: Device>() -> ForwardInputs<T, D> {
1028 let x = vec![
1029 -1.1258398,
1030 -1.1523602,
1031 -0.25057858,
1032 -0.4338788,
1033 0.84871036,
1034 0.69200915,
1035 -0.31601277,
1036 -2.1152194,
1037 0.32227492,
1038 -1.2633348,
1039 0.3499832,
1040 0.30813393,
1041 0.11984151,
1042 1.2376579,
1043 1.1167772,
1044 -0.24727815,
1045 ];
1046 let y = vec![
1047 -0.6203, -0.5908, -1.5910, -1.3877, 3.3524, 2.9482, 0.3480, -4.2931, -2.2263, -0.4678,
1048 -2.2570, -2.2106, 1.4723, 4.3557, 4.0439, 0.5253,
1049 ];
1050 let mean = vec![-0.7193, -0.4033];
1051 let variance = vec![0.5966, 0.1820];
1052 let scale = vec![-0.8567, 1.1006];
1053 let bias = vec![-1.0712, 0.1227];
1054
1055 let x = x.into_iter().map(T::from_f64).collect();
1056 let y = y.into_iter().map(T::from_f64).collect();
1057 let mean = mean.into_iter().map(T::from_f64).collect();
1058 let variance = variance.into_iter().map(T::from_f64).collect();
1059 let scale = scale.into_iter().map(T::from_f64).collect();
1060 let bias = bias.into_iter().map(T::from_f64).collect();
1061
1062 let x = Matrix::<Owned<T>, DimDyn, D>::from_vec(x, [2, 2, 2, 2]);
1063 let y = Matrix::<Owned<T>, DimDyn, D>::from_vec(y, [2, 2, 2, 2]);
1064 let mean = Matrix::<Owned<T>, DimDyn, D>::from_vec(mean, [2]);
1065 let variance = Matrix::<Owned<T>, DimDyn, D>::from_vec(variance, [2]);
1066 let scale = Matrix::<Owned<T>, DimDyn, D>::from_vec(scale, [2]);
1067 let bias = Matrix::<Owned<T>, DimDyn, D>::from_vec(bias, [2]);
1068 ForwardInputs {
1069 x,
1070 y,
1071 scale,
1072 bias,
1073 mean,
1074 variance,
1075 }
1076 }
1077}