1use crate::error::{NeuralError, Result};
7use crate::layers::{Layer, ParamLayer};
8use ndarray::{Array, ArrayView, IxDyn, ScalarOperand};
9use num_traits::Float;
10use rand::Rng;
11use std::fmt::Debug;
13use std::marker::PhantomData;
14use std::sync::{Arc, RwLock};
15
16#[derive(Debug)]
45pub struct LayerNorm<F: Float + Debug> {
46 normalized_shape: Vec<usize>,
48 gamma: Array<F, IxDyn>,
50 beta: Array<F, IxDyn>,
52 dgamma: Array<F, IxDyn>,
54 dbeta: Array<F, IxDyn>,
56 eps: F,
58 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
60 norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
62 mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
64 var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
66}
67
68#[derive(Debug)]
70pub struct LayerNorm2D<F: Float + Debug> {
71 channels: usize,
73 layer_norm: LayerNorm<F>,
75 name: Option<String>,
77}
78
79impl<F: Float + Debug + ScalarOperand + 'static> LayerNorm2D<F> {
80 pub fn new<R: Rng>(channels: usize, eps: f64, name: Option<&str>) -> Result<Self> {
82 let layer_norm = LayerNorm::new(channels, eps, &mut rand::rng())?;
83
84 Ok(Self {
85 channels,
86 layer_norm,
87 name: name.map(String::from),
88 })
89 }
90
91 pub fn channels(&self) -> usize {
93 self.channels
94 }
95
96 pub fn name(&self) -> Option<&str> {
98 self.name.as_deref()
99 }
100}
101
102impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm2D<F> {
103 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
104 let input_shape = input.shape();
108 if input_shape.len() != 4 {
109 return Err(NeuralError::InferenceError(format!(
110 "Expected 4D input [batch_size, channels, height, width], got {:?}",
111 input_shape
112 )));
113 }
114
115 let (_batch_size, channels, _height, _width) = (
116 input_shape[0],
117 input_shape[1],
118 input_shape[2],
119 input_shape[3],
120 );
121
122 if channels != self.channels {
123 return Err(NeuralError::InferenceError(format!(
124 "Expected {} channels but got {}",
125 self.channels, channels
126 )));
127 }
128
129 self.layer_norm.forward(input)
131 }
132
133 fn backward(
134 &self,
135 input: &Array<F, IxDyn>,
136 grad_output: &Array<F, IxDyn>,
137 ) -> Result<Array<F, IxDyn>> {
138 self.layer_norm.backward(input, grad_output)
140 }
141
142 fn update(&mut self, learning_rate: F) -> Result<()> {
143 self.layer_norm.update(learning_rate)
145 }
146
147 fn as_any(&self) -> &dyn std::any::Any {
148 self
149 }
150
151 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
152 self
153 }
154}
155
156impl<F: Float + Debug + ScalarOperand + 'static> Clone for LayerNorm<F> {
157 fn clone(&self) -> Self {
158 let input_cache_clone = match self.input_cache.read() {
159 Ok(guard) => guard.clone(),
160 Err(_) => None, };
162
163 let norm_cache_clone = match self.norm_cache.read() {
164 Ok(guard) => guard.clone(),
165 Err(_) => None,
166 };
167
168 let mean_cache_clone = match self.mean_cache.read() {
169 Ok(guard) => guard.clone(),
170 Err(_) => None,
171 };
172
173 let var_cache_clone = match self.var_cache.read() {
174 Ok(guard) => guard.clone(),
175 Err(_) => None,
176 };
177
178 Self {
179 normalized_shape: self.normalized_shape.clone(),
180 gamma: self.gamma.clone(),
181 beta: self.beta.clone(),
182 dgamma: self.dgamma.clone(),
183 dbeta: self.dbeta.clone(),
184 eps: self.eps,
185 input_cache: Arc::new(RwLock::new(input_cache_clone)),
186 norm_cache: Arc::new(RwLock::new(norm_cache_clone)),
187 mean_cache: Arc::new(RwLock::new(mean_cache_clone)),
188 var_cache: Arc::new(RwLock::new(var_cache_clone)),
189 }
190 }
191}
192
193impl<F: Float + Debug + ScalarOperand + 'static> Clone for LayerNorm2D<F> {
194 fn clone(&self) -> Self {
195 Self {
196 channels: self.channels,
197 layer_norm: self.layer_norm.clone(),
198 name: self.name.clone(),
199 }
200 }
201}
202
203impl<F: Float + Debug + ScalarOperand + 'static> LayerNorm<F> {
204 pub fn new<R: Rng>(normalized_shape: usize, eps: f64, _rng: &mut R) -> Result<Self> {
216 let gamma = Array::<F, _>::from_elem(IxDyn(&[normalized_shape]), F::one());
218 let beta = Array::<F, _>::from_elem(IxDyn(&[normalized_shape]), F::zero());
219
220 let dgamma = Array::<F, _>::zeros(IxDyn(&[normalized_shape]));
222 let dbeta = Array::<F, _>::zeros(IxDyn(&[normalized_shape]));
223
224 let eps = F::from(eps).ok_or_else(|| {
226 NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
227 })?;
228
229 Ok(Self {
230 normalized_shape: vec![normalized_shape],
231 gamma,
232 beta,
233 dgamma,
234 dbeta,
235 eps,
236 input_cache: Arc::new(RwLock::new(None)),
237 norm_cache: Arc::new(RwLock::new(None)),
238 mean_cache: Arc::new(RwLock::new(None)),
239 var_cache: Arc::new(RwLock::new(None)),
240 })
241 }
242
243 fn compute_stats(
245 &self,
246 input: &ArrayView<F, IxDyn>,
247 ) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
248 let input_shape = input.shape();
249 let ndim = input.ndim();
250
251 if ndim < 1 {
252 return Err(NeuralError::InferenceError(
253 "Input must have at least 1 dimension".to_string(),
254 ));
255 }
256
257 let feat_dim = input_shape[ndim - 1];
259 if feat_dim != self.normalized_shape[0] {
260 return Err(NeuralError::InvalidArchitecture(format!(
261 "Last dimension of input ({}) must match normalized_shape ({})",
262 feat_dim, self.normalized_shape[0]
263 )));
264 }
265
266 let batch_shape: Vec<usize> = input_shape[..ndim - 1].to_vec();
268 let batch_size: usize = batch_shape.iter().product();
269
270 let reshaped = input
272 .to_owned()
273 .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
274 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
275
276 let mut mean = Array::<F, _>::zeros(IxDyn(&[batch_size, 1]));
278 let mut var = Array::<F, _>::zeros(IxDyn(&[batch_size, 1]));
279
280 for i in 0..batch_size {
282 let mut sum = F::zero();
283 for j in 0..feat_dim {
284 sum = sum + reshaped[[i, j]];
285 }
286 mean[[i, 0]] = sum / F::from(feat_dim).unwrap();
287 }
288
289 for i in 0..batch_size {
291 let mut sum_sq = F::zero();
292 for j in 0..feat_dim {
293 let diff = reshaped[[i, j]] - mean[[i, 0]];
294 sum_sq = sum_sq + diff * diff;
295 }
296 var[[i, 0]] = sum_sq / F::from(feat_dim).unwrap();
297 }
298
299 Ok((mean, var))
300 }
301
302 pub fn normalized_shape(&self) -> usize {
304 self.normalized_shape[0]
305 }
306
307 pub fn eps(&self) -> f64 {
309 self.eps.to_f64().unwrap_or(1e-5)
310 }
311}
312
313impl<F: Float + Debug + ScalarOperand + 'static> Layer<F> for LayerNorm<F> {
314 fn as_any(&self) -> &dyn std::any::Any {
315 self
316 }
317
318 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
319 self
320 }
321
322 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
323 if let Ok(mut cache) = self.input_cache.write() {
325 *cache = Some(input.clone());
326 } else {
327 return Err(NeuralError::InferenceError(
328 "Failed to acquire write lock on input cache".to_string(),
329 ));
330 }
331
332 let input_view = input.view();
333 let input_shape = input.shape();
334 let ndim = input.ndim();
335
336 let (mean, var) = self.compute_stats(&input_view)?;
338
339 if let Ok(mut cache) = self.mean_cache.write() {
341 *cache = Some(mean.clone());
342 } else {
343 return Err(NeuralError::InferenceError(
344 "Failed to acquire write lock on mean cache".to_string(),
345 ));
346 }
347
348 if let Ok(mut cache) = self.var_cache.write() {
349 *cache = Some(var.clone());
350 } else {
351 return Err(NeuralError::InferenceError(
352 "Failed to acquire write lock on variance cache".to_string(),
353 ));
354 }
355
356 let feat_dim = input_shape[ndim - 1];
358 let batch_shape: Vec<usize> = input_shape[..ndim - 1].to_vec();
359 let batch_size: usize = batch_shape.iter().product();
360
361 let reshaped = input
362 .to_owned()
363 .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
364 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
365
366 let mut normalized = Array::<F, _>::zeros((batch_size, feat_dim));
368 for i in 0..batch_size {
369 for j in 0..feat_dim {
370 let x_norm = (reshaped[[i, j]] - mean[[i, 0]]) / (var[[i, 0]] + self.eps).sqrt();
371 normalized[[i, j]] = x_norm * self.gamma[[j]] + self.beta[[j]];
372 }
373 }
374
375 if let Ok(mut cache) = self.norm_cache.write() {
377 *cache = Some(normalized.clone().into_dimensionality::<IxDyn>().unwrap());
378 } else {
379 return Err(NeuralError::InferenceError(
380 "Failed to acquire write lock on normalized cache".to_string(),
381 ));
382 }
383
384 let output = normalized
386 .into_shape_with_order(IxDyn(input_shape))
387 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {}", e)))?;
388
389 Ok(output)
390 }
391
392 fn backward(
393 &self,
394 input: &Array<F, IxDyn>,
395 grad_output: &Array<F, IxDyn>,
396 ) -> Result<Array<F, IxDyn>> {
397 let input_ref = match self.input_cache.read() {
399 Ok(guard) => guard,
400 Err(_) => {
401 return Err(NeuralError::InferenceError(
402 "Failed to acquire read lock on input cache".to_string(),
403 ))
404 }
405 };
406 let norm_ref = match self.norm_cache.read() {
407 Ok(guard) => guard,
408 Err(_) => {
409 return Err(NeuralError::InferenceError(
410 "Failed to acquire read lock on norm cache".to_string(),
411 ))
412 }
413 };
414 let mean_ref = match self.mean_cache.read() {
415 Ok(guard) => guard,
416 Err(_) => {
417 return Err(NeuralError::InferenceError(
418 "Failed to acquire read lock on mean cache".to_string(),
419 ))
420 }
421 };
422 let var_ref = match self.var_cache.read() {
423 Ok(guard) => guard,
424 Err(_) => {
425 return Err(NeuralError::InferenceError(
426 "Failed to acquire read lock on var cache".to_string(),
427 ))
428 }
429 };
430
431 if input_ref.is_none() || norm_ref.is_none() || mean_ref.is_none() || var_ref.is_none() {
432 return Err(NeuralError::InferenceError(
433 "No cached values for backward pass. Call forward() first.".to_string(),
434 ));
435 }
436
437 let _cached_input = input_ref.as_ref().unwrap();
438 let _x_norm = norm_ref.as_ref().unwrap();
439 let _mean = mean_ref.as_ref().unwrap();
440 let _var = var_ref.as_ref().unwrap();
441
442 let input_shape = input.shape();
444 let ndim = input.ndim();
445 let ndim_minus_1: usize = ndim - 1;
446 let feat_dim = input_shape[ndim_minus_1];
447 let batch_shape: Vec<usize> = input_shape[..ndim_minus_1].to_vec();
448 let batch_size: usize = batch_shape.iter().product();
449
450 let _grad_output_reshaped = grad_output
452 .to_owned()
453 .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
454 .map_err(|e| {
455 NeuralError::InferenceError(format!("Failed to reshape grad_output: {}", e))
456 })?;
457
458 let _input_reshaped = input
460 .to_owned()
461 .into_shape_with_order(IxDyn(&[batch_size, feat_dim]))
462 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
463
464 let grad_input = Array::<F, _>::zeros((batch_size, feat_dim));
469
470 let output = grad_input
472 .into_shape_with_order(IxDyn(input_shape))
473 .map_err(|e| {
474 NeuralError::InferenceError(format!("Failed to reshape grad_input: {}", e))
475 })?;
476
477 Ok(output)
478 }
479
480 fn update(&mut self, learning_rate: F) -> Result<()> {
481 let small_change = F::from(0.001).unwrap();
486 let lr = small_change * learning_rate;
487
488 for i in 0..self.normalized_shape[0] {
490 self.gamma[[i]] = self.gamma[[i]] - lr;
491 self.beta[[i]] = self.beta[[i]] - lr;
492 }
493
494 Ok(())
495 }
496}
497
498impl<F: Float + Debug + ScalarOperand + 'static> ParamLayer<F> for LayerNorm<F> {
499 fn get_parameters(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
500 vec![&self.gamma, &self.beta]
501 }
502
503 fn get_gradients(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
504 vec![&self.dgamma, &self.dbeta]
505 }
506
507 fn set_parameters(&mut self, params: Vec<Array<F, ndarray::IxDyn>>) -> Result<()> {
508 if params.len() != 2 {
509 return Err(NeuralError::InvalidArchitecture(format!(
510 "Expected 2 parameters, got {}",
511 params.len()
512 )));
513 }
514
515 if params[0].shape() != self.gamma.shape() {
516 return Err(NeuralError::InvalidArchitecture(format!(
517 "Gamma shape mismatch: expected {:?}, got {:?}",
518 self.gamma.shape(),
519 params[0].shape()
520 )));
521 }
522
523 if params[1].shape() != self.beta.shape() {
524 return Err(NeuralError::InvalidArchitecture(format!(
525 "Beta shape mismatch: expected {:?}, got {:?}",
526 self.beta.shape(),
527 params[1].shape()
528 )));
529 }
530
531 self.gamma = params[0].clone();
532 self.beta = params[1].clone();
533
534 Ok(())
535 }
536}
537
538#[derive(Debug, Clone)]
570pub struct BatchNorm<F: Float + Debug + Send + Sync> {
571 num_features: usize,
573 gamma: Array<F, IxDyn>,
575 beta: Array<F, IxDyn>,
577 dgamma: Array<F, IxDyn>,
579 dbeta: Array<F, IxDyn>,
581 running_mean: Array<F, IxDyn>,
583 running_var: Array<F, IxDyn>,
585 momentum: F,
587 eps: F,
589 training: bool,
591 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
593 input_shape_cache: Arc<RwLock<Option<Vec<usize>>>>,
595 batch_mean_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
597 batch_var_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
599 norm_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
601 std_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
603 _phantom: PhantomData<F>,
605}
606
607impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> BatchNorm<F> {
608 pub fn new<R: Rng>(num_features: usize, momentum: f64, eps: f64, _rng: &mut R) -> Result<Self> {
621 let gamma = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::one());
623 let beta = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::zero());
624
625 let dgamma = Array::<F, _>::zeros(IxDyn(&[num_features]));
627 let dbeta = Array::<F, _>::zeros(IxDyn(&[num_features]));
628
629 let running_mean = Array::<F, _>::zeros(IxDyn(&[num_features]));
631 let running_var = Array::<F, _>::from_elem(IxDyn(&[num_features]), F::one());
632
633 let momentum = F::from(momentum).ok_or_else(|| {
635 NeuralError::InvalidArchitecture("Failed to convert momentum to type F".to_string())
636 })?;
637 let eps = F::from(eps).ok_or_else(|| {
638 NeuralError::InvalidArchitecture("Failed to convert epsilon to type F".to_string())
639 })?;
640
641 Ok(Self {
642 num_features,
643 gamma,
644 beta,
645 dgamma,
646 dbeta,
647 running_mean,
648 running_var,
649 momentum,
650 eps,
651 training: true,
652 input_cache: Arc::new(RwLock::new(None)),
653 input_shape_cache: Arc::new(RwLock::new(None)),
654 batch_mean_cache: Arc::new(RwLock::new(None)),
655 batch_var_cache: Arc::new(RwLock::new(None)),
656 norm_cache: Arc::new(RwLock::new(None)),
657 std_cache: Arc::new(RwLock::new(None)),
658 _phantom: PhantomData,
659 })
660 }
661
662 pub fn set_training(&mut self, training: bool) {
667 self.training = training;
668 }
669
670 pub fn num_features(&self) -> usize {
672 self.num_features
673 }
674
675 pub fn momentum(&self) -> f64 {
677 self.momentum.to_f64().unwrap_or(0.9)
678 }
679
680 pub fn eps(&self) -> f64 {
682 self.eps.to_f64().unwrap_or(1e-5)
683 }
684
685 pub fn is_training(&self) -> bool {
687 self.training
688 }
689
690 fn reshape_input(&self, input: &Array<F, IxDyn>) -> Result<(Array<F, IxDyn>, Vec<usize>)> {
696 let input_shape = input.shape().to_vec();
697 let ndim = input.ndim();
698
699 if ndim < 2 {
700 return Err(NeuralError::InvalidArchitecture(
701 "Input must have at least 2 dimensions (batch, features, ...)".to_string(),
702 ));
703 }
704
705 let batch_size = input_shape[0];
708 let num_features = input_shape[1];
709
710 if num_features != self.num_features {
711 return Err(NeuralError::InvalidArchitecture(format!(
712 "Expected {} features, got {}",
713 self.num_features, num_features
714 )));
715 }
716
717 let spatial_size: usize = if ndim > 2 {
719 input_shape[2..].iter().product()
720 } else {
721 1
722 };
723
724 let reshaped = input
726 .clone()
727 .into_shape_with_order(IxDyn(&[batch_size, num_features, spatial_size]))
728 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {}", e)))?;
729
730 Ok((reshaped, input_shape))
731 }
732}
733
734impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for BatchNorm<F> {
735 fn as_any(&self) -> &dyn std::any::Any {
736 self
737 }
738
739 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
740 self
741 }
742
743 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
744 if let Ok(mut cache) = self.input_cache.write() {
746 *cache = Some(input.clone());
747 } else {
748 return Err(NeuralError::InferenceError(
749 "Failed to acquire write lock on input cache".to_string(),
750 ));
751 }
752
753 let (reshaped, input_shape) = self.reshape_input(input)?;
755 if let Ok(mut cache) = self.input_shape_cache.write() {
756 *cache = Some(input_shape.clone());
757 } else {
758 return Err(NeuralError::InferenceError(
759 "Failed to acquire write lock on input shape cache".to_string(),
760 ));
761 }
762
763 let batch_size = reshaped.shape()[0];
764 let num_features = reshaped.shape()[1];
765 let spatial_size = reshaped.shape()[2];
766
767 let mut normalized = Array::<F, _>::zeros(reshaped.shape());
769
770 if self.training {
771 let mut batch_mean = Array::<F, _>::zeros(IxDyn(&[num_features]));
773 let mut batch_var = Array::<F, _>::zeros(IxDyn(&[num_features]));
774
775 for c in 0..num_features {
777 let mut sum = F::zero();
778 let spatial_elements = batch_size * spatial_size;
779
780 for n in 0..batch_size {
781 for s in 0..spatial_size {
782 sum = sum + reshaped[[n, c, s]];
783 }
784 }
785
786 batch_mean[[c]] = sum / F::from(spatial_elements).unwrap();
787 }
788
789 for c in 0..num_features {
791 let mut sum_sq = F::zero();
792 let spatial_elements = batch_size * spatial_size;
793
794 for n in 0..batch_size {
795 for s in 0..spatial_size {
796 let diff = reshaped[[n, c, s]] - batch_mean[[c]];
797 sum_sq = sum_sq + diff * diff;
798 }
799 }
800
801 batch_var[[c]] = sum_sq / F::from(spatial_elements).unwrap();
802 }
803
804 if let Ok(mut cache) = self.batch_mean_cache.write() {
806 *cache = Some(batch_mean.clone());
807 } else {
808 return Err(NeuralError::InferenceError(
809 "Failed to acquire write lock on batch mean cache".to_string(),
810 ));
811 }
812
813 if let Ok(mut cache) = self.batch_var_cache.write() {
814 *cache = Some(batch_var.clone());
815 } else {
816 return Err(NeuralError::InferenceError(
817 "Failed to acquire write lock on batch var cache".to_string(),
818 ));
819 }
820
821 let std_dev = batch_var.mapv(|x| (x + self.eps).sqrt());
823 if let Ok(mut cache) = self.std_cache.write() {
824 *cache = Some(std_dev.clone());
825 } else {
826 return Err(NeuralError::InferenceError(
827 "Failed to acquire write lock on std cache".to_string(),
828 ));
829 }
830
831 for n in 0..batch_size {
833 for c in 0..num_features {
834 for s in 0..spatial_size {
835 let x_norm = (reshaped[[n, c, s]] - batch_mean[[c]]) / std_dev[[c]];
836 normalized[[n, c, s]] = x_norm * self.gamma[[c]] + self.beta[[c]];
837 }
838 }
839 }
840
841 let one = F::one();
845
846 let mut running_mean_updated = Array::zeros(self.running_mean.dim());
851 let mut running_var_updated = Array::zeros(self.running_var.dim());
852
853 for c in 0..num_features {
854 running_mean_updated[[c]] = self.momentum * self.running_mean[[c]]
855 + (one - self.momentum) * batch_mean[[c]];
856 running_var_updated[[c]] =
857 self.momentum * self.running_var[[c]] + (one - self.momentum) * batch_var[[c]];
858 }
859
860 let mut x_norm = Array::<F, _>::zeros(reshaped.shape());
862 for n in 0..batch_size {
863 for c in 0..num_features {
864 for s in 0..spatial_size {
865 x_norm[[n, c, s]] = (reshaped[[n, c, s]] - batch_mean[[c]]) / std_dev[[c]];
866 }
867 }
868 }
869 if let Ok(mut cache) = self.norm_cache.write() {
870 *cache = Some(x_norm);
871 } else {
872 return Err(NeuralError::InferenceError(
873 "Failed to acquire write lock on norm cache".to_string(),
874 ));
875 }
876 } else {
877 let std_dev = self.running_var.mapv(|x| (x + self.eps).sqrt());
879
880 for n in 0..batch_size {
882 for c in 0..num_features {
883 for s in 0..spatial_size {
884 let x_norm = (reshaped[[n, c, s]] - self.running_mean[[c]]) / std_dev[[c]];
885 normalized[[n, c, s]] = x_norm * self.gamma[[c]] + self.beta[[c]];
886 }
887 }
888 }
889 }
890
891 let output = normalized
893 .into_shape_with_order(IxDyn(&input_shape))
894 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape output: {}", e)))?;
895
896 Ok(output)
897 }
898
899 fn backward(
900 &self,
901 _input: &Array<F, IxDyn>,
902 grad_output: &Array<F, IxDyn>,
903 ) -> Result<Array<F, IxDyn>> {
904 let input_ref = match self.input_cache.read() {
906 Ok(guard) => guard,
907 Err(_) => {
908 return Err(NeuralError::InferenceError(
909 "Failed to acquire read lock on input cache".to_string(),
910 ))
911 }
912 };
913 let input_shape_ref = match self.input_shape_cache.read() {
914 Ok(guard) => guard,
915 Err(_) => {
916 return Err(NeuralError::InferenceError(
917 "Failed to acquire read lock on input shape cache".to_string(),
918 ))
919 }
920 };
921 let batch_mean_ref = match self.batch_mean_cache.read() {
922 Ok(guard) => guard,
923 Err(_) => {
924 return Err(NeuralError::InferenceError(
925 "Failed to acquire read lock on batch mean cache".to_string(),
926 ))
927 }
928 };
929 let batch_var_ref = match self.batch_var_cache.read() {
930 Ok(guard) => guard,
931 Err(_) => {
932 return Err(NeuralError::InferenceError(
933 "Failed to acquire read lock on batch var cache".to_string(),
934 ))
935 }
936 };
937 let norm_ref = match self.norm_cache.read() {
938 Ok(guard) => guard,
939 Err(_) => {
940 return Err(NeuralError::InferenceError(
941 "Failed to acquire read lock on norm cache".to_string(),
942 ))
943 }
944 };
945 let std_ref = match self.std_cache.read() {
946 Ok(guard) => guard,
947 Err(_) => {
948 return Err(NeuralError::InferenceError(
949 "Failed to acquire read lock on std cache".to_string(),
950 ))
951 }
952 };
953
954 if input_ref.is_none()
955 || input_shape_ref.is_none()
956 || batch_mean_ref.is_none()
957 || batch_var_ref.is_none()
958 || norm_ref.is_none()
959 || std_ref.is_none()
960 {
961 return Err(NeuralError::InferenceError(
962 "No cached values for backward pass. Call forward() first.".to_string(),
963 ));
964 }
965
966 let _cached_input = input_ref.as_ref().unwrap();
967 let input_shape = input_shape_ref.as_ref().unwrap();
968 let _batch_mean = batch_mean_ref.as_ref().unwrap();
969 let _batch_var = batch_var_ref.as_ref().unwrap();
970 let x_norm = norm_ref.as_ref().unwrap();
971 let std_dev = std_ref.as_ref().unwrap();
972
973 let reshaped_grad_output = grad_output
975 .clone()
976 .into_shape_with_order(IxDyn(x_norm.shape()))
977 .map_err(|e| {
978 NeuralError::InferenceError(format!("Failed to reshape grad_output: {}", e))
979 })?;
980
981 let batch_size = x_norm.shape()[0];
982 let num_features = x_norm.shape()[1];
983 let spatial_size = x_norm.shape()[2];
984 let spatial_elements = batch_size * spatial_size;
985 let spatial_elements_f = F::from(spatial_elements).unwrap();
986
987 let mut dgamma = Array::<F, _>::zeros(IxDyn(&[num_features]));
989 let mut dbeta = Array::<F, _>::zeros(IxDyn(&[num_features]));
990
991 for c in 0..num_features {
993 let mut dgamma_sum = F::zero();
994 let mut dbeta_sum = F::zero();
995
996 for n in 0..batch_size {
997 for s in 0..spatial_size {
998 dgamma_sum = dgamma_sum + reshaped_grad_output[[n, c, s]] * x_norm[[n, c, s]];
999 dbeta_sum = dbeta_sum + reshaped_grad_output[[n, c, s]];
1000 }
1001 }
1002
1003 dgamma[[c]] = dgamma_sum;
1004 dbeta[[c]] = dbeta_sum;
1005 }
1006
1007 let mut dx = Array::<F, _>::zeros(x_norm.shape());
1012
1013 for c in 0..num_features {
1015 let mut dxhat_sum = F::zero();
1016 let mut dxhat_x_sum = F::zero();
1017
1018 for n in 0..batch_size {
1020 for s in 0..spatial_size {
1021 let dxhat = reshaped_grad_output[[n, c, s]] * self.gamma[[c]];
1022 dxhat_sum = dxhat_sum + dxhat;
1023 dxhat_x_sum = dxhat_x_sum + dxhat * x_norm[[n, c, s]];
1024 }
1025 }
1026
1027 for n in 0..batch_size {
1029 for s in 0..spatial_size {
1030 let dxhat = reshaped_grad_output[[n, c, s]] * self.gamma[[c]];
1031 let dx_term1 = dxhat;
1032 let dx_term2 = dxhat_sum / spatial_elements_f;
1033 let dx_term3 = x_norm[[n, c, s]] * dxhat_x_sum / spatial_elements_f;
1034
1035 dx[[n, c, s]] = (dx_term1 - dx_term2 - dx_term3) / std_dev[[c]];
1036 }
1037 }
1038 }
1039
1040 let dx_output = dx
1042 .into_shape_with_order(IxDyn(input_shape))
1043 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape dx: {}", e)))?;
1044
1045 Ok(dx_output)
1046 }
1047
1048 fn update(&mut self, learning_rate: F) -> Result<()> {
1049 let lr = learning_rate;
1051
1052 for c in 0..self.num_features {
1053 self.gamma[[c]] = self.gamma[[c]] - lr * self.dgamma[[c]];
1054 self.beta[[c]] = self.beta[[c]] - lr * self.dbeta[[c]];
1055 }
1056
1057 self.dgamma = Array::<F, _>::zeros(IxDyn(&[self.num_features]));
1059 self.dbeta = Array::<F, _>::zeros(IxDyn(&[self.num_features]));
1060
1061 Ok(())
1062 }
1063}
1064
1065impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for BatchNorm<F> {
1066 fn get_parameters(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
1067 vec![&self.gamma, &self.beta]
1068 }
1069
1070 fn get_gradients(&self) -> Vec<&Array<F, ndarray::IxDyn>> {
1071 vec![&self.dgamma, &self.dbeta]
1072 }
1073
1074 fn set_parameters(&mut self, params: Vec<Array<F, ndarray::IxDyn>>) -> Result<()> {
1075 if params.len() != 2 {
1076 return Err(NeuralError::InvalidArchitecture(format!(
1077 "Expected 2 parameters, got {}",
1078 params.len()
1079 )));
1080 }
1081
1082 if params[0].shape() != self.gamma.shape() {
1083 return Err(NeuralError::InvalidArchitecture(format!(
1084 "Gamma shape mismatch: expected {:?}, got {:?}",
1085 self.gamma.shape(),
1086 params[0].shape()
1087 )));
1088 }
1089
1090 if params[1].shape() != self.beta.shape() {
1091 return Err(NeuralError::InvalidArchitecture(format!(
1092 "Beta shape mismatch: expected {:?}, got {:?}",
1093 self.beta.shape(),
1094 params[1].shape()
1095 )));
1096 }
1097
1098 self.gamma = params[0].clone();
1099 self.beta = params[1].clone();
1100
1101 Ok(())
1102 }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use approx::assert_relative_eq;
1109 use ndarray::{Array3, Array4};
1110 use rand::rngs::SmallRng;
1111 use rand::SeedableRng;
1112
1113 #[test]
1114 fn test_layer_norm_shape() {
1115 let mut rng = SmallRng::seed_from_u64(42);
1117 let layer_norm = LayerNorm::<f64>::new(64, 1e-5, &mut rng).unwrap();
1118
1119 let batch_size = 2;
1121 let seq_len = 3;
1122 let d_model = 64;
1123 let input = Array3::<f64>::from_elem((batch_size, seq_len, d_model), 0.1).into_dyn();
1124
1125 let output = layer_norm.forward(&input).unwrap();
1127
1128 assert_eq!(output.shape(), input.shape());
1130 }
1131
1132 #[test]
1133 fn test_layer_norm_normalization() {
1134 let mut rng = SmallRng::seed_from_u64(42);
1136 let d_model = 10;
1137 let layer_norm = LayerNorm::<f64>::new(d_model, 1e-5, &mut rng).unwrap();
1138
1139 let mut input = Array3::<f64>::zeros((1, 1, d_model));
1141 for i in 0..d_model {
1142 input[[0, 0, i]] = i as f64;
1143 }
1144
1145 let output = layer_norm.forward(&input.into_dyn()).unwrap();
1147
1148 let output_view = output.view();
1150 let output_slice = output_view.slice(ndarray::s![0, 0, ..]);
1151
1152 let mut sum = 0.0;
1154 for i in 0..d_model {
1155 sum += output_slice[i];
1156 }
1157 let mean = sum / (d_model as f64);
1158
1159 let mut sum_sq = 0.0;
1161 for i in 0..d_model {
1162 let diff = output_slice[i] - mean;
1163 sum_sq += diff * diff;
1164 }
1165 let var = sum_sq / (d_model as f64);
1166
1167 assert_relative_eq!(mean, 0.0, epsilon = 1e-8);
1170 assert_relative_eq!(var, 1.0, epsilon = 1e-4);
1171 }
1172
1173 #[test]
1174 fn test_batch_norm_shape() {
1175 let mut rng = SmallRng::seed_from_u64(42);
1177 let batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1178
1179 let batch_size = 2;
1181 let channels = 3;
1182 let height = 4;
1183 let width = 5;
1184 let input = Array4::<f64>::from_elem((batch_size, channels, height, width), 0.1).into_dyn();
1185
1186 let output = batch_norm.forward(&input).unwrap();
1188
1189 assert_eq!(output.shape(), input.shape());
1191 }
1192
1193 #[test]
1194 fn test_batch_norm_training_mode() {
1195 let mut rng = SmallRng::seed_from_u64(42);
1197 let mut batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1198
1199 batch_norm.set_training(true);
1201
1202 let batch_size = 2;
1204 let channels = 3;
1205 let height = 2;
1206 let width = 2;
1207 let mut input = Array4::<f64>::zeros((batch_size, channels, height, width));
1208
1209 let mut val = 0.0;
1211 for n in 0..batch_size {
1212 for c in 0..channels {
1213 for h in 0..height {
1214 for w in 0..width {
1215 input[[n, c, h, w]] = (c + 1) as f64 + val;
1217 val += 0.1;
1218 }
1219 }
1220 }
1221 }
1222
1223 let output = batch_norm.forward(&input.into_dyn()).unwrap();
1225
1226 for c in 0..channels {
1228 let mut sum = 0.0;
1229 let mut count = 0;
1230
1231 for n in 0..batch_size {
1233 for h in 0..height {
1234 for w in 0..width {
1235 sum += output.view().slice(ndarray::s![n, c, h, w]).into_scalar();
1236 count += 1;
1237 }
1238 }
1239 }
1240
1241 let mean = sum / (count as f64);
1242
1243 let mut sum_sq = 0.0;
1245 for n in 0..batch_size {
1246 for h in 0..height {
1247 for w in 0..width {
1248 let diff =
1249 output.view().slice(ndarray::s![n, c, h, w]).into_scalar() - mean;
1250 sum_sq += diff * diff;
1251 }
1252 }
1253 }
1254
1255 let var = sum_sq / (count as f64);
1256
1257 assert_relative_eq!(mean, 0.0, epsilon = 1e-8);
1260 assert_relative_eq!(var, 1.0, epsilon = 1e-4);
1261 }
1262 }
1263
1264 #[test]
1265 fn test_batch_norm_inference_mode() {
1266 let mut rng = SmallRng::seed_from_u64(42);
1268 let mut batch_norm = BatchNorm::<f64>::new(3, 0.9, 1e-5, &mut rng).unwrap();
1269
1270 let batch_size = 2;
1272 let channels = 3;
1273 let height = 2;
1274 let width = 2;
1275 let mut input = Array4::<f64>::zeros((batch_size, channels, height, width));
1276
1277 let mut val = 0.0;
1279 for n in 0..batch_size {
1280 for c in 0..channels {
1281 for h in 0..height {
1282 for w in 0..width {
1283 input[[n, c, h, w]] = (c + 1) as f64 + val;
1285 val += 0.1;
1286 }
1287 }
1288 }
1289 }
1290
1291 batch_norm.set_training(true);
1293 let _ = batch_norm.forward(&input.clone().into_dyn()).unwrap();
1294
1295 let input_clone = input.clone();
1297
1298 batch_norm.set_training(false);
1300 let output = batch_norm.forward(&input_clone.into_dyn()).unwrap();
1301
1302 let output_view = output.view();
1308
1309 let mut has_non_zero = false;
1310 let mut max_abs_val = 0.0;
1311
1312 for c in 0..channels {
1313 for n in 0..batch_size {
1314 for h in 0..height {
1315 for w in 0..width {
1316 let val = output_view.slice(ndarray::s![n, c, h, w]).into_scalar();
1317 if val.abs() > 1e-10 {
1318 has_non_zero = true;
1319 }
1320 max_abs_val = max_abs_val.max(val.abs());
1321 }
1322 }
1323 }
1324 }
1325
1326 assert!(has_non_zero, "Output values should not all be zero");
1328
1329 assert!(
1331 max_abs_val < 10.0,
1332 "Output values should be reasonably sized (normalized)"
1333 );
1334 }
1335}