batchnorm_example/
batchnorm_example.rs

1use ndarray::{Array, Array4};
2use rand::rngs::SmallRng;
3use rand::Rng;
4use rand::SeedableRng;
5use scirs2_neural::layers::{BatchNorm, Conv2D, Dense, Layer, PaddingMode};
6
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Batch Normalization Example");
9
10    // Initialize random number generator with a fixed seed for reproducibility
11    let mut rng = SmallRng::seed_from_u64(42);
12
13    // Create a sample CNN architecture with batch normalization
14
15    // 1. Define input dimensions
16    let batch_size = 2;
17    let in_channels = 3; // RGB input
18    let height = 32;
19    let width = 32;
20
21    // 2. Create convolutional layer
22    let conv = Conv2D::new(in_channels, 16, (3, 3), (1, 1), PaddingMode::Same, &mut rng)?;
23
24    // 3. Create batch normalization layer for the conv output
25    let batch_norm = BatchNorm::new(16, 0.9, 1e-5, &mut rng)?;
26
27    // 4. Create random input data
28    let input = Array4::<f32>::from_elem((batch_size, in_channels, height, width), 0.0);
29
30    // Randomly fill input with values between -1 and 1
31    let mut input_mut = input.clone();
32    for n in 0..batch_size {
33        for c in 0..in_channels {
34            for h in 0..height {
35                for w in 0..width {
36                    input_mut[[n, c, h, w]] = rng.random_range(-1.0..1.0);
37                }
38            }
39        }
40    }
41
42    // 5. Forward pass through conv layer
43    println!("Input shape: {:?}", input_mut.shape());
44    let conv_output = conv.forward(&input_mut.into_dyn())?;
45    println!("Conv output shape: {:?}", conv_output.shape());
46
47    // 6. Forward pass through batch normalization
48    let bn_output = batch_norm.forward(&conv_output)?;
49    println!("BatchNorm output shape: {:?}", bn_output.shape());
50
51    // Print statistics of the conv output and batch norm output
52    let conv_mean = compute_mean(&conv_output);
53    let conv_std = compute_std(&conv_output, conv_mean);
54
55    let bn_mean = compute_mean(&bn_output);
56    let bn_std = compute_std(&bn_output, bn_mean);
57
58    println!("\nStatistics before BatchNorm:");
59    println!("  Mean: {:.6}", conv_mean);
60    println!("  Std:  {:.6}", conv_std);
61
62    println!("\nStatistics after BatchNorm:");
63    println!("  Mean: {:.6}", bn_mean);
64    println!("  Std:  {:.6}", bn_std);
65
66    // Switch to inference mode
67    let mut bn_inference = BatchNorm::new(16, 0.9, 1e-5, &mut rng)?;
68
69    // First do a forward pass in training mode to accumulate statistics
70    bn_inference.forward(&conv_output)?;
71
72    // Now switch to inference mode
73    bn_inference.set_training(false);
74    let bn_inference_output = bn_inference.forward(&conv_output)?;
75
76    let bn_inference_mean = compute_mean(&bn_inference_output);
77    let bn_inference_std = compute_std(&bn_inference_output, bn_inference_mean);
78
79    println!("\nStatistics in inference mode:");
80    println!("  Mean: {:.6}", bn_inference_mean);
81    println!("  Std:  {:.6}", bn_inference_std);
82
83    // Example of using BatchNorm in a simple neural network
84    println!("\nExample: BatchNorm in a simple neural network");
85
86    // Create a random 2D input (batch_size, features)
87    let batch_size = 16;
88    let in_features = 10;
89    let mut input_2d = Array::from_elem((batch_size, in_features), 0.0);
90
91    // Randomly fill input with values between -1 and 1
92    for n in 0..batch_size {
93        for f in 0..in_features {
94            input_2d[[n, f]] = rng.random_range(-1.0..1.0);
95        }
96    }
97
98    // Create dense layer
99    let dense1 = Dense::new(in_features, 32, None, &mut rng)?;
100
101    // Create batch norm for dense output
102    let bn1 = BatchNorm::new(32, 0.9, 1e-5, &mut rng)?;
103
104    // Forward passes
105    let dense1_output = dense1.forward(&input_2d.into_dyn())?;
106    let bn1_output = bn1.forward(&dense1_output)?;
107
108    println!(
109        "Dense output stats - Mean: {:.6}, Std: {:.6}",
110        compute_mean(&dense1_output),
111        compute_std(&dense1_output, compute_mean(&dense1_output))
112    );
113
114    println!(
115        "After BatchNorm - Mean: {:.6}, Std: {:.6}",
116        compute_mean(&bn1_output),
117        compute_std(&bn1_output, compute_mean(&bn1_output))
118    );
119
120    Ok(())
121}
122
123// Helper function to compute mean of an array
124fn compute_mean<F: num_traits::Float>(arr: &Array<F, ndarray::IxDyn>) -> F {
125    let n = arr.len();
126    let mut sum = F::zero();
127
128    for &val in arr.iter() {
129        sum = sum + val;
130    }
131
132    sum / F::from(n).unwrap()
133}
134
135// Helper function to compute standard deviation
136fn compute_std<F: num_traits::Float>(arr: &Array<F, ndarray::IxDyn>, mean: F) -> F {
137    let n = arr.len();
138    let mut sum_sq = F::zero();
139
140    for &val in arr.iter() {
141        let diff = val - mean;
142        sum_sq = sum_sq + diff * diff;
143    }
144
145    (sum_sq / F::from(n).unwrap()).sqrt()
146}