1use ndarray::{ArrayD, Axis, IxDyn};
14
15#[derive(Debug, Clone)]
17pub enum NormalizationError {
18 ShapeMismatch {
20 expected: Vec<usize>,
21 got: Vec<usize>,
22 },
23 InvalidAxis { axis: usize, ndim: usize },
25 InvalidNumGroups { groups: usize, channels: usize },
27 ZeroVariance,
29 EmptyInput,
31 InsufficientDimensions { ndim: usize, required: usize },
33}
34
35impl std::fmt::Display for NormalizationError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Self::ShapeMismatch { expected, got } => {
39 write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
40 }
41 Self::InvalidAxis { axis, ndim } => {
42 write!(
43 f,
44 "Invalid axis {} for tensor with {} dimensions",
45 axis, ndim
46 )
47 }
48 Self::InvalidNumGroups { groups, channels } => {
49 write!(
50 f,
51 "Invalid number of groups {}: does not evenly divide {} channels",
52 groups, channels
53 )
54 }
55 Self::ZeroVariance => write!(f, "Zero variance encountered during normalization"),
56 Self::EmptyInput => write!(f, "Empty input tensor"),
57 Self::InsufficientDimensions { ndim, required } => {
58 write!(
59 f,
60 "Insufficient dimensions: tensor has {} dims, but {} required",
61 ndim, required
62 )
63 }
64 }
65 }
66}
67
68impl std::error::Error for NormalizationError {}
69
70fn axis_element_count(shape: &[usize], axes: &[usize]) -> f64 {
76 axes.iter().map(|&a| shape[a] as f64).product()
77}
78
79fn mean_along_axes(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
81 let mut result = input.clone();
82 let mut sorted_axes: Vec<usize> = axes.to_vec();
84 sorted_axes.sort_unstable();
85 sorted_axes.reverse();
86
87 let count = axis_element_count(input.shape(), axes);
88
89 for &ax in &sorted_axes {
90 result = result.sum_axis(Axis(ax)).insert_axis(Axis(ax));
91 }
92 result / count
93}
94
95fn var_along_axes(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
97 let mean = mean_along_axes(input, axes);
98 let diff = input - &mean;
99 let sq = &diff * &diff;
100 mean_along_axes(&sq, axes)
101}
102
103#[derive(Debug, Clone)]
113pub struct RmsNorm {
114 pub normalized_shape: Vec<usize>,
116 pub eps: f64,
118 pub gamma: ArrayD<f64>,
120}
121
122impl RmsNorm {
123 pub fn new(normalized_shape: Vec<usize>, eps: f64) -> Result<Self, NormalizationError> {
128 if normalized_shape.is_empty() {
129 return Err(NormalizationError::EmptyInput);
130 }
131 let gamma = ArrayD::ones(IxDyn(&normalized_shape));
132 Ok(Self {
133 normalized_shape,
134 eps,
135 gamma,
136 })
137 }
138
139 pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
144 let ndim = input.ndim();
145 let norm_ndim = self.normalized_shape.len();
146 if ndim < norm_ndim {
147 return Err(NormalizationError::InsufficientDimensions {
148 ndim,
149 required: norm_ndim,
150 });
151 }
152 if input.is_empty() {
153 return Err(NormalizationError::EmptyInput);
154 }
155
156 let trailing: Vec<usize> = input.shape()[(ndim - norm_ndim)..].to_vec();
158 if trailing != self.normalized_shape {
159 return Err(NormalizationError::ShapeMismatch {
160 expected: self.normalized_shape.clone(),
161 got: trailing,
162 });
163 }
164
165 let axes: Vec<usize> = ((ndim - norm_ndim)..ndim).collect();
167 let rms = Self::rms(input, &axes);
168
169 let rms_inv = rms.mapv(|v| 1.0 / (v + self.eps));
171 let normalized = input * &rms_inv;
172 Ok(normalized * &self.gamma)
173 }
174
175 pub fn rms(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
177 let sq = input.mapv(|x| x * x);
178 let mean_sq = mean_along_axes(&sq, axes);
179 mean_sq.mapv(f64::sqrt)
180 }
181
182 pub fn update_gamma(&mut self, new_gamma: ArrayD<f64>) -> Result<(), NormalizationError> {
184 let expected: Vec<usize> = self.normalized_shape.clone();
185 let got: Vec<usize> = new_gamma.shape().to_vec();
186 if expected != got {
187 return Err(NormalizationError::ShapeMismatch { expected, got });
188 }
189 self.gamma = new_gamma;
190 Ok(())
191 }
192}
193
194#[derive(Debug, Clone)]
202pub struct GroupNorm {
203 pub num_groups: usize,
205 pub num_channels: usize,
207 pub eps: f64,
209 pub gamma: ArrayD<f64>,
211 pub beta: ArrayD<f64>,
213 pub affine: bool,
215}
216
217impl GroupNorm {
218 pub fn new(
222 num_groups: usize,
223 num_channels: usize,
224 eps: f64,
225 affine: bool,
226 ) -> Result<Self, NormalizationError> {
227 if num_groups == 0 || num_channels == 0 {
228 return Err(NormalizationError::EmptyInput);
229 }
230 if !num_channels.is_multiple_of(num_groups) {
231 return Err(NormalizationError::InvalidNumGroups {
232 groups: num_groups,
233 channels: num_channels,
234 });
235 }
236 let gamma = ArrayD::ones(IxDyn(&[num_channels]));
237 let beta = ArrayD::zeros(IxDyn(&[num_channels]));
238 Ok(Self {
239 num_groups,
240 num_channels,
241 eps,
242 gamma,
243 beta,
244 affine,
245 })
246 }
247
248 pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
252 let ndim = input.ndim();
253 if ndim < 2 {
254 return Err(NormalizationError::InsufficientDimensions { ndim, required: 2 });
255 }
256 if input.is_empty() {
257 return Err(NormalizationError::EmptyInput);
258 }
259
260 let shape = input.shape();
261 let batch_size = shape[0];
262 let channels = shape[1];
263
264 if channels != self.num_channels {
265 return Err(NormalizationError::ShapeMismatch {
266 expected: vec![batch_size, self.num_channels],
267 got: vec![batch_size, channels],
268 });
269 }
270
271 let cpg = self.channels_per_group();
272 let spatial: Vec<usize> = shape[2..].to_vec();
273
274 let mut reshaped_dims = vec![batch_size, self.num_groups, cpg];
276 reshaped_dims.extend_from_slice(&spatial);
277
278 let reshaped = input
279 .clone()
280 .into_shape_with_order(IxDyn(&reshaped_dims))
281 .map_err(|_| NormalizationError::ShapeMismatch {
282 expected: reshaped_dims.clone(),
283 got: shape.to_vec(),
284 })?;
285
286 let norm_axes: Vec<usize> = (2..reshaped.ndim()).collect();
288 let mean = mean_along_axes(&reshaped, &norm_axes);
289 let var = var_along_axes(&reshaped, &norm_axes);
290
291 let inv_std = var.mapv(|v| 1.0 / (v + self.eps).sqrt());
292 let normalized = (&reshaped - &mean) * &inv_std;
293
294 let mut out_shape = vec![batch_size, channels];
296 out_shape.extend_from_slice(&spatial);
297 let mut output = normalized
298 .into_shape_with_order(IxDyn(&out_shape))
299 .map_err(|_| NormalizationError::ShapeMismatch {
300 expected: out_shape.clone(),
301 got: vec![],
302 })?;
303
304 if self.affine {
306 let mut broadcast_shape = vec![1usize; ndim];
308 broadcast_shape[1] = channels;
309
310 let gamma_bc = self
311 .gamma
312 .clone()
313 .into_shape_with_order(IxDyn(&broadcast_shape))
314 .map_err(|_| NormalizationError::ShapeMismatch {
315 expected: broadcast_shape.clone(),
316 got: self.gamma.shape().to_vec(),
317 })?;
318 let beta_bc = self
319 .beta
320 .clone()
321 .into_shape_with_order(IxDyn(&broadcast_shape))
322 .map_err(|_| NormalizationError::ShapeMismatch {
323 expected: broadcast_shape.clone(),
324 got: self.beta.shape().to_vec(),
325 })?;
326
327 output = output * &gamma_bc + &beta_bc;
328 }
329
330 Ok(output)
331 }
332
333 pub fn channels_per_group(&self) -> usize {
335 self.num_channels / self.num_groups
336 }
337}
338
339#[derive(Debug, Clone)]
348pub struct InstanceNorm {
349 pub num_channels: usize,
351 pub eps: f64,
353 pub gamma: ArrayD<f64>,
355 pub beta: ArrayD<f64>,
357 pub affine: bool,
359}
360
361impl InstanceNorm {
362 pub fn new(num_channels: usize, eps: f64, affine: bool) -> Result<Self, NormalizationError> {
364 if num_channels == 0 {
365 return Err(NormalizationError::EmptyInput);
366 }
367 let gamma = ArrayD::ones(IxDyn(&[num_channels]));
368 let beta = ArrayD::zeros(IxDyn(&[num_channels]));
369 Ok(Self {
370 num_channels,
371 eps,
372 gamma,
373 beta,
374 affine,
375 })
376 }
377
378 pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
382 let mut gn = GroupNorm::new(self.num_channels, self.num_channels, self.eps, self.affine)?;
383 if self.affine {
384 gn.gamma = self.gamma.clone();
385 gn.beta = self.beta.clone();
386 }
387 gn.forward(input)
388 }
389}
390
391#[derive(Debug, Clone)]
400pub struct BatchNorm {
401 pub num_channels: usize,
403 pub eps: f64,
405 pub momentum: f64,
407 pub gamma: ArrayD<f64>,
409 pub beta: ArrayD<f64>,
411 pub affine: bool,
413 pub running_mean: ArrayD<f64>,
415 pub running_var: ArrayD<f64>,
417 pub training: bool,
419 pub num_batches_tracked: u64,
421}
422
423impl BatchNorm {
424 pub fn new(
426 num_channels: usize,
427 eps: f64,
428 momentum: f64,
429 affine: bool,
430 ) -> Result<Self, NormalizationError> {
431 if num_channels == 0 {
432 return Err(NormalizationError::EmptyInput);
433 }
434 let gamma = ArrayD::ones(IxDyn(&[num_channels]));
435 let beta = ArrayD::zeros(IxDyn(&[num_channels]));
436 let running_mean = ArrayD::zeros(IxDyn(&[num_channels]));
437 let running_var = ArrayD::ones(IxDyn(&[num_channels]));
438 Ok(Self {
439 num_channels,
440 eps,
441 momentum,
442 gamma,
443 beta,
444 affine,
445 running_mean,
446 running_var,
447 training: true,
448 num_batches_tracked: 0,
449 })
450 }
451
452 pub fn forward(&mut self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
457 let ndim = input.ndim();
458 if ndim < 2 {
459 return Err(NormalizationError::InsufficientDimensions { ndim, required: 2 });
460 }
461 if input.is_empty() {
462 return Err(NormalizationError::EmptyInput);
463 }
464
465 let shape = input.shape();
466 let channels = shape[1];
467 if channels != self.num_channels {
468 return Err(NormalizationError::ShapeMismatch {
469 expected: vec![shape[0], self.num_channels],
470 got: vec![shape[0], channels],
471 });
472 }
473
474 let reduce_axes: Vec<usize> = std::iter::once(0).chain(2..ndim).collect();
477
478 let (mean, var) = if self.training {
479 let batch_mean = mean_along_axes(input, &reduce_axes);
480 let batch_var = var_along_axes(input, &reduce_axes);
481
482 let mean_1d = batch_mean
484 .clone()
485 .into_shape_with_order(IxDyn(&[channels]))
486 .map_err(|_| NormalizationError::ShapeMismatch {
487 expected: vec![channels],
488 got: batch_mean.shape().to_vec(),
489 })?;
490 let var_1d = batch_var
491 .clone()
492 .into_shape_with_order(IxDyn(&[channels]))
493 .map_err(|_| NormalizationError::ShapeMismatch {
494 expected: vec![channels],
495 got: batch_var.shape().to_vec(),
496 })?;
497
498 let mom = self.momentum;
500 self.running_mean =
501 self.running_mean.mapv(|r| r * (1.0 - mom)) + mean_1d.mapv(|m| m * mom);
502 self.running_var =
503 self.running_var.mapv(|r| r * (1.0 - mom)) + var_1d.mapv(|v| v * mom);
504 self.num_batches_tracked += 1;
505
506 (batch_mean, batch_var)
507 } else {
508 let mut bc_shape = vec![1usize; ndim];
510 bc_shape[1] = channels;
511
512 let mean = self
513 .running_mean
514 .clone()
515 .into_shape_with_order(IxDyn(&bc_shape))
516 .map_err(|_| NormalizationError::ShapeMismatch {
517 expected: bc_shape.clone(),
518 got: self.running_mean.shape().to_vec(),
519 })?;
520 let var = self
521 .running_var
522 .clone()
523 .into_shape_with_order(IxDyn(&bc_shape))
524 .map_err(|_| NormalizationError::ShapeMismatch {
525 expected: bc_shape.clone(),
526 got: self.running_var.shape().to_vec(),
527 })?;
528 (mean, var)
529 };
530
531 let inv_std = var.mapv(|v| 1.0 / (v + self.eps).sqrt());
532 let mut output = (input - &mean) * &inv_std;
533
534 if self.affine {
535 let mut bc_shape = vec![1usize; ndim];
536 bc_shape[1] = channels;
537
538 let gamma_bc = self
539 .gamma
540 .clone()
541 .into_shape_with_order(IxDyn(&bc_shape))
542 .map_err(|_| NormalizationError::ShapeMismatch {
543 expected: bc_shape.clone(),
544 got: self.gamma.shape().to_vec(),
545 })?;
546 let beta_bc = self
547 .beta
548 .clone()
549 .into_shape_with_order(IxDyn(&bc_shape))
550 .map_err(|_| NormalizationError::ShapeMismatch {
551 expected: bc_shape.clone(),
552 got: self.beta.shape().to_vec(),
553 })?;
554 output = output * &gamma_bc + &beta_bc;
555 }
556
557 Ok(output)
558 }
559
560 pub fn eval_mode(&mut self) {
562 self.training = false;
563 }
564
565 pub fn train_mode(&mut self) {
567 self.training = true;
568 }
569
570 pub fn is_training(&self) -> bool {
572 self.training
573 }
574
575 pub fn reset_running_stats(&mut self) {
577 self.running_mean = ArrayD::zeros(IxDyn(&[self.num_channels]));
578 self.running_var = ArrayD::ones(IxDyn(&[self.num_channels]));
579 self.num_batches_tracked = 0;
580 }
581}
582
583#[derive(Debug, Clone)]
591pub struct WeightNorm {
592 pub dim: usize,
594}
595
596impl WeightNorm {
597 pub fn new(dim: usize) -> Self {
599 Self { dim }
600 }
601
602 pub fn apply(
605 &self,
606 weight: &ArrayD<f64>,
607 ) -> Result<(ArrayD<f64>, ArrayD<f64>), NormalizationError> {
608 let ndim = weight.ndim();
609 if ndim == 0 {
610 return Err(NormalizationError::EmptyInput);
611 }
612 if self.dim >= ndim {
613 return Err(NormalizationError::InvalidAxis {
614 axis: self.dim,
615 ndim,
616 });
617 }
618
619 let reduce_axes: Vec<usize> = (0..ndim).filter(|&a| a != self.dim).collect();
621
622 let sq = weight.mapv(|x| x * x);
624 let mut sum_sq = sq;
625 let mut sorted_axes = reduce_axes.clone();
627 sorted_axes.sort_unstable();
628 sorted_axes.reverse();
629 for &ax in &sorted_axes {
630 sum_sq = sum_sq.sum_axis(Axis(ax));
631 }
632 let g = sum_sq.mapv(f64::sqrt);
634
635 let mut bc_shape = vec![1usize; ndim];
637 bc_shape[self.dim] = weight.shape()[self.dim];
638 let g_bc = g
639 .clone()
640 .into_shape_with_order(IxDyn(&bc_shape))
641 .map_err(|_| NormalizationError::ShapeMismatch {
642 expected: bc_shape.clone(),
643 got: g.shape().to_vec(),
644 })?;
645
646 let v = weight / &g_bc.mapv(|val| if val.abs() < 1e-12 { 1e-12 } else { val });
648
649 Ok((g, v))
650 }
651
652 pub fn reparametrize(
654 g: &ArrayD<f64>,
655 v: &ArrayD<f64>,
656 dim: usize,
657 ) -> Result<ArrayD<f64>, NormalizationError> {
658 let ndim = v.ndim();
659 if ndim == 0 {
660 return Err(NormalizationError::EmptyInput);
661 }
662 if dim >= ndim {
663 return Err(NormalizationError::InvalidAxis { axis: dim, ndim });
664 }
665
666 let reduce_axes: Vec<usize> = (0..ndim).filter(|&a| a != dim).collect();
668 let sq = v.mapv(|x| x * x);
669 let mut sum_sq = sq;
670 let mut sorted_axes = reduce_axes;
671 sorted_axes.sort_unstable();
672 sorted_axes.reverse();
673 for &ax in &sorted_axes {
674 sum_sq = sum_sq.sum_axis(Axis(ax));
675 }
676 let v_norm = sum_sq.mapv(f64::sqrt);
677
678 let mut bc_shape = vec![1usize; ndim];
680 bc_shape[dim] = v.shape()[dim];
681
682 let g_bc = g
683 .clone()
684 .into_shape_with_order(IxDyn(&bc_shape))
685 .map_err(|_| NormalizationError::ShapeMismatch {
686 expected: bc_shape.clone(),
687 got: g.shape().to_vec(),
688 })?;
689 let v_norm_bc = v_norm
690 .into_shape_with_order(IxDyn(&bc_shape))
691 .map_err(|_| NormalizationError::ShapeMismatch {
692 expected: bc_shape.clone(),
693 got: vec![],
694 })?;
695
696 let v_norm_safe = v_norm_bc.mapv(|val| if val.abs() < 1e-12 { 1e-12 } else { val });
697 Ok(v * &g_bc / &v_norm_safe)
698 }
699}
700
701#[derive(Debug, Clone)]
707pub struct NormStats {
708 pub input_mean: f64,
710 pub input_std: f64,
712 pub output_mean: f64,
714 pub output_std: f64,
716 pub gamma_mean: f64,
718 pub beta_mean: f64,
720}
721
722impl NormStats {
723 pub fn compute(
725 input: &ArrayD<f64>,
726 output: &ArrayD<f64>,
727 gamma: &ArrayD<f64>,
728 beta: &ArrayD<f64>,
729 ) -> Self {
730 let input_mean = Self::array_mean(input);
731 let input_std = Self::array_std(input, input_mean);
732 let output_mean = Self::array_mean(output);
733 let output_std = Self::array_std(output, output_mean);
734 let gamma_mean = Self::array_mean(gamma);
735 let beta_mean = Self::array_mean(beta);
736
737 Self {
738 input_mean,
739 input_std,
740 output_mean,
741 output_std,
742 gamma_mean,
743 beta_mean,
744 }
745 }
746
747 pub fn summary(&self) -> String {
749 format!(
750 "NormStats {{ input: mean={:.6}, std={:.6} | output: mean={:.6}, std={:.6} | gamma_mean={:.6}, beta_mean={:.6} }}",
751 self.input_mean, self.input_std,
752 self.output_mean, self.output_std,
753 self.gamma_mean, self.beta_mean,
754 )
755 }
756
757 fn array_mean(arr: &ArrayD<f64>) -> f64 {
760 if arr.is_empty() {
761 return 0.0;
762 }
763 arr.sum() / arr.len() as f64
764 }
765
766 fn array_std(arr: &ArrayD<f64>, mean: f64) -> f64 {
767 if arr.len() <= 1 {
768 return 0.0;
769 }
770 let var = arr.mapv(|x| (x - mean).powi(2)).sum() / arr.len() as f64;
771 var.sqrt()
772 }
773}
774
775#[cfg(test)]
780mod tests {
781 use super::*;
782 use ndarray::ArrayD;
783
784 fn make_input_4d(batch: usize, channels: usize, h: usize, w: usize) -> ArrayD<f64> {
785 let total = batch * channels * h * w;
786 ArrayD::from_shape_vec(
787 IxDyn(&[batch, channels, h, w]),
788 (0..total).map(|i| (i as f64) * 0.01 + 0.1).collect(),
789 )
790 .expect("test helper: shape matches element count")
791 }
792
793 fn make_input_2d(rows: usize, cols: usize) -> ArrayD<f64> {
794 let total = rows * cols;
795 ArrayD::from_shape_vec(
796 IxDyn(&[rows, cols]),
797 (0..total).map(|i| (i as f64) * 0.05 + 0.5).collect(),
798 )
799 .expect("test helper: shape matches element count")
800 }
801
802 #[test]
807 fn test_rmsnorm_new_valid() {
808 let rms = RmsNorm::new(vec![64], 1e-5);
809 assert!(rms.is_ok());
810 let rms = rms.expect("already checked");
811 assert_eq!(rms.normalized_shape, vec![64]);
812 }
813
814 #[test]
815 fn test_rmsnorm_forward_shape_preserved() {
816 let rms = RmsNorm::new(vec![8], 1e-5).expect("valid config");
817 let input = make_input_2d(4, 8);
818 let output = rms.forward(&input).expect("forward should succeed");
819 assert_eq!(output.shape(), input.shape());
820 }
821
822 #[test]
823 fn test_rmsnorm_output_scale() {
824 let rms = RmsNorm::new(vec![16], 1e-8).expect("valid config");
827 let input = ArrayD::from_shape_vec(
828 IxDyn(&[2, 16]),
829 (0..32).map(|i| (i as f64) * 0.1 + 1.0).collect(),
830 )
831 .expect("test data");
832
833 let output = rms.forward(&input).expect("forward");
834 for row_idx in 0..2 {
836 let mut sum_sq = 0.0;
837 for col_idx in 0..16 {
838 let v = output[[row_idx, col_idx]];
839 sum_sq += v * v;
840 }
841 let row_rms = (sum_sq / 16.0).sqrt();
842 assert!(
843 (row_rms - 1.0).abs() < 0.1,
844 "RMS should be close to 1, got {}",
845 row_rms
846 );
847 }
848 }
849
850 #[test]
851 fn test_rmsnorm_update_gamma() {
852 let mut rms = RmsNorm::new(vec![4], 1e-5).expect("valid");
853 let new_gamma =
854 ArrayD::from_shape_vec(IxDyn(&[4]), vec![2.0, 2.0, 2.0, 2.0]).expect("test data");
855 assert!(rms.update_gamma(new_gamma).is_ok());
856 assert!((rms.gamma[[0]] - 2.0).abs() < 1e-10);
857 }
858
859 #[test]
864 fn test_groupnorm_new_valid() {
865 let gn = GroupNorm::new(4, 16, 1e-5, true);
866 assert!(gn.is_ok());
867 }
868
869 #[test]
870 fn test_groupnorm_invalid_groups() {
871 let gn = GroupNorm::new(5, 16, 1e-5, true);
872 assert!(gn.is_err());
873 match gn {
874 Err(NormalizationError::InvalidNumGroups { groups, channels }) => {
875 assert_eq!(groups, 5);
876 assert_eq!(channels, 16);
877 }
878 _ => panic!("Expected InvalidNumGroups error"),
879 }
880 }
881
882 #[test]
883 fn test_groupnorm_forward_shape_preserved() {
884 let gn = GroupNorm::new(4, 8, 1e-5, true).expect("valid");
885 let input = make_input_4d(2, 8, 4, 4);
886 let output = gn.forward(&input).expect("forward");
887 assert_eq!(output.shape(), input.shape());
888 }
889
890 #[test]
891 fn test_groupnorm_channels_per_group() {
892 let gn = GroupNorm::new(4, 16, 1e-5, true).expect("valid");
893 assert_eq!(gn.channels_per_group(), 4);
894 }
895
896 #[test]
901 fn test_instancenorm_new_valid() {
902 let ins = InstanceNorm::new(8, 1e-5, true);
903 assert!(ins.is_ok());
904 }
905
906 #[test]
907 fn test_instancenorm_forward_shape_preserved() {
908 let ins = InstanceNorm::new(4, 1e-5, true).expect("valid");
909 let input = make_input_4d(2, 4, 3, 3);
910 let output = ins.forward(&input).expect("forward");
911 assert_eq!(output.shape(), input.shape());
912 }
913
914 #[test]
915 fn test_instancenorm_normalizes_per_instance() {
916 let ins = InstanceNorm::new(2, 1e-8, false).expect("valid");
917 let input = ArrayD::from_shape_vec(
918 IxDyn(&[2, 2, 4]),
919 (0..16).map(|i| (i as f64) * 0.5 + 1.0).collect(),
920 )
921 .expect("test data");
922
923 let output = ins.forward(&input).expect("forward");
924 for b in 0..2 {
926 for c in 0..2 {
927 let mut sum = 0.0;
928 for s in 0..4 {
929 sum += output[[b, c, s]];
930 }
931 let slice_mean = sum / 4.0;
932 assert!(
933 slice_mean.abs() < 0.01,
934 "Expected ~0 mean, got {} at b={}, c={}",
935 slice_mean,
936 b,
937 c
938 );
939 }
940 }
941 }
942
943 #[test]
948 fn test_batchnorm_new_valid() {
949 let bn = BatchNorm::new(16, 1e-5, 0.1, true);
950 assert!(bn.is_ok());
951 let bn = bn.expect("valid");
952 assert!(bn.is_training());
953 }
954
955 #[test]
956 fn test_batchnorm_forward_training() {
957 let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
958 let input = make_input_4d(2, 4, 3, 3);
959 let output = bn.forward(&input).expect("forward");
960 assert_eq!(output.shape(), input.shape());
961 }
962
963 #[test]
964 fn test_batchnorm_running_stats_update() {
965 let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
966 let initial_mean = bn.running_mean.clone();
967 let input = make_input_4d(2, 4, 3, 3);
968 let _output = bn.forward(&input).expect("forward");
969 assert_ne!(bn.running_mean, initial_mean);
971 assert_eq!(bn.num_batches_tracked, 1);
972 }
973
974 #[test]
975 fn test_batchnorm_eval_mode() {
976 let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
977 let input = make_input_4d(2, 4, 3, 3);
979 let _output = bn.forward(&input).expect("forward training");
980
981 bn.eval_mode();
983 assert!(!bn.is_training());
984 let batches_before = bn.num_batches_tracked;
985 let output = bn.forward(&input).expect("forward eval");
986 assert_eq!(output.shape(), input.shape());
987 assert_eq!(bn.num_batches_tracked, batches_before);
989 }
990
991 #[test]
992 fn test_batchnorm_train_eval_toggle() {
993 let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
994 assert!(bn.is_training());
995 bn.eval_mode();
996 assert!(!bn.is_training());
997 bn.train_mode();
998 assert!(bn.is_training());
999 }
1000
1001 #[test]
1006 fn test_weightnorm_apply() {
1007 let wn = WeightNorm::new(0);
1008 let weight = ArrayD::from_shape_vec(
1009 IxDyn(&[3, 4]),
1010 (0..12).map(|i| (i as f64) * 0.1 + 0.1).collect(),
1011 )
1012 .expect("test data");
1013
1014 let (g, v) = wn.apply(&weight).expect("apply");
1015 assert_eq!(g.shape(), &[3]);
1017 assert_eq!(v.shape(), weight.shape());
1019 }
1020
1021 #[test]
1022 fn test_weightnorm_reparametrize() {
1023 let wn = WeightNorm::new(0);
1024 let weight = ArrayD::from_shape_vec(
1025 IxDyn(&[3, 4]),
1026 (0..12).map(|i| (i as f64) * 0.3 + 0.5).collect(),
1027 )
1028 .expect("test data");
1029
1030 let (g, v) = wn.apply(&weight).expect("apply");
1031 let reconstructed = WeightNorm::reparametrize(&g, &v, 0).expect("reparametrize");
1032
1033 assert_eq!(reconstructed.shape(), weight.shape());
1034 for (orig, recon) in weight.iter().zip(reconstructed.iter()) {
1036 assert!(
1037 (orig - recon).abs() < 1e-8,
1038 "Mismatch: orig={}, recon={}",
1039 orig,
1040 recon
1041 );
1042 }
1043 }
1044
1045 #[test]
1050 fn test_norm_stats_compute() {
1051 let input = make_input_2d(4, 8);
1052 let output = make_input_2d(4, 8);
1053 let gamma = ArrayD::ones(IxDyn(&[8]));
1054 let beta = ArrayD::zeros(IxDyn(&[8]));
1055
1056 let stats = NormStats::compute(&input, &output, &gamma, &beta);
1057 assert!(stats.input_mean.is_finite());
1059 assert!(stats.input_std.is_finite());
1060 assert!(stats.output_mean.is_finite());
1061 assert!(stats.output_std.is_finite());
1062 assert!(stats.gamma_mean.is_finite());
1063 assert!(stats.beta_mean.is_finite());
1064 }
1065
1066 #[test]
1067 fn test_norm_stats_summary_nonempty() {
1068 let input = make_input_2d(2, 4);
1069 let output = make_input_2d(2, 4);
1070 let gamma = ArrayD::ones(IxDyn(&[4]));
1071 let beta = ArrayD::zeros(IxDyn(&[4]));
1072
1073 let stats = NormStats::compute(&input, &output, &gamma, &beta);
1074 let summary = stats.summary();
1075 assert!(!summary.is_empty());
1076 assert!(summary.contains("NormStats"));
1077 }
1078}