batchnorm_example/
batchnorm_example.rs1use ndarray::{Array, Array4};
2use rand::rngs::SmallRng;
3use rand::Rng;
4use rand::SeedableRng;
5use scirs2_neural::layers::{BatchNorm, Conv2D, Dense, Layer, PaddingMode};
6
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Batch Normalization Example");
9
10 let mut rng = SmallRng::seed_from_u64(42);
12
13 let batch_size = 2;
17 let in_channels = 3; let height = 32;
19 let width = 32;
20
21 let conv = Conv2D::new(in_channels, 16, (3, 3), (1, 1), PaddingMode::Same, &mut rng)?;
23
24 let batch_norm = BatchNorm::new(16, 0.9, 1e-5, &mut rng)?;
26
27 let input = Array4::<f32>::from_elem((batch_size, in_channels, height, width), 0.0);
29
30 let mut input_mut = input.clone();
32 for n in 0..batch_size {
33 for c in 0..in_channels {
34 for h in 0..height {
35 for w in 0..width {
36 input_mut[[n, c, h, w]] = rng.random_range(-1.0..1.0);
37 }
38 }
39 }
40 }
41
42 println!("Input shape: {:?}", input_mut.shape());
44 let conv_output = conv.forward(&input_mut.into_dyn())?;
45 println!("Conv output shape: {:?}", conv_output.shape());
46
47 let bn_output = batch_norm.forward(&conv_output)?;
49 println!("BatchNorm output shape: {:?}", bn_output.shape());
50
51 let conv_mean = compute_mean(&conv_output);
53 let conv_std = compute_std(&conv_output, conv_mean);
54
55 let bn_mean = compute_mean(&bn_output);
56 let bn_std = compute_std(&bn_output, bn_mean);
57
58 println!("\nStatistics before BatchNorm:");
59 println!(" Mean: {:.6}", conv_mean);
60 println!(" Std: {:.6}", conv_std);
61
62 println!("\nStatistics after BatchNorm:");
63 println!(" Mean: {:.6}", bn_mean);
64 println!(" Std: {:.6}", bn_std);
65
66 let mut bn_inference = BatchNorm::new(16, 0.9, 1e-5, &mut rng)?;
68
69 bn_inference.forward(&conv_output)?;
71
72 bn_inference.set_training(false);
74 let bn_inference_output = bn_inference.forward(&conv_output)?;
75
76 let bn_inference_mean = compute_mean(&bn_inference_output);
77 let bn_inference_std = compute_std(&bn_inference_output, bn_inference_mean);
78
79 println!("\nStatistics in inference mode:");
80 println!(" Mean: {:.6}", bn_inference_mean);
81 println!(" Std: {:.6}", bn_inference_std);
82
83 println!("\nExample: BatchNorm in a simple neural network");
85
86 let batch_size = 16;
88 let in_features = 10;
89 let mut input_2d = Array::from_elem((batch_size, in_features), 0.0);
90
91 for n in 0..batch_size {
93 for f in 0..in_features {
94 input_2d[[n, f]] = rng.random_range(-1.0..1.0);
95 }
96 }
97
98 let dense1 = Dense::new(in_features, 32, None, &mut rng)?;
100
101 let bn1 = BatchNorm::new(32, 0.9, 1e-5, &mut rng)?;
103
104 let dense1_output = dense1.forward(&input_2d.into_dyn())?;
106 let bn1_output = bn1.forward(&dense1_output)?;
107
108 println!(
109 "Dense output stats - Mean: {:.6}, Std: {:.6}",
110 compute_mean(&dense1_output),
111 compute_std(&dense1_output, compute_mean(&dense1_output))
112 );
113
114 println!(
115 "After BatchNorm - Mean: {:.6}, Std: {:.6}",
116 compute_mean(&bn1_output),
117 compute_std(&bn1_output, compute_mean(&bn1_output))
118 );
119
120 Ok(())
121}
122
123fn compute_mean<F: num_traits::Float>(arr: &Array<F, ndarray::IxDyn>) -> F {
125 let n = arr.len();
126 let mut sum = F::zero();
127
128 for &val in arr.iter() {
129 sum = sum + val;
130 }
131
132 sum / F::from(n).unwrap()
133}
134
135fn compute_std<F: num_traits::Float>(arr: &Array<F, ndarray::IxDyn>, mean: F) -> F {
137 let n = arr.len();
138 let mut sum_sq = F::zero();
139
140 for &val in arr.iter() {
141 let diff = val - mean;
142 sum_sq = sum_sq + diff * diff;
143 }
144
145 (sum_sq / F::from(n).unwrap()).sqrt()
146}