zenu_matrix/nn/
batch_norm.rs

1use crate::{
2    device::{cpu::Cpu, Device, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Ref},
5    num::Num,
6};
7
8#[cfg(feature = "nvidia")]
9use zenu_cuda::cudnn::{
10    batch_norm::{
11        BatchNorm2d, BatchNorm2dBackward, BatchNorm2dBackwardBuilder, BatchNorm2dBuilder,
12        BatchNorm2dInference, BatchNorm2dInferenceBuilder,
13    },
14    TensorFormat,
15};
16
17#[cfg(feature = "nvidia")]
18use crate::device::nvidia::Nvidia;
19
20pub struct BatchNorm2dConfig<T: Num> {
21    #[cfg(feature = "nvidia")]
22    pub device_batch_norm: BatchNorm2d<T>,
23    _phantom: std::marker::PhantomData<T>,
24}
25
26impl<T: Num> BatchNorm2dConfig<T> {
27    #[must_use]
28    #[allow(unused_variables)]
29    pub fn new(dim: DimDyn) -> Self {
30        BatchNorm2dConfig::<T> {
31            #[cfg(feature = "nvidia")]
32            device_batch_norm: create_batch_norm_gpu::<T>(dim),
33            _phantom: std::marker::PhantomData,
34        }
35    }
36}
37
38pub struct BatchNorm2dBackwardConfig<T> {
39    #[cfg(feature = "nvidia")]
40    pub device_batch_norm_backward: BatchNorm2dBackward<T>,
41    _phantom: std::marker::PhantomData<T>,
42}
43
44impl<T: Num> BatchNorm2dBackwardConfig<T> {
45    #[must_use]
46    #[allow(unused_variables)]
47    pub fn new(dim: DimDyn) -> Self {
48        BatchNorm2dBackwardConfig::<T> {
49            #[cfg(feature = "nvidia")]
50            device_batch_norm_backward: create_batch_norm_backward_gpu::<T>(dim),
51            _phantom: std::marker::PhantomData,
52        }
53    }
54}
55
56pub struct BatchNorm2dInferenceConfig<T> {
57    #[cfg(feature = "nvidia")]
58    pub device_batch_norm_inference: BatchNorm2dInference<T>,
59    _phantom: std::marker::PhantomData<T>,
60}
61
62impl<T: Num> BatchNorm2dInferenceConfig<T> {
63    #[must_use]
64    pub fn new(dim: DimDyn) -> Self {
65        BatchNorm2dInferenceConfig::<T> {
66            #[cfg(feature = "nvidia")]
67            device_batch_norm_inference: create_batch_norm_inference_gpu::<T>(dim),
68            _phantom: std::marker::PhantomData,
69        }
70    }
71}
72
73#[cfg(feature = "nvidia")]
74fn create_batch_norm_gpu<T: Num>(input: DimDyn) -> BatchNorm2d<T> {
75    let input = (
76        input[0].try_into().unwrap(),
77        input[1].try_into().unwrap(),
78        input[2].try_into().unwrap(),
79        input[3].try_into().unwrap(),
80    );
81    BatchNorm2dBuilder::<T>::new()
82        .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
83        .unwrap()
84        .output(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
85        .unwrap()
86        .scale_bias_mean_var(input.1, TensorFormat::NCHW)
87        .unwrap()
88        .build()
89}
90
91#[cfg(feature = "nvidia")]
92fn create_batch_norm_backward_gpu<T: Num>(input: DimDyn) -> BatchNorm2dBackward<T> {
93    let input = (
94        input[0].try_into().unwrap(),
95        input[1].try_into().unwrap(),
96        input[2].try_into().unwrap(),
97        input[3].try_into().unwrap(),
98    );
99    BatchNorm2dBackwardBuilder::<T>::new()
100        .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
101        .unwrap()
102        .input_grad(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
103        .unwrap()
104        .output_grad(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
105        .unwrap()
106        .scale_bias_mean_var(input.1, TensorFormat::NCHW)
107        .unwrap()
108        .build()
109}
110
111#[cfg(feature = "nvidia")]
112fn create_batch_norm_inference_gpu<T: Num>(input: DimDyn) -> BatchNorm2dInference<T> {
113    let input = (
114        input[0].try_into().unwrap(),
115        input[1].try_into().unwrap(),
116        input[2].try_into().unwrap(),
117        input[3].try_into().unwrap(),
118    );
119    BatchNorm2dInferenceBuilder::<T>::new()
120        .input(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
121        .unwrap()
122        .output(input.0, input.1, input.2, input.3, TensorFormat::NCHW)
123        .unwrap()
124        .scale_bias_mean_var(input.1, TensorFormat::NCHW)
125        .unwrap()
126        .build()
127}
128
129pub trait BatchNormalization: DeviceBase {
130    #[expect(clippy::too_many_arguments)]
131    fn batch_norm_2d_forward_train<T: Num>(
132        momentum: f64,
133        x: Matrix<Ref<&T>, DimDyn, Self>,
134        y: Matrix<Ref<&mut T>, DimDyn, Self>,
135        scale: Matrix<Ref<&T>, DimDyn, Self>,
136        bias: Matrix<Ref<&T>, DimDyn, Self>,
137        mean: Matrix<Ref<&mut T>, DimDyn, Self>,
138        variance: Matrix<Ref<&mut T>, DimDyn, Self>,
139        saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
140        saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
141        device_batch_norm: &Option<BatchNorm2dConfig<T>>,
142    );
143
144    #[expect(clippy::too_many_arguments)]
145    fn batch_norm_2d_backward<T: Num>(
146        x: Matrix<Ref<&T>, DimDyn, Self>,
147        y_grad: Matrix<Ref<&T>, DimDyn, Self>,
148        x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
149        scale: Matrix<Ref<&T>, DimDyn, Self>,
150        scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
151        bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
152        saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
153        saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
154        device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
155    );
156
157    fn bach_norm_2d_forward_inference<T: Num>(
158        x: Matrix<Ref<&T>, DimDyn, Self>,
159        y: Matrix<Ref<&mut T>, DimDyn, Self>,
160        scale: Matrix<Ref<&T>, DimDyn, Self>,
161        bias: Matrix<Ref<&T>, DimDyn, Self>,
162        mean: Matrix<Ref<&T>, DimDyn, Self>,
163        variance: Matrix<Ref<&T>, DimDyn, Self>,
164        device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
165    );
166}
167
168#[cfg(feature = "nvidia")]
169impl BatchNormalization for Nvidia {
170    fn batch_norm_2d_forward_train<T: Num>(
171        momentum: f64,
172        x: Matrix<Ref<&T>, DimDyn, Self>,
173        y: Matrix<Ref<&mut T>, DimDyn, Self>,
174        scale: Matrix<Ref<&T>, DimDyn, Self>,
175        bias: Matrix<Ref<&T>, DimDyn, Self>,
176        mean: Matrix<Ref<&mut T>, DimDyn, Self>,
177        variance: Matrix<Ref<&mut T>, DimDyn, Self>,
178        saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
179        saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
180        device_batch_norm: &Option<BatchNorm2dConfig<T>>,
181    ) {
182        let momentum = 1. - momentum;
183        let batch_norm = match device_batch_norm {
184            Some(ref batch_norm) => &batch_norm.device_batch_norm,
185            None => &create_batch_norm_gpu::<T>(x.shape()),
186        };
187        let saving_mean = match saving_mean {
188            Some(saved_mean) => saved_mean.as_mut_ptr(),
189            None => std::ptr::null_mut(),
190        };
191        let saving_inv_variance = match saving_inv_variance {
192            Some(saved_inv_variance) => saved_inv_variance.as_mut_ptr(),
193            None => std::ptr::null_mut(),
194        };
195        batch_norm
196            .forward_train(
197                T::one(),
198                T::zero(),
199                x.as_ptr(),
200                y.as_mut_ptr(),
201                scale.as_ptr(),
202                bias.as_ptr(),
203                mean.as_mut_ptr(),
204                variance.as_mut_ptr(),
205                momentum,
206                saving_mean,
207                saving_inv_variance,
208            )
209            .unwrap();
210    }
211
212    fn batch_norm_2d_backward<T: Num>(
213        x: Matrix<Ref<&T>, DimDyn, Self>,
214        y_grad: Matrix<Ref<&T>, DimDyn, Self>,
215        x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
216        scale: Matrix<Ref<&T>, DimDyn, Self>,
217        scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
218        bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
219        saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
220        saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
221        device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
222    ) {
223        let batch_norm_backward = match device_batch_norm_backward {
224            Some(ref batch_norm_backward) => &batch_norm_backward.device_batch_norm_backward,
225            None => &create_batch_norm_backward_gpu::<T>(x.shape()),
226        };
227        let saving_mean = match saving_mean {
228            Some(saved_mean) => saved_mean.as_ptr(),
229            None => std::ptr::null_mut(),
230        };
231        let saving_inv_variance = match saving_inv_variance {
232            Some(saved_inv_variance) => saved_inv_variance.as_ptr(),
233            None => std::ptr::null_mut(),
234        };
235        batch_norm_backward
236            .backward(
237                T::one(),
238                T::zero(),
239                T::one(),
240                T::zero(),
241                x.as_ptr(),
242                y_grad.as_ptr(),
243                x_grad.as_mut_ptr(),
244                scale.as_ptr(),
245                scale_grad.as_mut_ptr(),
246                bias_grad.as_mut_ptr(),
247                saving_mean,
248                saving_inv_variance,
249            )
250            .unwrap();
251    }
252
253    fn bach_norm_2d_forward_inference<T: Num>(
254        x: Matrix<Ref<&T>, DimDyn, Self>,
255        y: Matrix<Ref<&mut T>, DimDyn, Self>,
256        scale: Matrix<Ref<&T>, DimDyn, Self>,
257        bias: Matrix<Ref<&T>, DimDyn, Self>,
258        mean: Matrix<Ref<&T>, DimDyn, Self>,
259        variance: Matrix<Ref<&T>, DimDyn, Self>,
260        device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
261    ) {
262        let batch_norm_inference = match device_batch_norm_inference {
263            Some(ref batch_norm_inference) => &batch_norm_inference.device_batch_norm_inference,
264            None => &create_batch_norm_inference_gpu::<T>(x.shape()),
265        };
266        batch_norm_inference
267            .forward_inference(
268                T::one(),
269                T::zero(),
270                x.as_ptr(),
271                y.as_mut_ptr(),
272                scale.as_ptr(),
273                bias.as_ptr(),
274                mean.as_ptr(),
275                variance.as_ptr(),
276            )
277            .unwrap();
278    }
279}
280
281impl BatchNormalization for Cpu {
282    fn batch_norm_2d_forward_train<T: Num>(
283        momentum: f64,
284        x: Matrix<Ref<&T>, DimDyn, Self>,
285        y: Matrix<Ref<&mut T>, DimDyn, Self>,
286        scale: Matrix<Ref<&T>, DimDyn, Self>,
287        bias: Matrix<Ref<&T>, DimDyn, Self>,
288        mean: Matrix<Ref<&mut T>, DimDyn, Self>,
289        variance: Matrix<Ref<&mut T>, DimDyn, Self>,
290        saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
291        saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, Self>>,
292        _: &Option<BatchNorm2dConfig<T>>,
293    ) {
294        let momentum = T::from_f64(momentum);
295        let epsilon = T::from_f64(1e-10);
296        let x_shape = x.shape();
297        let n = x_shape[0] * x_shape[2] * x_shape[3];
298        let c = x_shape[1];
299        let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
300        let x_reshaped = x_transposed.reshape([n, c]);
301
302        let num_elements = T::from_usize(x_reshaped.shape()[0]);
303
304        let x_mean = x_reshaped.mean(Some(0), false);
305        let x_diff = &x_reshaped - &x_mean;
306        let x_variance = x_reshaped.variance(Some(0), false);
307        let x_variance_unbiased = &x_variance * (num_elements / (num_elements - T::one()));
308
309        let mean_t = &x_mean * (T::one() - momentum) + &mean * momentum;
310        let variance_t = &x_variance_unbiased * (T::one() - momentum) + &variance * momentum;
311
312        let inv_var = Matrix::<_, DimDyn, _>::ones(variance_t.shape()) / (&x_variance + epsilon);
313        let inv_std = inv_var.sqrt();
314
315        mean.copy_from(&mean_t);
316        variance.copy_from(&variance_t);
317
318        if let Some(saving_mean_mat) = saving_mean {
319            saving_mean_mat.copy_from(&x_mean);
320        }
321        if let Some(saving_inv_variance_mat) = saving_inv_variance {
322            saving_inv_variance_mat.copy_from(&inv_std);
323        }
324
325        let x_normalized = &x_diff * &inv_std;
326        let y_tmp = &x_normalized * &scale + &bias;
327        let y_transposed = y_tmp.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
328        y.copy_from(&y_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
329    }
330
331    fn batch_norm_2d_backward<T: Num>(
332        x: Matrix<Ref<&T>, DimDyn, Self>,
333        y_grad: Matrix<Ref<&T>, DimDyn, Self>,
334        x_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
335        scale: Matrix<Ref<&T>, DimDyn, Self>,
336        scale_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
337        bias_grad: Matrix<Ref<&mut T>, DimDyn, Self>,
338        saving_mean: Option<Matrix<Ref<&T>, DimDyn, Self>>,
339        saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, Self>>,
340        _: &Option<BatchNorm2dBackwardConfig<T>>,
341    ) {
342        let epsilon = T::from_f64(1e-10);
343        let n = x.shape()[0] * x.shape()[2] * x.shape()[3];
344        let c = x.shape()[1];
345        let x_shape = x.shape();
346
347        // Transpose and reshape x and y_grad for easier manipulation
348        let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
349        let x_reshaped = x_transposed.reshape([n, c]);
350
351        let y_grad_transposed = y_grad.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
352        let y_grad_reshaped = y_grad_transposed.reshape([n, c]);
353
354        let mean = if let Some(ref mean_mat) = saving_mean {
355            mean_mat.new_matrix()
356        } else {
357            x_reshaped.mean(Some(0), false)
358        };
359
360        let inv_std = if let Some(ref inv_variance_mat) = saving_inv_variance {
361            inv_variance_mat.new_matrix()
362        } else {
363            let x_variance = x_reshaped.variance(Some(0), false);
364            let inv_var =
365                Matrix::<_, DimDyn, _>::ones(x_variance.shape()) / (&x_variance + epsilon);
366            inv_var.sqrt()
367        };
368
369        let x_centered = &x_reshaped - &mean;
370        let x_hat = &x_centered * &inv_std;
371
372        bias_grad.copy_from(&y_grad_reshaped.to_ref().sum(0, false));
373        scale_grad.copy_from(&(&x_hat * &y_grad_reshaped).to_ref().sum(0, false));
374
375        // Compute the gradients
376        let term1 = &inv_std * &y_grad_reshaped * scale;
377        let mut term2 = term1.to_ref().sum(0, false) / T::from_usize(n);
378        term2.add_axis(0);
379        let mut term3 =
380            &x_centered * (&term1 * &x_centered).to_ref().sum(0, false) / T::from_usize(n);
381        term3.add_axis(0);
382        let term3 = term3 * &inv_std * &inv_std;
383
384        let x_grad_reshaped = term1 - term2 - term3;
385
386        let x_grad_transposed =
387            x_grad_reshaped.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
388
389        x_grad.copy_from(&x_grad_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
390    }
391
392    fn bach_norm_2d_forward_inference<T: Num>(
393        x: Matrix<Ref<&T>, DimDyn, Self>,
394        y: Matrix<Ref<&mut T>, DimDyn, Self>,
395        scale: Matrix<Ref<&T>, DimDyn, Self>,
396        bias: Matrix<Ref<&T>, DimDyn, Self>,
397        mean: Matrix<Ref<&T>, DimDyn, Self>,
398        variance: Matrix<Ref<&T>, DimDyn, Self>,
399        _: &Option<BatchNorm2dInferenceConfig<T>>,
400    ) {
401        let epsilon = T::from_f64(1e-10);
402        let n = x.shape()[0] * x.shape()[2] * x.shape()[3];
403        let c = x.shape()[1];
404        let x_shape = x.shape();
405
406        // Transpose and reshape x and y_grad for easier manipulation
407        let x_transposed = x.transpose_by_index_new_matrix(&[0, 2, 3, 1]);
408        let x_reshaped = x_transposed.reshape([n, c]);
409
410        let mean = mean.to_ref();
411        let inv_std = Matrix::<_, DimDyn, _>::ones(variance.shape()) / (&variance + epsilon).sqrt();
412
413        let x_centered = &x_reshaped - mean;
414        let x_hat = &x_centered * &inv_std;
415
416        let y_tmp = &x_hat * &scale + &bias;
417        let y_transposed = y_tmp.reshape([x_shape[0], x_shape[2], x_shape[3], x_shape[1]]);
418        y.copy_from(&y_transposed.transpose_by_index_new_matrix(&[0, 3, 1, 2]));
419    }
420}
421
422#[expect(clippy::too_many_arguments)]
423fn batch_norm_2d_shape_check(
424    x: DimDyn,
425    y: DimDyn,
426    scale: DimDyn,
427    bias: DimDyn,
428    mean: DimDyn,
429    variance: DimDyn,
430    saving_mean: Option<DimDyn>,
431    saving_inv_variance: Option<DimDyn>,
432) -> Result<(), String> {
433    if scale.len() != 1 {
434        return Err("scale must be a vector".to_string());
435    }
436    if bias.len() != 1 {
437        return Err("bias must be a vector".to_string());
438    }
439    if mean.len() != 1 {
440        return Err("mean must be a vector".to_string());
441    }
442    if variance.len() != 1 {
443        return Err("variance must be a vector".to_string());
444    }
445    if let Some(saving_mean) = saving_mean {
446        if saving_mean.len() != 1 {
447            return Err("saving_mean must be a vector".to_string());
448        }
449    }
450    if let Some(saving_inv_variance) = saving_inv_variance {
451        if saving_inv_variance.len() != 1 {
452            return Err("saving_inv_variance must be a vector".to_string());
453        }
454    }
455    if x.len() != 4 {
456        return Err("x and y must have the same number of elements".to_string());
457    }
458    if x != y {
459        return Err("x and y must have the same shape".to_string());
460    }
461    if x[1] != scale[0] {
462        return Err("x and scale must have the same number of channels".to_string());
463    }
464    if x[1] != bias[0] {
465        return Err("x and bias must have the same number of channels".to_string());
466    }
467    if x[1] != mean[0] {
468        return Err("x and mean must have the same number of channels".to_string());
469    }
470    if x[1] != variance[0] {
471        return Err("x and variance must have the same number of channels".to_string());
472    }
473    if let Some(saving_mean) = saving_mean {
474        if x[1] != saving_mean[0] {
475            return Err("x and saving_mean must have the same number of channels".to_string());
476        }
477    }
478    if let Some(saving_inv_variance) = saving_inv_variance {
479        if x[1] != saving_inv_variance[0] {
480            return Err(
481                "x and saving_inv_variance must have the same number of channels".to_string(),
482            );
483        }
484    }
485    Ok(())
486}
487
488#[expect(clippy::too_many_arguments)]
489fn batch_norm_2d_backward_shape_check(
490    x: DimDyn,
491    y_grad: DimDyn,
492    x_grad: DimDyn,
493    scale: DimDyn,
494    scale_grad: DimDyn,
495    bias_grad: DimDyn,
496    saving_mean: Option<DimDyn>,
497    saving_inv_variance: Option<DimDyn>,
498) -> Result<(), String> {
499    if scale.len() != 1 {
500        return Err("scale must be a vector".to_string());
501    }
502    if bias_grad.len() != 1 {
503        return Err("bias_grad must be a vector".to_string());
504    }
505    if let Some(saving_mean) = saving_mean {
506        if saving_mean.len() != 1 {
507            return Err("saving_mean must be a vector".to_string());
508        }
509    }
510    if let Some(saving_inv_variance) = saving_inv_variance {
511        if saving_inv_variance.len() != 1 {
512            return Err("saving_inv_variance must be a vector".to_string());
513        }
514    }
515    if x.len() != 4 {
516        return Err("x and y_grad must have the same number of elements".to_string());
517    }
518    if x != y_grad {
519        return Err("x and y_grad must have the same shape".to_string());
520    }
521    if x != x_grad {
522        return Err("x and x_grad must have the same shape".to_string());
523    }
524    if x[1] != scale[0] {
525        return Err("x and scale must have the same number of channels".to_string());
526    }
527    if x[1] != scale_grad[0] {
528        return Err("x and scale_grad must have the same number of channels".to_string());
529    }
530    if x[1] != bias_grad[0] {
531        return Err("x and bias_grad must have the same number of channels".to_string());
532    }
533    if let Some(saving_mean) = saving_mean {
534        if x[1] != saving_mean[0] {
535            return Err("x and saving_mean must have the same number of channels".to_string());
536        }
537    }
538    if let Some(saving_inv_variance) = saving_inv_variance {
539        if x[1] != saving_inv_variance[0] {
540            return Err(
541                "x and saving_inv_variance must have the same number of channels".to_string(),
542            );
543        }
544    }
545    Ok(())
546}
547
548#[expect(clippy::too_many_arguments, clippy::missing_errors_doc)]
549pub fn try_batch_norm_2d_forward_trian<T: Num, D: Device>(
550    momentum: f64,
551    x: Matrix<Ref<&T>, DimDyn, D>,
552    y: Matrix<Ref<&mut T>, DimDyn, D>,
553    scale: Matrix<Ref<&T>, DimDyn, D>,
554    bias: Matrix<Ref<&T>, DimDyn, D>,
555    mean: Matrix<Ref<&mut T>, DimDyn, D>,
556    variance: Matrix<Ref<&mut T>, DimDyn, D>,
557    saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
558    saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
559    device_batch_norm: &Option<BatchNorm2dConfig<T>>,
560) -> Result<(), String> {
561    let x_shape = x.shape();
562    let y_shape = y.shape();
563    let scale_shape = scale.shape();
564    let bias_shape = bias.shape();
565    let mean_shape = mean.shape();
566    let variance_shape = variance.shape();
567    let saving_mean_shape = saving_mean.as_ref().map(Matrix::shape);
568    let saving_inv_variance_shape = saving_inv_variance.as_ref().map(Matrix::shape);
569
570    batch_norm_2d_shape_check(
571        x_shape,
572        y_shape,
573        scale_shape,
574        bias_shape,
575        mean_shape,
576        variance_shape,
577        saving_mean_shape,
578        saving_inv_variance_shape,
579    )?;
580
581    D::batch_norm_2d_forward_train(
582        momentum,
583        x,
584        y,
585        scale,
586        bias,
587        mean,
588        variance,
589        saving_mean,
590        saving_inv_variance,
591        device_batch_norm,
592    );
593
594    Ok(())
595}
596
597#[expect(clippy::missing_errors_doc)]
598pub fn try_batch_norm_2d_forward_inference<T: Num, D: Device>(
599    x: Matrix<Ref<&T>, DimDyn, D>,
600    y: Matrix<Ref<&mut T>, DimDyn, D>,
601    scale: Matrix<Ref<&T>, DimDyn, D>,
602    bias: Matrix<Ref<&T>, DimDyn, D>,
603    mean: Matrix<Ref<&T>, DimDyn, D>,
604    variance: Matrix<Ref<&T>, DimDyn, D>,
605    device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
606) -> Result<(), String> {
607    let x_shape = x.shape();
608    let y_shape = y.shape();
609    let scale_shape = scale.shape();
610    let bias_shape = bias.shape();
611    let mean_shape = mean.shape();
612    let variance_shape = variance.shape();
613
614    batch_norm_2d_shape_check(
615        x_shape,
616        y_shape,
617        scale_shape,
618        bias_shape,
619        mean_shape,
620        variance_shape,
621        None,
622        None,
623    )?;
624
625    D::bach_norm_2d_forward_inference(
626        x,
627        y,
628        scale,
629        bias,
630        mean,
631        variance,
632        device_batch_norm_inference,
633    );
634
635    Ok(())
636}
637
638#[expect(clippy::too_many_arguments, clippy::missing_errors_doc)]
639pub fn try_batch_norm_2d_backward<T: Num, D: Device>(
640    x: Matrix<Ref<&T>, DimDyn, D>,
641    y_grad: Matrix<Ref<&T>, DimDyn, D>,
642    x_grad: Matrix<Ref<&mut T>, DimDyn, D>,
643    scale: Matrix<Ref<&T>, DimDyn, D>,
644    scale_grad: Matrix<Ref<&mut T>, DimDyn, D>,
645    bias_grad: Matrix<Ref<&mut T>, DimDyn, D>,
646    saving_mean: Option<Matrix<Ref<&T>, DimDyn, D>>,
647    saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, D>>,
648    device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
649) -> Result<(), String> {
650    let x_shape = x.shape();
651    let y_grad_shape = y_grad.shape();
652    let x_grad_shape = x_grad.shape();
653    let scale_shape = scale.shape();
654    let scale_grad_shape = scale_grad.shape();
655    let bias_grad_shape = bias_grad.shape();
656    let saving_mean_shape = saving_mean.as_ref().map(Matrix::shape);
657    let saving_inv_variance_shape = saving_inv_variance.as_ref().map(Matrix::shape);
658
659    batch_norm_2d_backward_shape_check(
660        x_shape,
661        y_grad_shape,
662        x_grad_shape,
663        scale_shape,
664        scale_grad_shape,
665        bias_grad_shape,
666        saving_mean_shape,
667        saving_inv_variance_shape,
668    )?;
669
670    D::batch_norm_2d_backward(
671        x,
672        y_grad,
673        x_grad,
674        scale,
675        scale_grad,
676        bias_grad,
677        saving_mean,
678        saving_inv_variance,
679        device_batch_norm_backward,
680    );
681
682    Ok(())
683}
684
685#[expect(clippy::too_many_arguments, clippy::missing_panics_doc)]
686pub fn batch_norm_2d_forward_train<T: Num, D: Device>(
687    momentum: f64,
688    x: Matrix<Ref<&T>, DimDyn, D>,
689    y: Matrix<Ref<&mut T>, DimDyn, D>,
690    scale: Matrix<Ref<&T>, DimDyn, D>,
691    bias: Matrix<Ref<&T>, DimDyn, D>,
692    mean: Matrix<Ref<&mut T>, DimDyn, D>,
693    variance: Matrix<Ref<&mut T>, DimDyn, D>,
694    saving_mean: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
695    saving_inv_variance: Option<Matrix<Ref<&mut T>, DimDyn, D>>,
696    device_batch_norm: &Option<BatchNorm2dConfig<T>>,
697) {
698    try_batch_norm_2d_forward_trian(
699        momentum,
700        x,
701        y,
702        scale,
703        bias,
704        mean,
705        variance,
706        saving_mean,
707        saving_inv_variance,
708        device_batch_norm,
709    )
710    .unwrap();
711}
712
713#[expect(clippy::missing_panics_doc)]
714pub fn batch_norm_2d_forward_inference<T: Num, D: Device>(
715    x: Matrix<Ref<&T>, DimDyn, D>,
716    y: Matrix<Ref<&mut T>, DimDyn, D>,
717    scale: Matrix<Ref<&T>, DimDyn, D>,
718    bias: Matrix<Ref<&T>, DimDyn, D>,
719    mean: Matrix<Ref<&T>, DimDyn, D>,
720    variance: Matrix<Ref<&T>, DimDyn, D>,
721    device_batch_norm_inference: &Option<BatchNorm2dInferenceConfig<T>>,
722) {
723    try_batch_norm_2d_forward_inference(
724        x,
725        y,
726        scale,
727        bias,
728        mean,
729        variance,
730        device_batch_norm_inference,
731    )
732    .unwrap();
733}
734
735#[expect(clippy::too_many_arguments, clippy::missing_panics_doc)]
736pub fn batch_norm_2d_backward<T: Num, D: Device>(
737    x: Matrix<Ref<&T>, DimDyn, D>,
738    y_grad: Matrix<Ref<&T>, DimDyn, D>,
739    x_grad: Matrix<Ref<&mut T>, DimDyn, D>,
740    scale: Matrix<Ref<&T>, DimDyn, D>,
741    scale_grad: Matrix<Ref<&mut T>, DimDyn, D>,
742    bias_grad: Matrix<Ref<&mut T>, DimDyn, D>,
743    saving_mean: Option<Matrix<Ref<&T>, DimDyn, D>>,
744    saving_inv_variance: Option<Matrix<Ref<&T>, DimDyn, D>>,
745    device_batch_norm_backward: &Option<BatchNorm2dBackwardConfig<T>>,
746) {
747    try_batch_norm_2d_backward(
748        x,
749        y_grad,
750        x_grad,
751        scale,
752        scale_grad,
753        bias_grad,
754        saving_mean,
755        saving_inv_variance,
756        device_batch_norm_backward,
757    )
758    .unwrap();
759}
760
761#[expect(clippy::unreadable_literal)]
762#[cfg(test)]
763mod batch_norm {
764    use crate::{
765        device::Device,
766        dim::DimDyn,
767        matrix::{Matrix, Owned},
768    };
769
770    use zenu_test::*;
771
772    use super::*;
773
774    #[derive(Debug)]
775    struct BatchNormInputs<D: Device> {
776        x: Matrix<Owned<f32>, DimDyn, D>,
777        y: Matrix<Owned<f32>, DimDyn, D>,
778        scale: Matrix<Owned<f32>, DimDyn, D>,
779        bias: Matrix<Owned<f32>, DimDyn, D>,
780        mean: Matrix<Owned<f32>, DimDyn, D>,
781        variance: Matrix<Owned<f32>, DimDyn, D>,
782        saved_mean: Matrix<Owned<f32>, DimDyn, D>,
783        saved_variance: Matrix<Owned<f32>, DimDyn, D>,
784    }
785
786    fn small_data<D: Device>() -> BatchNormInputs<D> {
787        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
788            vec![
789                -1.1258398,
790                -1.1523602,
791                -0.25057858,
792                -0.4338788,
793                0.84871036,
794                0.69200915,
795                -0.31601277,
796                -2.1152194,
797                0.32227492,
798                -1.2633348,
799                0.3499832,
800                0.30813393,
801                0.11984151,
802                1.2376579,
803                1.1167772,
804                -0.24727815,
805            ],
806            [2, 2, 2, 2],
807        );
808        let y = vec![
809            -1.0970649,
810            -1.1374662,
811            0.23631285,
812            -0.04292771,
813            0.66504365,
814            0.5121599,
815            -0.4713051,
816            -2.2266803,
817            1.109001,
818            -1.3065253,
819            1.1512119,
820            1.0874585,
821            -0.04606889,
822            1.0445158,
823            0.92657995,
824            -0.40424496,
825        ];
826        let running_mean = vec![-0.36513, 0.15035464];
827        let running_variance = vec![0.4431935, 1.0805689];
828        let saved_mean = vec![-0.40570003, 0.16706072];
829        let saved_variance = vec![1.5234232, 0.97564316];
830        let scale = vec![1.0, 1.0];
831        let bias = vec![0.0, 0.0];
832        let y = Matrix::<Owned<f32>, DimDyn, D>::from_vec(y, [2, 2, 2, 2]);
833        let mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(running_mean, [2]);
834        let variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(running_variance, [2]);
835        let scale = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale, [2]);
836        let bias = Matrix::<Owned<f32>, DimDyn, D>::from_vec(bias, [2]);
837        let saved_mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_mean, [2]);
838        let saved_variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_variance, [2]);
839        BatchNormInputs {
840            x,
841            y,
842            scale,
843            bias,
844            mean,
845            variance,
846            saved_mean,
847            saved_variance,
848        }
849    }
850
851    fn small_foward<D: Device>() {
852        let inputs = small_data::<D>();
853        let mut y_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.y.shape());
854        let mut mean_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.mean.shape());
855        let mut variance_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.variance.shape());
856        let mut saved_mean_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.saved_mean.shape());
857        let mut saved_variance_out =
858            Matrix::<Owned<f32>, DimDyn, D>::alloc(inputs.saved_variance.shape());
859        let batch_norm = BatchNorm2dConfig::<f32>::new(inputs.x.shape());
860        D::batch_norm_2d_forward_train(
861            0.1,
862            inputs.x.to_ref(),
863            y_out.to_ref_mut(),
864            inputs.scale.to_ref(),
865            inputs.bias.to_ref(),
866            mean_out.to_ref_mut(),
867            variance_out.to_ref_mut(),
868            Some(saved_mean_out.to_ref_mut()),
869            Some(saved_variance_out.to_ref_mut()),
870            &Some(batch_norm),
871        );
872
873        assert_mat_eq_epsilon!(y_out.to_ref(), inputs.y.to_ref(), 2e-4);
874        assert_mat_eq_epsilon!(mean_out.to_ref(), inputs.mean.to_ref(), 2e-4);
875        assert_mat_eq_epsilon!(variance_out.to_ref(), inputs.variance.to_ref(), 2e-4);
876        assert_mat_eq_epsilon!(saved_mean_out.to_ref(), inputs.saved_mean.to_ref(), 2e-4);
877        assert_mat_eq_epsilon!(
878            saved_variance_out.to_ref(),
879            inputs.saved_variance.to_ref(),
880            2e-4
881        );
882    }
883    run_mat_test!(small_foward, small_forward_cpu, small_forward_gpu);
884
885    #[derive(Debug)]
886    struct BatchNormBackward<D: Device> {
887        x: Matrix<Owned<f32>, DimDyn, D>,
888        y_grad: Matrix<Owned<f32>, DimDyn, D>,
889        scale: Matrix<Owned<f32>, DimDyn, D>,
890        saved_mean: Matrix<Owned<f32>, DimDyn, D>,
891        saved_variance: Matrix<Owned<f32>, DimDyn, D>,
892    }
893
894    fn small_data_backward<D: Device>() -> BatchNormBackward<D> {
895        let x = vec![
896            -1.1258398,
897            -1.1523602,
898            -0.25057858,
899            -0.4338788,
900            0.84871036,
901            0.69200915,
902            -0.31601277,
903            -2.1152194,
904            0.32227492,
905            -1.2633348,
906            0.3499832,
907            0.30813393,
908            0.11984151,
909            1.2376579,
910            1.1167772,
911            -0.24727815,
912        ];
913        let y_grad = vec![
914            -0.9246624,
915            -0.42534423,
916            -2.6438458,
917            0.14518386,
918            -0.1208664,
919            -0.57972574,
920            -0.622851,
921            -0.3283869,
922            -1.0745419,
923            -0.36314395,
924            -1.6710504,
925            2.2655048,
926            0.3116848,
927            -0.1841891,
928            1.2866427,
929            1.1819527,
930        ];
931        let saved_mean = vec![-0.04057, 0.01670607];
932        let saved_variance = vec![0.9492437, 1.0200632];
933        let scale = vec![1.0, 1.0];
934        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(x, [2, 2, 2, 2]);
935        let y_grad = Matrix::<Owned<f32>, DimDyn, D>::from_vec(y_grad, [2, 2, 2, 2]);
936        let scale = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale, [2]);
937        let saved_mean = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_mean, [2]);
938        let saved_variance = Matrix::<Owned<f32>, DimDyn, D>::from_vec(saved_variance, [2]);
939        BatchNormBackward {
940            x,
941            y_grad,
942            scale,
943            saved_mean,
944            saved_variance,
945        }
946    }
947
948    fn small_backward<D: Device>() {
949        let inputs = small_data_backward::<D>();
950        let mut x_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.x.shape());
951        let mut scale_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.scale.shape());
952        let mut bias_grad = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.scale.shape());
953        let batch_norm_backward = BatchNorm2dBackwardConfig::<f32>::new(inputs.x.shape());
954        D::batch_norm_2d_backward(
955            inputs.x.to_ref(),
956            inputs.y_grad.to_ref(),
957            x_grad.to_ref_mut(),
958            inputs.scale.to_ref(),
959            scale_grad.to_ref_mut(),
960            bias_grad.to_ref_mut(),
961            Some(inputs.saved_mean.to_ref()),
962            Some(inputs.saved_variance.to_ref()),
963            &Some(batch_norm_backward),
964        );
965
966        let x_grad_ans = vec![
967            -0.06967929,
968            0.41043705,
969            -1.9042997,
970            0.7856185,
971            -0.39005604,
972            -0.83055514,
973            -0.69721717,
974            -0.080333665,
975            -0.54731166,
976            0.4951802,
977            -1.1199604,
978            2.6264815,
979            0.1793941,
980            -0.52307177,
981            0.99853456,
982            1.131705,
983        ];
984        let scale_grad_ans = vec![2.0560942, 1.352522];
985        let bias_grad_ans = vec![-4.6919003, 0.9442612];
986        let x_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(x_grad_ans, [2, 2, 2, 2]);
987        let scale_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(scale_grad_ans, [2]);
988        let bias_grad_ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(bias_grad_ans, [2]);
989        assert_mat_eq_epsilon!(x_grad.to_ref(), x_grad_ans.to_ref(), 2e-4);
990        assert_mat_eq_epsilon!(scale_grad.to_ref(), scale_grad_ans.to_ref(), 2e-4);
991        assert_mat_eq_epsilon!(bias_grad.to_ref(), bias_grad_ans.to_ref(), 2e-4);
992    }
993    run_mat_test!(small_backward, small_backward_cpu, small_backward_gpu);
994
995    fn small_foward_inference<D: Device>() {
996        let inputs = small_forward_inference_data::<f32, D>();
997        let mut y_out = Matrix::<Owned<f32>, DimDyn, D>::zeros(inputs.y.shape());
998        let batch_norm_inference = BatchNorm2dInferenceConfig::<f32>::new(inputs.x.shape());
999        D::bach_norm_2d_forward_inference(
1000            inputs.x.to_ref(),
1001            y_out.to_ref_mut(),
1002            inputs.scale.to_ref(),
1003            inputs.bias.to_ref(),
1004            inputs.mean.to_ref(),
1005            inputs.variance.to_ref(),
1006            &Some(batch_norm_inference),
1007        );
1008
1009        assert_mat_eq_epsilon!(y_out.to_ref(), inputs.y.to_ref(), 3e-3);
1010    }
1011    run_mat_test!(
1012        small_foward_inference,
1013        small_forward_inference_cpu,
1014        small_forward_inference_gpu
1015    );
1016
1017    #[derive(Debug)]
1018    struct ForwardInputs<T: Num, D: Device> {
1019        x: Matrix<Owned<T>, DimDyn, D>,
1020        y: Matrix<Owned<T>, DimDyn, D>,
1021        scale: Matrix<Owned<T>, DimDyn, D>,
1022        bias: Matrix<Owned<T>, DimDyn, D>,
1023        mean: Matrix<Owned<T>, DimDyn, D>,
1024        variance: Matrix<Owned<T>, DimDyn, D>,
1025    }
1026
1027    fn small_forward_inference_data<T: Num, D: Device>() -> ForwardInputs<T, D> {
1028        let x = vec![
1029            -1.1258398,
1030            -1.1523602,
1031            -0.25057858,
1032            -0.4338788,
1033            0.84871036,
1034            0.69200915,
1035            -0.31601277,
1036            -2.1152194,
1037            0.32227492,
1038            -1.2633348,
1039            0.3499832,
1040            0.30813393,
1041            0.11984151,
1042            1.2376579,
1043            1.1167772,
1044            -0.24727815,
1045        ];
1046        let y = vec![
1047            -0.6203, -0.5908, -1.5910, -1.3877, 3.3524, 2.9482, 0.3480, -4.2931, -2.2263, -0.4678,
1048            -2.2570, -2.2106, 1.4723, 4.3557, 4.0439, 0.5253,
1049        ];
1050        let mean = vec![-0.7193, -0.4033];
1051        let variance = vec![0.5966, 0.1820];
1052        let scale = vec![-0.8567, 1.1006];
1053        let bias = vec![-1.0712, 0.1227];
1054
1055        let x = x.into_iter().map(T::from_f64).collect();
1056        let y = y.into_iter().map(T::from_f64).collect();
1057        let mean = mean.into_iter().map(T::from_f64).collect();
1058        let variance = variance.into_iter().map(T::from_f64).collect();
1059        let scale = scale.into_iter().map(T::from_f64).collect();
1060        let bias = bias.into_iter().map(T::from_f64).collect();
1061
1062        let x = Matrix::<Owned<T>, DimDyn, D>::from_vec(x, [2, 2, 2, 2]);
1063        let y = Matrix::<Owned<T>, DimDyn, D>::from_vec(y, [2, 2, 2, 2]);
1064        let mean = Matrix::<Owned<T>, DimDyn, D>::from_vec(mean, [2]);
1065        let variance = Matrix::<Owned<T>, DimDyn, D>::from_vec(variance, [2]);
1066        let scale = Matrix::<Owned<T>, DimDyn, D>::from_vec(scale, [2]);
1067        let bias = Matrix::<Owned<T>, DimDyn, D>::from_vec(bias, [2]);
1068        ForwardInputs {
1069            x,
1070            y,
1071            scale,
1072            bias,
1073            mean,
1074            variance,
1075        }
1076    }
1077}