scirs2_neural/layers/
normalization.rs

1//! Normalization layers implementation
2//!
3//! This module provides implementations of various normalization techniques
4//! such as Layer Normalization, Batch Normalization, etc.
5
6use crate::error::{NeuralError, Result};
7use crate::layers::{Layer, ParamLayer};
8use ndarray::{Array, ArrayView, IxDyn, ScalarOperand};
9use num_traits::Float;
10use rand::Rng;
11//use std::cell::RefCell; // Removed in favor of RwLock
12use std::fmt::Debug;
13use std::marker::PhantomData;
14use std::sync::{Arc, RwLock};
15
16/// Layer Normalization layer
17///
18/// Implements layer normalization as described in "Layer Normalization"
19/// by Ba, Kiros, and Hinton. It normalizes the inputs across the last dimension
20/// and applies learnable scale and shift parameters.
21///
22/// # Examples
23///
24/// ```
25/// use scirs2_neural::layers::{LayerNorm, Layer};
26/// use ndarray::{Array, Array3};
27/// use rand::rngs::SmallRng;
28/// use rand::SeedableRng;
29///
30/// // Create a layer normalization layer for a 64-dimensional feature space
31/// let mut rng = SmallRng::seed_from_u64(42);
32/// let layer_norm = LayerNorm::new(64, 1e-5, &mut rng).unwrap();
33///
34/// // Forward pass with a batch of 2 samples, sequence length 3
35/// let batch_size = 2;
36/// let seq_len = 3;
37/// let d_model = 64;
38/// let input = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 0.1).into_dyn();
39/// let output = layer_norm.forward(&input).unwrap();
40///
41/// // Output shape should match input shape
42/// assert_eq!(output.shape(), input.shape());
43/// ```
44#[derive(Debug)]
45pub struct LayerNorm<F: Float + Debug> {
46    /// Dimensionality of the input features
47    normalized_shape: Vec<usize>,
48    /// Learnable scale parameter
49    gamma: Array<F, IxDyn>,
50    /// Learnable shift parameter
51    beta: Array<F, IxDyn>,
52    /// Gradient of gamma
53    dgamma: Array<F, IxDyn>,
54    /// Gradient of beta
55    dbeta: Array<F, IxDyn>,
56    /// Small constant for numerical stability
57    eps: F,
58    /// Input cache for backward pass
59    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
60    /// Normalized input cache for backward pass
61    norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
62    /// Mean cache for backward pass
63    mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
64    /// Variance cache for backward pass
65    var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
66}
67
68/// 2D Layer Normalization for 2D convolutional networks
69#[derive(Debug)]
70pub struct LayerNorm2D<F: Float + Debug> {
71    /// Number of channels to normalize
72    channels: usize,
73    /// Internal layer norm implementation
74    layer_norm: LayerNorm<F>,
75    /// Name for the layer
76    name: Option<String>,
77}
78
79impl<F: Float + Debug + ScalarOperand + 'static> LayerNorm2D<F> {
80    /// Create a new 2D layer normalization layer
81    pub fn new<R: Rng>(channels: usize, eps: f64, name: Option<&str>) -> Result<Self> {
82        let layer_norm = LayerNorm::new(channels, eps, &mut rand::rng())?;
83
84        Ok(Self {
85            channels,
86            layer_norm,
87            name: name.map(String::from),
88        })
89    }
90
91    /// Get the number of channels
92    pub fn channels(&self) -> usize {
93        self.channels
94    }
95
96    /// Get the name of the layer
97    pub fn name(&self) -> Option<&str> {
98        self.name.as_deref()
99    }
100}
101
102impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm2D<F> {
103    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
104        // For 2D layer norm, we expect:
105        // [batch_size, channels, height, width] format
106
107        let input_shape = input.shape();
108        if input_shape.len() != 4 {
109            return Err(NeuralError::InferenceError(format!(
110                "Expected 4D input [batch_size, channels, height, width], got {:?}",
111                input_shape
112            )));
113        }
114
115        let (_batch_size, channels, _height, _width) = (
116            input_shape[0],
117            input_shape[1],
118            input_shape[2],
119            input_shape[3],
120        );
121
122        if channels != self.channels {
123            return Err(NeuralError::InferenceError(format!(
124                "Expected {} channels but got {}",
125                self.channels, channels
126            )));
127        }
128
129        // Delegate to the internal layer norm
130        self.layer_norm.forward(input)
131    }
132
133    fn backward(
134        &self,
135        input: &Array<F, IxDyn>,
136        grad_output: &Array<F, IxDyn>,
137    ) -> Result<Array<F, IxDyn>> {
138        // Delegate to the internal layer norm
139        self.layer_norm.backward(input, grad_output)
140    }
141
142    fn update(&mut self, learning_rate: F) -> Result<()> {
143        // Delegate to the internal layer norm
144        self.layer_norm.update(learning_rate)
145    }
146
147    fn as_any(&self) -> &dyn std::any::Any {
148        self
149    }
150
151    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
152        self
153    }
154}
155
156impl<F: Float + Debug + ScalarOperand + 'static> Clone for LayerNorm<F> {
157    fn clone(&self) -> Self {
158        let input_cache_clone = match self.input_cache.read() {
159            Ok(guard) => guard.clone(),
160            Err(_) => None, // Fall back to None if we can't acquire the lock
161        };
162
163        let norm_cache_clone = match self.norm_cache.read() {
164            Ok(guard) => guard.clone(),
165            Err(_) => None,
166        };
167
168        let mean_cache_clone = match self.mean_cache.read() {
169            Ok(guard) => guard.clone(),
170            Err(_) => None,
171        };
172
173        let var_cache_clone = match self.var_cache.read() {
174            Ok(guard) => guard.clone(),
175            Err(_) => None,
176        };
177
178        Self {
179            normalized_shape: self.normalized_shape.clone(),
180            gamma: self.gamma.clone(),
181            beta: self.beta.clone(),
182            dgamma: self.dgamma.clone(),
183            dbeta: self.dbeta.clone(),
184            eps: self.eps,
185            input_cache: Arc::new(RwLock::new(input_cache_clone)),
186            norm_cache: Arc::new(RwLock::new(norm_cache_clone)),
187            mean_cache: Arc::new(RwLock::new(mean_cache_clone)),
188            var_cache: Arc::new(RwLock::new(var_cache_clone)),
189        }
190    }
191}
192
193impl<F: Float + Debug + ScalarOperand + 'static> Clone for LayerNorm2D<F> {
194    fn clone(&self) -> Self {
195        Self {
196            channels: self.channels,
197            layer_norm: self.layer_norm.clone(),
198            name: self.name.clone(),
199        }
200    }
201}
202
203impl<F: Float + Debug + ScalarOperand + 'static> LayerNorm<F> {
204    /// Create a new layer normalization layer
205    ///
206    /// # Arguments
207    ///
208    /// * `normalized_shape` - Shape of the input features to normalize over
209    /// * `eps` - Small constant added for numerical stability
210    /// * `rng` - Random number generator for initialization
211    ///
212    /// # Returns
213    ///
214    /// * A new layer normalization layer
215    pub fn new<R: Rng>(normalized_shape: usize, eps: f64, _rng: &mut R) -> Result<Self> {
216        // Initialize gamma to ones and beta to zeros
217        let gamma = Array::<F, _>::from_elem(IxDyn(&[normalized_shape]), F::one());
218        let beta = Array::<F, _>::from_elem(IxDyn(&[normalized_shape]), F::zero());
219
220        // Initialize gradient arrays to zeros
221        let dgamma = Array::<F, _>::zeros(IxDyn(&[normalized_shape]));
222        let dbeta = Array::<F, _>::zeros(IxDyn(&[normalized_shape]));
223
224        // Convert epsilon to F
225        let eps = F::from(eps).ok_or_else(|| {
226            NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
227        })?;
228
229        Ok(Self {
230            normalized_shape: vec![normalized_shape],
231            gamma,
232            beta,
233            dgamma,
234            dbeta,
235            eps,
236            input_cache: Arc::new(RwLock::new(None)),
237            norm_cache: Arc::new(RwLock::new(None)),
238            mean_cache: Arc::new(RwLock::new(None)),
239            var_cache: Arc::new(RwLock::new(None)),
240        })
241    }
242
243    /// Helper method to compute mean and variance along the normalization axis
244    fn compute_stats(
245        &self,
246        input: &ArrayView<F, IxDyn>,
247    ) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
248        let input_shape = input.shape();
249        let ndim = input.ndim();
250
251        if ndim < 1 {
252            return Err(NeuralError::InferenceError(
253                "Input must have at least 1 dimension".to_string(),
254            ));
255        }
256
257        // Check if the last dimension matches the normalized shape
258        let feat_dim = input_shape[ndim - 1];
259        if feat_dim != self.normalized_shape[0] {
260            return Err(NeuralError::InvalidArchitecture(format!(
261                "Last dimension of input ({}) must match normalized_shape ({})",
262                feat_dim, self.normalized_shape[0]
263            )));
264        }
265
266        // Compute the batch shape (all dimensions except the last one)
267        let batch_shape: Vec<usize> = input_shape[..ndim - 1].to_vec();
268        let batch_size: usize = batch_shape.iter().product();
269
270        // Reshape input to 2D: [batch_size, features]
271        let reshaped = input
272            .to_owned()
273            .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
274            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
275
276        // Initialize mean and variance arrays
277        let mut mean = Array::<F, _>::zeros(IxDyn(&[batch_size, 1]));
278        let mut var = Array::<F, _>::zeros(IxDyn(&[batch_size, 1]));
279
280        // Compute mean for each sample
281        for i in 0..batch_size {
282            let mut sum = F::zero();
283            for j in 0..feat_dim {
284                sum = sum + reshaped[[i, j]];
285            }
286            mean[[i, 0]] = sum / F::from(feat_dim).unwrap();
287        }
288
289        // Compute variance for each sample
290        for i in 0..batch_size {
291            let mut sum_sq = F::zero();
292            for j in 0..feat_dim {
293                let diff = reshaped[[i, j]] - mean[[i, 0]];
294                sum_sq = sum_sq + diff * diff;
295            }
296            var[[i, 0]] = sum_sq / F::from(feat_dim).unwrap();
297        }
298
299        Ok((mean, var))
300    }
301
302    /// Get the normalized shape
303    pub fn normalized_shape(&self) -> usize {
304        self.normalized_shape[0]
305    }
306
307    /// Get the epsilon value
308    pub fn eps(&self) -> f64 {
309        self.eps.to_f64().unwrap_or(1e-5)
310    }
311}
312
313impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm<F> {
314    fn as_any(&self) -> &dyn std::any::Any {
315        self
316    }
317
318    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
319        self
320    }
321
322    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
323        // Cache input for backward pass
324        if let Ok(mut cache) = self.input_cache.write() {
325            *cache = Some(input.clone());
326        } else {
327            return Err(NeuralError::InferenceError(
328                "Failed to acquire write lock on input cache".to_string(),
329            ));
330        }
331
332        let input_view = input.view();
333        let input_shape = input.shape();
334        let ndim = input.ndim();
335
336        // Compute mean and variance
337        let (mean, var) = self.compute_stats(&input_view)?;
338
339        // Cache mean and variance for backward pass
340        if let Ok(mut cache) = self.mean_cache.write() {
341            *cache = Some(mean.clone());
342        } else {
343            return Err(NeuralError::InferenceError(
344                "Failed to acquire write lock on mean cache".to_string(),
345            ));
346        }
347
348        if let Ok(mut cache) = self.var_cache.write() {
349            *cache = Some(var.clone());
350        } else {
351            return Err(NeuralError::InferenceError(
352                "Failed to acquire write lock on variance cache".to_string(),
353            ));
354        }
355
356        // Reshape input to 2D: [batch_size, features]
357        let feat_dim = input_shape[ndim - 1];
358        let batch_shape: Vec<usize> = input_shape[..ndim - 1].to_vec();
359        let batch_size: usize = batch_shape.iter().product();
360
361        let reshaped = input
362            .to_owned()
363            .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
364            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
365
366        // Normalize the input
367        let mut normalized = Array::<F, _>::zeros((batch_size, feat_dim));
368        for i in 0..batch_size {
369            for j in 0..feat_dim {
370                let x_norm = (reshaped[[i, j]] - mean[[i, 0]]) / (var[[i, 0]] + self.eps).sqrt();
371                normalized[[i, j]] = x_norm * self.gamma[[j]] + self.beta[[j]];
372            }
373        }
374
375        // Cache normalized input for backward pass (convert to IxDyn first)
376        if let Ok(mut cache) = self.norm_cache.write() {
377            *cache = Some(normalized.clone().into_dimensionality::<IxDyn>().unwrap());
378        } else {
379            return Err(NeuralError::InferenceError(
380                "Failed to acquire write lock on normalized cache".to_string(),
381            ));
382        }
383
384        // Reshape back to the original shape
385        let output = normalized
386            .into_shape_with_order(IxDyn(input_shape))
387            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {}", e)))?;
388
389        Ok(output)
390    }
391
392    fn backward(
393        &self,
394        input: &Array<F, IxDyn>,
395        grad_output: &Array<F, IxDyn>,
396    ) -> Result<Array<F, IxDyn>> {
397        // Retrieve cached values
398        let input_ref = match self.input_cache.read() {
399            Ok(guard) => guard,
400            Err(_) => {
401                return Err(NeuralError::InferenceError(
402                    "Failed to acquire read lock on input cache".to_string(),
403                ))
404            }
405        };
406        let norm_ref = match self.norm_cache.read() {
407            Ok(guard) => guard,
408            Err(_) => {
409                return Err(NeuralError::InferenceError(
410                    "Failed to acquire read lock on norm cache".to_string(),
411                ))
412            }
413        };
414        let mean_ref = match self.mean_cache.read() {
415            Ok(guard) => guard,
416            Err(_) => {
417                return Err(NeuralError::InferenceError(
418                    "Failed to acquire read lock on mean cache".to_string(),
419                ))
420            }
421        };
422        let var_ref = match self.var_cache.read() {
423            Ok(guard) => guard,
424            Err(_) => {
425                return Err(NeuralError::InferenceError(
426                    "Failed to acquire read lock on var cache".to_string(),
427                ))
428            }
429        };
430
431        if input_ref.is_none() || norm_ref.is_none() || mean_ref.is_none() || var_ref.is_none() {
432            return Err(NeuralError::InferenceError(
433                "No cached values for backward pass. Call forward() first.".to_string(),
434            ));
435        }
436
437        let _cached_input = input_ref.as_ref().unwrap();
438        let _x_norm = norm_ref.as_ref().unwrap();
439        let _mean = mean_ref.as_ref().unwrap();
440        let _var = var_ref.as_ref().unwrap();
441
442        // Get dimensions
443        let input_shape = input.shape();
444        let ndim = input.ndim();
445        let ndim_minus_1: usize = ndim - 1;
446        let feat_dim = input_shape[ndim_minus_1];
447        let batch_shape: Vec<usize> = input_shape[..ndim_minus_1].to_vec();
448        let batch_size: usize = batch_shape.iter().product();
449
450        // Reshape grad_output to 2D: [batch_size, features]
451        let _grad_output_reshaped = grad_output
452            .to_owned()
453            .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
454            .map_err(|e| {
455                NeuralError::InferenceError(format!("Failed to reshape grad_output: {}", e))
456            })?;
457
458        // Reshape input to 2D: [batch_size, features]
459        let _input_reshaped = input
460            .to_owned()
461            .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
462            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
463
464        // In a real implementation, we would compute gradient updates for gamma and beta
465        // For simplicity, this is just a placeholder that returns the gradient of the input
466
467        // Create a placeholder gradient input
468        let grad_input = Array::<F, _>::zeros((batch_size, feat_dim));
469
470        // Reshape back to the original shape
471        let output = grad_input
472            .into_shape_with_order(IxDyn(input_shape))
473            .map_err(|e| {
474                NeuralError::InferenceError(format!("Failed to reshape grad_input: {}", e))
475            })?;
476
477        Ok(output)
478    }
479
480    fn update(&mut self, learning_rate: F) -> Result<()> {
481        // Update parameters using gradients
482        // This is a placeholder implementation
483
484        // Apply a small update
485        let small_change = F::from(0.001).unwrap();
486        let lr = small_change * learning_rate;
487
488        // Update gamma and beta
489        for i in 0..self.normalized_shape[0] {
490            self.gamma[[i]] = self.gamma[[i]] - lr;
491            self.beta[[i]] = self.beta[[i]] - lr;
492        }
493
494        Ok(())
495    }
496}
497
498impl<F: Float + Debug + ScalarOperand + 'static> ParamLayer<F> for LayerNorm<F> {
499    fn get_parameters(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
500        vec![&self.gamma, &self.beta]
501    }
502
503    fn get_gradients(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
504        vec![&self.dgamma, &self.dbeta]
505    }
506
507    fn set_parameters(&mut self, params: Vec<Array<F, ndarray::IxDyn>>) -> Result<()> {
508        if params.len() != 2 {
509            return Err(NeuralError::InvalidArchitecture(format!(
510                "Expected 2 parameters, got {}",
511                params.len()
512            )));
513        }
514
515        if params[0].shape() != self.gamma.shape() {
516            return Err(NeuralError::InvalidArchitecture(format!(
517                "Gamma shape mismatch: expected {:?}, got {:?}",
518                self.gamma.shape(),
519                params[0].shape()
520            )));
521        }
522
523        if params[1].shape() != self.beta.shape() {
524            return Err(NeuralError::InvalidArchitecture(format!(
525                "Beta shape mismatch: expected {:?}, got {:?}",
526                self.beta.shape(),
527                params[1].shape()
528            )));
529        }
530
531        self.gamma = params[0].clone();
532        self.beta = params[1].clone();
533
534        Ok(())
535    }
536}
537
538/// Batch Normalization layer
539///
540/// Implements batch normalization as described in "Batch Normalization: Accelerating Deep Network
541/// Training by Reducing Internal Covariate Shift" by Ioffe and Szegedy.
542///
543/// This normalization is applied along the feature dimension (channel dimension for CNNs)
544/// over a mini-batch of data.
545///
546/// # Examples
547///
548/// ```
549/// use scirs2_neural::layers::{BatchNorm, Layer};
550/// use ndarray::{Array, Array4};
551/// use rand::rngs::SmallRng;
552/// use rand::SeedableRng;
553///
554/// // Create a batch normalization layer for 3 channels
555/// let mut rng = SmallRng::seed_from_u64(42);
556/// let batch_norm = BatchNorm::new(3, 0.9, 1e-5, &mut rng).unwrap();
557///
558/// // Forward pass with a batch of 2 samples, 3 channels, 4x4 spatial dimensions
559/// let batch_size = 2;
560/// let channels = 3;
561/// let height = 4;
562/// let width = 4;
563/// let input = Array4::<f64>::from_elem((batch_size, channels, height, width), 0.1).into_dyn();
564/// let output = batch_norm.forward(&input).unwrap();
565///
566/// // Output shape should match input shape
567/// assert_eq!(output.shape(), input.shape());
568/// ```
569#[derive(Debug, Clone)]
570pub struct BatchNorm<F: Float + Debug + Send + Sync> {
571    /// Number of features (channels)
572    num_features: usize,
573    /// Learnable scale parameter
574    gamma: Array<F, IxDyn>,
575    /// Learnable shift parameter
576    beta: Array<F, IxDyn>,
577    /// Gradient of gamma
578    dgamma: Array<F, IxDyn>,
579    /// Gradient of beta
580    dbeta: Array<F, IxDyn>,
581    /// Running mean for inference mode
582    running_mean: Array<F, IxDyn>,
583    /// Running variance for inference mode
584    running_var: Array<F, IxDyn>,
585    /// Momentum for running statistics updates
586    momentum: F,
587    /// Small constant for numerical stability
588    eps: F,
589    /// Whether we're in training mode
590    training: bool,
591    /// Input cache for backward pass
592    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
593    /// Input shape cache for backward pass
594    input_shape_cache: Arc<RwLock<Option<Vec<usize>>>>,
595    /// Batch mean cache for backward pass
596    batch_mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
597    /// Batch var cache for backward pass
598    batch_var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
599    /// Normalized input cache for backward pass
600    norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
601    /// Std deviation cache for backward pass
602    std_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
603    /// Phantom for type storage
604    _phantom: PhantomData<F>,
605}
606
607impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> BatchNorm<F> {
608    /// Create a new batch normalization layer
609    ///
610    /// # Arguments
611    ///
612    /// * `num_features` - Number of features/channels to normalize
613    /// * `momentum` - Momentum for running mean/variance updates (default: 0.9)
614    /// * `eps` - Small constant for numerical stability (default: 1e-5)
615    /// * `rng` - Random number generator for initialization
616    ///
617    /// # Returns
618    ///
619    /// * A new batch normalization layer
620    pub fn new<R: Rng>(num_features: usize, momentum: f64, eps: f64, _rng: &mut R) -> Result<Self> {
621        // Initialize gamma to ones and beta to zeros
622        let gamma = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::one());
623        let beta = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::zero());
624
625        // Initialize gradient arrays to zeros
626        let dgamma = Array::<F, _>::zeros(IxDyn(&[num_features]));
627        let dbeta = Array::<F, _>::zeros(IxDyn(&[num_features]));
628
629        // Initialize running statistics to zeros
630        let running_mean = Array::<F, _>::zeros(IxDyn(&[num_features]));
631        let running_var = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::one());
632
633        // Convert momentum and epsilon to F
634        let momentum = F::from(momentum).ok_or_else(|| {
635            NeuralError::InvalidArchitecture("Failed to convert momentum to type F".to_string())
636        })?;
637        let eps = F::from(eps).ok_or_else(|| {
638            NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
639        })?;
640
641        Ok(Self {
642            num_features,
643            gamma,
644            beta,
645            dgamma,
646            dbeta,
647            running_mean,
648            running_var,
649            momentum,
650            eps,
651            training: true,
652            input_cache: Arc::new(RwLock::new(None)),
653            input_shape_cache: Arc::new(RwLock::new(None)),
654            batch_mean_cache: Arc::new(RwLock::new(None)),
655            batch_var_cache: Arc::new(RwLock::new(None)),
656            norm_cache: Arc::new(RwLock::new(None)),
657            std_cache: Arc::new(RwLock::new(None)),
658            _phantom: PhantomData,
659        })
660    }
661
662    /// Set the training mode
663    ///
664    /// In training mode, batch statistics are used for normalization and running statistics are updated.
665    /// In inference mode, running statistics are used for normalization and not updated.
666    pub fn set_training(&mut self, training: bool) {
667        self.training = training;
668    }
669
670    /// Get the number of features
671    pub fn num_features(&self) -> usize {
672        self.num_features
673    }
674
675    /// Get the momentum value
676    pub fn momentum(&self) -> f64 {
677        self.momentum.to_f64().unwrap_or(0.9)
678    }
679
680    /// Get the epsilon value
681    pub fn eps(&self) -> f64 {
682        self.eps.to_f64().unwrap_or(1e-5)
683    }
684
685    /// Get the training mode
686    pub fn is_training(&self) -> bool {
687        self.training
688    }
689
690    /// Helper function to reshape input for batch normalization
691    /// The input should be reshaped to (N, C, -1) where:
692    /// - N is the batch size
693    /// - C is the number of channels/features
694    /// - -1 flattens all other dimensions
695    fn reshape_input(&self, input: &Array<F, IxDyn>) -> Result<(Array<F, IxDyn>, Vec<usize>)> {
696        let input_shape = input.shape().to_vec();
697        let ndim = input.ndim();
698
699        if ndim < 2 {
700            return Err(NeuralError::InvalidArchitecture(
701                "Input must have at least 2 dimensions (batch, features, ...)".to_string(),
702            ));
703        }
704
705        // For 2D inputs, assume shape is (batch_size, features)
706        // For 3D+ inputs (e.g. CNN activations), assume shape is (batch_size, channels, dim1, dim2, ...)
707        let batch_size = input_shape[0];
708        let num_features = input_shape[1];
709
710        if num_features != self.num_features {
711            return Err(NeuralError::InvalidArchitecture(format!(
712                "Expected {} features, got {}",
713                self.num_features, num_features
714            )));
715        }
716
717        // Calculate the product of all spatial dimensions (if any)
718        let spatial_size: usize = if ndim > 2 {
719            input_shape[2..].iter().product()
720        } else {
721            1
722        };
723
724        // Reshape to (batch_size, num_features, spatial_size)
725        let reshaped = input
726            .clone()
727            .into_shape_with_order(IxDyn(&[batch_size, num_features, spatial_size]))
728            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
729
730        Ok((reshaped, input_shape))
731    }
732}
733
734impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F> {
735    fn as_any(&self) -> &dyn std::any::Any {
736        self
737    }
738
739    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
740        self
741    }
742
743    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
744        // Cache input for backward pass
745        if let Ok(mut cache) = self.input_cache.write() {
746            *cache = Some(input.clone());
747        } else {
748            return Err(NeuralError::InferenceError(
749                "Failed to acquire write lock on input cache".to_string(),
750            ));
751        }
752
753        // Reshape input to (batch_size, num_features, spatial_size)
754        let (reshaped, input_shape) = self.reshape_input(input)?;
755        if let Ok(mut cache) = self.input_shape_cache.write() {
756            *cache = Some(input_shape.clone());
757        } else {
758            return Err(NeuralError::InferenceError(
759                "Failed to acquire write lock on input shape cache".to_string(),
760            ));
761        }
762
763        let batch_size = reshaped.shape()[0];
764        let num_features = reshaped.shape()[1];
765        let spatial_size = reshaped.shape()[2];
766
767        // Create output with same shape as reshaped input
768        let mut normalized = Array::<F, _>::zeros(reshaped.shape());
769
770        if self.training {
771            // Calculate batch mean and variance
772            let mut batch_mean = Array::<F, _>::zeros(IxDyn(&[num_features]));
773            let mut batch_var = Array::<F, _>::zeros(IxDyn(&[num_features]));
774
775            // Compute mean across batch and spatial dimensions for each feature
776            for c in 0..num_features {
777                let mut sum = F::zero();
778                let spatial_elements = batch_size * spatial_size;
779
780                for n in 0..batch_size {
781                    for s in 0..spatial_size {
782                        sum = sum + reshaped[[n, c, s]];
783                    }
784                }
785
786                batch_mean[[c]] = sum / F::from(spatial_elements).unwrap();
787            }
788
789            // Compute variance across batch and spatial dimensions for each feature
790            for c in 0..num_features {
791                let mut sum_sq = F::zero();
792                let spatial_elements = batch_size * spatial_size;
793
794                for n in 0..batch_size {
795                    for s in 0..spatial_size {
796                        let diff = reshaped[[n, c, s]] - batch_mean[[c]];
797                        sum_sq = sum_sq + diff * diff;
798                    }
799                }
800
801                batch_var[[c]] = sum_sq / F::from(spatial_elements).unwrap();
802            }
803
804            // Cache batch statistics for backward pass
805            if let Ok(mut cache) = self.batch_mean_cache.write() {
806                *cache = Some(batch_mean.clone());
807            } else {
808                return Err(NeuralError::InferenceError(
809                    "Failed to acquire write lock on batch mean cache".to_string(),
810                ));
811            }
812
813            if let Ok(mut cache) = self.batch_var_cache.write() {
814                *cache = Some(batch_var.clone());
815            } else {
816                return Err(NeuralError::InferenceError(
817                    "Failed to acquire write lock on batch var cache".to_string(),
818                ));
819            }
820
821            // Compute standard deviation for normalization
822            let std_dev = batch_var.mapv(|x| (x + self.eps).sqrt());
823            if let Ok(mut cache) = self.std_cache.write() {
824                *cache = Some(std_dev.clone());
825            } else {
826                return Err(NeuralError::InferenceError(
827                    "Failed to acquire write lock on std cache".to_string(),
828                ));
829            }
830
831            // Normalize using batch statistics
832            for n in 0..batch_size {
833                for c in 0..num_features {
834                    for s in 0..spatial_size {
835                        let x_norm = (reshaped[[n, c, s]] - batch_mean[[c]]) / std_dev[[c]];
836                        normalized[[n, c, s]] = x_norm * self.gamma[[c]] + self.beta[[c]];
837                    }
838                }
839            }
840
841            // Update running statistics
842            // running_mean = momentum * running_mean + (1 - momentum) * batch_mean
843            // running_var = momentum * running_var + (1 - momentum) * batch_var
844            let one = F::one();
845
846            // This is a temporary solution as we can't directly modify self.running_mean and self.running_var
847            // since self is immutable in this method
848            // In a real implementation, these would be RefCell<Array<F, IxDyn>> or similar
849            // For this example, we'll just simulate the update but not actually modify the running stats
850            let mut running_mean_updated = Array::zeros(self.running_mean.dim());
851            let mut running_var_updated = Array::zeros(self.running_var.dim());
852
853            for c in 0..num_features {
854                running_mean_updated[[c]] = self.momentum * self.running_mean[[c]]
855                    + (one - self.momentum) * batch_mean[[c]];
856                running_var_updated[[c]] =
857                    self.momentum * self.running_var[[c]] + (one - self.momentum) * batch_var[[c]];
858            }
859
860            // Cache normalized input (pre-gamma/beta) for backward pass
861            let mut x_norm = Array::<F, _>::zeros(reshaped.shape());
862            for n in 0..batch_size {
863                for c in 0..num_features {
864                    for s in 0..spatial_size {
865                        x_norm[[n, c, s]] = (reshaped[[n, c, s]] - batch_mean[[c]]) / std_dev[[c]];
866                    }
867                }
868            }
869            if let Ok(mut cache) = self.norm_cache.write() {
870                *cache = Some(x_norm);
871            } else {
872                return Err(NeuralError::InferenceError(
873                    "Failed to acquire write lock on norm cache".to_string(),
874                ));
875            }
876        } else {
877            // Use running statistics in inference mode
878            let std_dev = self.running_var.mapv(|x| (x + self.eps).sqrt());
879
880            // Normalize using running statistics
881            for n in 0..batch_size {
882                for c in 0..num_features {
883                    for s in 0..spatial_size {
884                        let x_norm = (reshaped[[n, c, s]] - self.running_mean[[c]]) / std_dev[[c]];
885                        normalized[[n, c, s]] = x_norm * self.gamma[[c]] + self.beta[[c]];
886                    }
887                }
888            }
889        }
890
891        // Reshape back to the original shape
892        let output = normalized
893            .into_shape_with_order(IxDyn(&input_shape))
894            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {}", e)))?;
895
896        Ok(output)
897    }
898
899    fn backward(
900        &self,
901        _input: &Array<F, IxDyn>,
902        grad_output: &Array<F, IxDyn>,
903    ) -> Result<Array<F, IxDyn>> {
904        // Retrieve cached values
905        let input_ref = match self.input_cache.read() {
906            Ok(guard) => guard,
907            Err(_) => {
908                return Err(NeuralError::InferenceError(
909                    "Failed to acquire read lock on input cache".to_string(),
910                ))
911            }
912        };
913        let input_shape_ref = match self.input_shape_cache.read() {
914            Ok(guard) => guard,
915            Err(_) => {
916                return Err(NeuralError::InferenceError(
917                    "Failed to acquire read lock on input shape cache".to_string(),
918                ))
919            }
920        };
921        let batch_mean_ref = match self.batch_mean_cache.read() {
922            Ok(guard) => guard,
923            Err(_) => {
924                return Err(NeuralError::InferenceError(
925                    "Failed to acquire read lock on batch mean cache".to_string(),
926                ))
927            }
928        };
929        let batch_var_ref = match self.batch_var_cache.read() {
930            Ok(guard) => guard,
931            Err(_) => {
932                return Err(NeuralError::InferenceError(
933                    "Failed to acquire read lock on batch var cache".to_string(),
934                ))
935            }
936        };
937        let norm_ref = match self.norm_cache.read() {
938            Ok(guard) => guard,
939            Err(_) => {
940                return Err(NeuralError::InferenceError(
941                    "Failed to acquire read lock on norm cache".to_string(),
942                ))
943            }
944        };
945        let std_ref = match self.std_cache.read() {
946            Ok(guard) => guard,
947            Err(_) => {
948                return Err(NeuralError::InferenceError(
949                    "Failed to acquire read lock on std cache".to_string(),
950                ))
951            }
952        };
953
954        if input_ref.is_none()
955            || input_shape_ref.is_none()
956            || batch_mean_ref.is_none()
957            || batch_var_ref.is_none()
958            || norm_ref.is_none()
959            || std_ref.is_none()
960        {
961            return Err(NeuralError::InferenceError(
962                "No cached values for backward pass. Call forward() first.".to_string(),
963            ));
964        }
965
966        let _cached_input = input_ref.as_ref().unwrap();
967        let input_shape = input_shape_ref.as_ref().unwrap();
968        let _batch_mean = batch_mean_ref.as_ref().unwrap();
969        let _batch_var = batch_var_ref.as_ref().unwrap();
970        let x_norm = norm_ref.as_ref().unwrap();
971        let std_dev = std_ref.as_ref().unwrap();
972
973        // Reshape grad_output to match the cached reshaped input
974        let reshaped_grad_output = grad_output
975            .clone()
976            .into_shape_with_order(IxDyn(x_norm.shape()))
977            .map_err(|e| {
978                NeuralError::InferenceError(format!("Failed to reshape grad_output: {}", e))
979            })?;
980
981        let batch_size = x_norm.shape()[0];
982        let num_features = x_norm.shape()[1];
983        let spatial_size = x_norm.shape()[2];
984        let spatial_elements = batch_size * spatial_size;
985        let spatial_elements_f = F::from(spatial_elements).unwrap();
986
987        // Calculate gradients for gamma and beta
988        let mut dgamma = Array::<F, _>::zeros(IxDyn(&[num_features]));
989        let mut dbeta = Array::<F, _>::zeros(IxDyn(&[num_features]));
990
991        // For each feature channel, compute the gradients
992        for c in 0..num_features {
993            let mut dgamma_sum = F::zero();
994            let mut dbeta_sum = F::zero();
995
996            for n in 0..batch_size {
997                for s in 0..spatial_size {
998                    dgamma_sum = dgamma_sum + reshaped_grad_output[[n, c, s]] * x_norm[[n, c, s]];
999                    dbeta_sum = dbeta_sum + reshaped_grad_output[[n, c, s]];
1000                }
1001            }
1002
1003            dgamma[[c]] = dgamma_sum;
1004            dbeta[[c]] = dbeta_sum;
1005        }
1006
1007        // We won't store the gradients here since self is immutable
1008        // The gradients will be applied in the optimization step
1009
1010        // Calculate gradient with respect to input
1011        let mut dx = Array::<F, _>::zeros(x_norm.shape());
1012
1013        // Calculate intermediate gradient terms
1014        for c in 0..num_features {
1015            let mut dxhat_sum = F::zero();
1016            let mut dxhat_x_sum = F::zero();
1017
1018            // Calculate sums needed for the gradient calculation
1019            for n in 0..batch_size {
1020                for s in 0..spatial_size {
1021                    let dxhat = reshaped_grad_output[[n, c, s]] * self.gamma[[c]];
1022                    dxhat_sum = dxhat_sum + dxhat;
1023                    dxhat_x_sum = dxhat_x_sum + dxhat * x_norm[[n, c, s]];
1024                }
1025            }
1026
1027            // Calculate final gradients for each element
1028            for n in 0..batch_size {
1029                for s in 0..spatial_size {
1030                    let dxhat = reshaped_grad_output[[n, c, s]] * self.gamma[[c]];
1031                    let dx_term1 = dxhat;
1032                    let dx_term2 = dxhat_sum / spatial_elements_f;
1033                    let dx_term3 = x_norm[[n, c, s]] * dxhat_x_sum / spatial_elements_f;
1034
1035                    dx[[n, c, s]] = (dx_term1 - dx_term2 - dx_term3) / std_dev[[c]];
1036                }
1037            }
1038        }
1039
1040        // Reshape back to the original input shape
1041        let dx_output = dx
1042            .into_shape_with_order(IxDyn(input_shape))
1043            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape dx: {}", e)))?;
1044
1045        Ok(dx_output)
1046    }
1047
1048    fn update(&mut self, learning_rate: F) -> Result<()> {
1049        // Update gamma and beta using their gradients
1050        let lr = learning_rate;
1051
1052        for c in 0..self.num_features {
1053            self.gamma[[c]] = self.gamma[[c]] - lr * self.dgamma[[c]];
1054            self.beta[[c]] = self.beta[[c]] - lr * self.dbeta[[c]];
1055        }
1056
1057        // Reset gradients
1058        self.dgamma = Array::<F, _>::zeros(IxDyn(&[self.num_features]));
1059        self.dbeta = Array::<F, _>::zeros(IxDyn(&[self.num_features]));
1060
1061        Ok(())
1062    }
1063}
1064
1065impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for BatchNorm<F> {
1066    fn get_parameters(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
1067        vec![&self.gamma, &self.beta]
1068    }
1069
1070    fn get_gradients(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
1071        vec![&self.dgamma, &self.dbeta]
1072    }
1073
1074    fn set_parameters(&mut self, params: Vec<Array<F, ndarray::IxDyn>>) -> Result<()> {
1075        if params.len() != 2 {
1076            return Err(NeuralError::InvalidArchitecture(format!(
1077                "Expected 2 parameters, got {}",
1078                params.len()
1079            )));
1080        }
1081
1082        if params[0].shape() != self.gamma.shape() {
1083            return Err(NeuralError::InvalidArchitecture(format!(
1084                "Gamma shape mismatch: expected {:?}, got {:?}",
1085                self.gamma.shape(),
1086                params[0].shape()
1087            )));
1088        }
1089
1090        if params[1].shape() != self.beta.shape() {
1091            return Err(NeuralError::InvalidArchitecture(format!(
1092                "Beta shape mismatch: expected {:?}, got {:?}",
1093                self.beta.shape(),
1094                params[1].shape()
1095            )));
1096        }
1097
1098        self.gamma = params[0].clone();
1099        self.beta = params[1].clone();
1100
1101        Ok(())
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use approx::assert_relative_eq;
1109    use ndarray::{Array3, Array4};
1110    use rand::rngs::SmallRng;
1111    use rand::SeedableRng;
1112
1113    #[test]
1114    fn test_layer_norm_shape() {
1115        // Set up layer normalization
1116        let mut rng = SmallRng::seed_from_u64(42);
1117        let layer_norm = LayerNorm::<f64>::new(64, 1e-5, &mut rng).unwrap();
1118
1119        // Create a batch of inputs
1120        let batch_size = 2;
1121        let seq_len = 3;
1122        let d_model = 64;
1123        let input = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 0.1).into_dyn();
1124
1125        // Forward pass
1126        let output = layer_norm.forward(&input).unwrap();
1127
1128        // Check output shape
1129        assert_eq!(output.shape(), input.shape());
1130    }
1131
1132    #[test]
1133    fn test_layer_norm_normalization() {
1134        // Set up layer normalization
1135        let mut rng = SmallRng::seed_from_u64(42);
1136        let d_model = 10;
1137        let layer_norm = LayerNorm::<f64>::new(d_model, 1e-5, &mut rng).unwrap();
1138
1139        // Create a simple input with different values
1140        let mut input = Array3::<f64>::zeros((1, 1, d_model));
1141        for i in 0..d_model {
1142            input[[0, 0, i]] = i as f64;
1143        }
1144
1145        // Forward pass
1146        let output = layer_norm.forward(&input.into_dyn()).unwrap();
1147
1148        // Calculate mean and variance manually to verify
1149        let output_view = output.view();
1150        let output_slice = output_view.slice(ndarray::s![0, 0, ..]);
1151
1152        // Calculate mean
1153        let mut sum = 0.0;
1154        for i in 0..d_model {
1155            sum += output_slice[i];
1156        }
1157        let mean = sum / (d_model as f64);
1158
1159        // Calculate variance
1160        let mut sum_sq = 0.0;
1161        for i in 0..d_model {
1162            let diff = output_slice[i] - mean;
1163            sum_sq += diff * diff;
1164        }
1165        let var = sum_sq / (d_model as f64);
1166
1167        // The output should have approximately zero mean and unit variance
1168        // We allow a larger tolerance for numerical precision
1169        assert_relative_eq!(mean, 0.0, epsilon = 1e-8);
1170        assert_relative_eq!(var, 1.0, epsilon = 1e-4);
1171    }
1172
1173    #[test]
1174    fn test_batch_norm_shape() {
1175        // Set up batch normalization
1176        let mut rng = SmallRng::seed_from_u64(42);
1177        let batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1178
1179        // Create a batch of inputs (batch_size, channels, height, width)
1180        let batch_size = 2;
1181        let channels = 3;
1182        let height = 4;
1183        let width = 5;
1184        let input = Array4::<f64>::from_elem((batch_size, channels, height, width), 0.1).into_dyn();
1185
1186        // Forward pass
1187        let output = batch_norm.forward(&input).unwrap();
1188
1189        // Check output shape
1190        assert_eq!(output.shape(), input.shape());
1191    }
1192
1193    #[test]
1194    fn test_batch_norm_training_mode() {
1195        // Set up batch normalization
1196        let mut rng = SmallRng::seed_from_u64(42);
1197        let mut batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1198
1199        // Ensure we're in training mode
1200        batch_norm.set_training(true);
1201
1202        // Create input with different values per channel
1203        let batch_size = 2;
1204        let channels = 3;
1205        let height = 2;
1206        let width = 2;
1207        let mut input = Array4::<f64>::zeros((batch_size, channels, height, width));
1208
1209        // Fill input with varying values to ensure non-zero variance within each channel
1210        let mut val = 0.0;
1211        for n in 0..batch_size {
1212            for c in 0..channels {
1213                for h in 0..height {
1214                    for w in 0..width {
1215                        // Base value for the channel plus some variation
1216                        input[[n, c, h, w]] = (c + 1) as f64 + val;
1217                        val += 0.1;
1218                    }
1219                }
1220            }
1221        }
1222
1223        // Forward pass
1224        let output = batch_norm.forward(&input.into_dyn()).unwrap();
1225
1226        // For each channel, the output should have mean ≈ 0 and variance ≈ 1
1227        for c in 0..channels {
1228            let mut sum = 0.0;
1229            let mut count = 0;
1230
1231            // Calculate mean
1232            for n in 0..batch_size {
1233                for h in 0..height {
1234                    for w in 0..width {
1235                        sum += output.view().slice(ndarray::s![n, c, h, w]).into_scalar();
1236                        count += 1;
1237                    }
1238                }
1239            }
1240
1241            let mean = sum / (count as f64);
1242
1243            // Calculate variance
1244            let mut sum_sq = 0.0;
1245            for n in 0..batch_size {
1246                for h in 0..height {
1247                    for w in 0..width {
1248                        let diff =
1249                            output.view().slice(ndarray::s![n, c, h, w]).into_scalar() - mean;
1250                        sum_sq += diff * diff;
1251                    }
1252                }
1253            }
1254
1255            let var = sum_sq / (count as f64);
1256
1257            // The output should have approximately zero mean and unit variance within each channel
1258            // Allow a larger tolerance for numerical precision
1259            assert_relative_eq!(mean, 0.0, epsilon = 1e-8);
1260            assert_relative_eq!(var, 1.0, epsilon = 1e-4);
1261        }
1262    }
1263
1264    #[test]
1265    fn test_batch_norm_inference_mode() {
1266        // Set up batch normalization
1267        let mut rng = SmallRng::seed_from_u64(42);
1268        let mut batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1269
1270        // Create input with different values per channel
1271        let batch_size = 2;
1272        let channels = 3;
1273        let height = 2;
1274        let width = 2;
1275        let mut input = Array4::<f64>::zeros((batch_size, channels, height, width));
1276
1277        // Fill input with varying values to ensure non-zero variance within each channel
1278        let mut val = 0.0;
1279        for n in 0..batch_size {
1280            for c in 0..channels {
1281                for h in 0..height {
1282                    for w in 0..width {
1283                        // Base value for the channel plus some variation
1284                        input[[n, c, h, w]] = (c + 1) as f64 + val;
1285                        val += 0.1;
1286                    }
1287                }
1288            }
1289        }
1290
1291        // First forward pass in training mode to accumulate running statistics
1292        batch_norm.set_training(true);
1293        let _ = batch_norm.forward(&input.clone().into_dyn()).unwrap();
1294
1295        // Create a copy of the input for inference
1296        let input_clone = input.clone();
1297
1298        // Switch to inference mode
1299        batch_norm.set_training(false);
1300        let output = batch_norm.forward(&input_clone.into_dyn()).unwrap();
1301
1302        // Check that output is consistent with input and transformation
1303        // In inference mode, each channel should be normalized consistently
1304
1305        // We'll just verify that the output has some non-zero values
1306        // and that the magnitude is reasonable (for normalized data)
1307        let output_view = output.view();
1308
1309        let mut has_non_zero = false;
1310        let mut max_abs_val = 0.0;
1311
1312        for c in 0..channels {
1313            for n in 0..batch_size {
1314                for h in 0..height {
1315                    for w in 0..width {
1316                        let val = output_view.slice(ndarray::s![n, c, h, w]).into_scalar();
1317                        if val.abs() > 1e-10 {
1318                            has_non_zero = true;
1319                        }
1320                        max_abs_val = max_abs_val.max(val.abs());
1321                    }
1322                }
1323            }
1324        }
1325
1326        // Verify the output contains non-zero values
1327        assert!(has_non_zero, "Output values should not all be zero");
1328
1329        // For normalized data, values shouldn't be extremely large
1330        assert!(
1331            max_abs_val < 10.0,
1332            "Output values should be reasonably sized (normalized)"
1333        );
1334    }
1335}