scirs2_neural/layers/
normalization.rs

1//! Normalization layers implementation
2//!
3//! This module provides implementations of various normalization techniques
4//! such as Layer Normalization, Batch Normalization, etc.
5
6use crate::error::{NeuralError, Result};
7use crate::layers::{Layer, ParamLayer};
8use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
9use scirs2_core::numeric::Float;
10use scirs2_core::random::Rng;
11use std::fmt::Debug;
12use std::sync::{Arc, RwLock};
13
14/// Layer Normalization layer
15///
16/// Implements layer normalization as described in "Layer Normalization"
17/// by Ba, Kiros, and Hinton. It normalizes the inputs across the last dimension
18/// and applies learnable scale and shift parameters.
19#[derive(Debug)]
20pub struct LayerNorm<F: Float + Debug + Send + Sync> {
21    /// Dimensionality of the input features
22    normalizedshape: Vec<usize>,
23    /// Learnable scale parameter
24    gamma: Array<F, IxDyn>,
25    /// Learnable shift parameter
26    beta: Array<F, IxDyn>,
27    /// Gradient of gamma
28    dgamma: Arc<RwLock<Array<F, IxDyn>>>,
29    /// Gradient of beta
30    dbeta: Arc<RwLock<Array<F, IxDyn>>>,
31    /// Small constant for numerical stability
32    eps: F,
33    /// Input cache for backward pass
34    input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
35    /// Normalized input cache for backward pass
36    norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
37    /// Mean cache for backward pass
38    mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
39    /// Variance cache for backward pass
40    var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
41}
42
43impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Clone for LayerNorm<F> {
44    fn clone(&self) -> Self {
45        let input_cache_clone = match self.input_cache.read() {
46            Ok(guard) => guard.clone(),
47            Err(_) => None,
48        };
49        let norm_cache_clone = match self.norm_cache.read() {
50            Ok(guard) => guard.clone(),
51            Err(_) => None,
52        };
53        let mean_cache_clone = match self.mean_cache.read() {
54            Ok(guard) => guard.clone(),
55            Err(_) => None,
56        };
57        let var_cache_clone = match self.var_cache.read() {
58            Ok(guard) => guard.clone(),
59            Err(_) => None,
60        };
61
62        Self {
63            normalizedshape: self.normalizedshape.clone(),
64            gamma: self.gamma.clone(),
65            beta: self.beta.clone(),
66            dgamma: Arc::new(RwLock::new(self.dgamma.read().unwrap().clone())),
67            dbeta: Arc::new(RwLock::new(self.dbeta.read().unwrap().clone())),
68            eps: self.eps,
69            input_cache: Arc::new(RwLock::new(input_cache_clone)),
70            norm_cache: Arc::new(RwLock::new(norm_cache_clone)),
71            mean_cache: Arc::new(RwLock::new(mean_cache_clone)),
72            var_cache: Arc::new(RwLock::new(var_cache_clone)),
73        }
74    }
75}
76
77impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> LayerNorm<F> {
78    /// Create a new layer normalization layer
79    pub fn new<R: Rng>(normalizedshape: usize, eps: f64, _rng: &mut R) -> Result<Self> {
80        let gamma = Array::<F, IxDyn>::from_elem(IxDyn(&[normalizedshape]), F::one());
81        let beta = Array::<F, IxDyn>::from_elem(IxDyn(&[normalizedshape]), F::zero());
82
83        let dgamma = Arc::new(RwLock::new(Array::<F, IxDyn>::zeros(IxDyn(&[
84            normalizedshape,
85        ]))));
86        let dbeta = Arc::new(RwLock::new(Array::<F, IxDyn>::zeros(IxDyn(&[
87            normalizedshape,
88        ]))));
89
90        let eps = F::from(eps).ok_or_else(|| {
91            NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
92        })?;
93
94        Ok(Self {
95            normalizedshape: vec![normalizedshape],
96            gamma,
97            beta,
98            dgamma,
99            dbeta,
100            eps,
101            input_cache: Arc::new(RwLock::new(None)),
102            norm_cache: Arc::new(RwLock::new(None)),
103            mean_cache: Arc::new(RwLock::new(None)),
104            var_cache: Arc::new(RwLock::new(None)),
105        })
106    }
107
108    /// Get the normalized shape
109    pub fn normalizedshape(&self) -> usize {
110        self.normalizedshape[0]
111    }
112
113    /// Get the epsilon value
114    #[allow(dead_code)]
115    pub fn eps(&self) -> f64 {
116        self.eps.to_f64().unwrap_or(1e-5)
117    }
118}
119
120impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for LayerNorm<F> {
121    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
122        // Cache input for backward pass
123        if let Ok(mut cache) = self.input_cache.write() {
124            *cache = Some(input.clone());
125        }
126
127        let inputshape = input.shape();
128        let ndim = input.ndim();
129
130        if ndim < 1 {
131            return Err(NeuralError::InferenceError(
132                "Input must have at least 1 dimension".to_string(),
133            ));
134        }
135
136        let feat_dim = inputshape[ndim - 1];
137        if feat_dim != self.normalizedshape[0] {
138            return Err(NeuralError::InvalidArchitecture(format!(
139                "Last dimension of input ({}) must match normalizedshape ({})",
140                feat_dim, self.normalizedshape[0]
141            )));
142        }
143
144        let batchshape: Vec<usize> = inputshape[..ndim - 1].to_vec();
145        let batch_size: usize = batchshape.iter().product();
146
147        // Reshape input to 2D: [batch_size, features]
148        let reshaped = input
149            .to_owned()
150            .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
151            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {e}")))?;
152
153        // Compute mean and variance for each sample
154        let mut mean = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, 1]));
155        let mut var = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, 1]));
156
157        for i in 0..batch_size {
158            let mut sum = F::zero();
159            for j in 0..feat_dim {
160                sum = sum + reshaped[[i, j]];
161            }
162            mean[[i, 0]] = sum / F::from(feat_dim).unwrap();
163
164            let mut sum_sq = F::zero();
165            for j in 0..feat_dim {
166                let diff = reshaped[[i, j]] - mean[[i, 0]];
167                sum_sq = sum_sq + diff * diff;
168            }
169            var[[i, 0]] = sum_sq / F::from(feat_dim).unwrap();
170        }
171
172        // Cache mean and variance
173        if let Ok(mut cache) = self.mean_cache.write() {
174            *cache = Some(mean.clone());
175        }
176        if let Ok(mut cache) = self.var_cache.write() {
177            *cache = Some(var.clone());
178        }
179
180        // Normalize and apply gamma/beta
181        let mut normalized = Array::<F, IxDyn>::zeros(IxDyn(&[batch_size, feat_dim]));
182        for i in 0..batch_size {
183            for j in 0..feat_dim {
184                let x_norm = (reshaped[[i, j]] - mean[[i, 0]]) / (var[[i, 0]] + self.eps).sqrt();
185                normalized[[i, j]] = x_norm * self.gamma[[j]] + self.beta[[j]];
186            }
187        }
188
189        // Cache normalized input
190        if let Ok(mut cache) = self.norm_cache.write() {
191            *cache = Some(normalized.clone().into_dimensionality::<IxDyn>().unwrap());
192        }
193
194        // Reshape back to original shape
195        let output = normalized
196            .into_shape_with_order(IxDyn(inputshape))
197            .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {e}")))?;
198
199        Ok(output)
200    }
201
202    fn backward(
203        &self,
204        _input: &Array<F, IxDyn>,
205        grad_output: &Array<F, IxDyn>,
206    ) -> Result<Array<F, IxDyn>> {
207        // Simple implementation - return grad_output as is
208        Ok(grad_output.clone())
209    }
210
211    fn update(&mut self, _learningrate: F) -> Result<()> {
212        // Simple implementation - no-op for now
213        Ok(())
214    }
215
216    fn as_any(&self) -> &dyn std::any::Any {
217        self
218    }
219
220    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
221        self
222    }
223
224    fn layer_type(&self) -> &str {
225        "LayerNorm"
226    }
227
228    fn parameter_count(&self) -> usize {
229        self.gamma.len() + self.beta.len()
230    }
231}
232
233impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for LayerNorm<F> {
234    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
235        vec![self.gamma.clone(), self.beta.clone()]
236    }
237
238    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
239        vec![]
240    }
241
242    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
243        if params.len() != 2 {
244            return Err(NeuralError::InvalidArchitecture(format!(
245                "Expected 2 parameters, got {}",
246                params.len()
247            )));
248        }
249
250        if params[0].shape() != self.gamma.shape() {
251            return Err(NeuralError::InvalidArchitecture(format!(
252                "Gamma shape mismatch: expected {:?}, got {:?}",
253                self.gamma.shape(),
254                params[0].shape()
255            )));
256        }
257
258        if params[1].shape() != self.beta.shape() {
259            return Err(NeuralError::InvalidArchitecture(format!(
260                "Beta shape mismatch: expected {:?}, got {:?}",
261                self.beta.shape(),
262                params[1].shape()
263            )));
264        }
265
266        self.gamma = params[0].clone();
267        self.beta = params[1].clone();
268
269        Ok(())
270    }
271}
272
273/// Batch Normalization layer
274#[derive(Debug, Clone)]
275pub struct BatchNorm<F: Float + Debug + Send + Sync> {
276    /// Number of features (channels)
277    num_features: usize,
278    /// Learnable scale parameter
279    gamma: Array<F, IxDyn>,
280    /// Learnable shift parameter
281    beta: Array<F, IxDyn>,
282    /// Small constant for numerical stability
283    #[allow(dead_code)]
284    eps: F,
285    /// Momentum for running statistics updates
286    #[allow(dead_code)]
287    momentum: F,
288    /// Whether we're in training mode
289    training: bool,
290}
291
292impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> BatchNorm<F> {
293    /// Create a new batch normalization layer
294    pub fn new<R: Rng>(
295        _num_features: usize,
296        momentum: f64,
297        eps: f64,
298        _rng: &mut R,
299    ) -> Result<Self> {
300        let gamma = Array::<F, IxDyn>::from_elem(IxDyn(&[_num_features]), F::one());
301        let beta = Array::<F, IxDyn>::from_elem(IxDyn(&[_num_features]), F::zero());
302
303        let momentum = F::from(momentum).ok_or_else(|| {
304            NeuralError::InvalidArchitecture("Failed to convert momentum to type F".to_string())
305        })?;
306
307        let eps = F::from(eps).ok_or_else(|| {
308            NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
309        })?;
310
311        Ok(Self {
312            num_features: _num_features,
313            gamma,
314            beta,
315            eps,
316            momentum,
317            training: true,
318        })
319    }
320
321    /// Set the training mode
322    #[allow(dead_code)]
323    pub fn set_training(&mut self, training: bool) {
324        self.training = training;
325    }
326
327    /// Get the number of features
328    #[allow(dead_code)]
329    pub fn num_features(&self) -> usize {
330        self.num_features
331    }
332}
333
334impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F> {
335    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
336        // Simple implementation - return input as is for now
337        Ok(input.clone())
338    }
339
340    fn backward(
341        &self,
342        _input: &Array<F, IxDyn>,
343        grad_output: &Array<F, IxDyn>,
344    ) -> Result<Array<F, IxDyn>> {
345        Ok(grad_output.clone())
346    }
347
348    fn update(&mut self, _learningrate: F) -> Result<()> {
349        Ok(())
350    }
351
352    fn as_any(&self) -> &dyn std::any::Any {
353        self
354    }
355
356    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
357        self
358    }
359
360    fn layer_type(&self) -> &str {
361        "BatchNorm"
362    }
363
364    fn parameter_count(&self) -> usize {
365        self.gamma.len() + self.beta.len()
366    }
367}
368
369impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for BatchNorm<F> {
370    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
371        vec![self.gamma.clone(), self.beta.clone()]
372    }
373
374    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
375        vec![]
376    }
377
378    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
379        if params.len() != 2 {
380            return Err(NeuralError::InvalidArchitecture(format!(
381                "Expected 2 parameters, got {}",
382                params.len()
383            )));
384        }
385
386        self.gamma = params[0].clone();
387        self.beta = params[1].clone();
388
389        Ok(())
390    }
391}