Skip to main content

tensorlogic_trustformers/
normalization_variants.rs

1//! Advanced normalization variants for transformer architectures.
2//!
3//! This module provides numerical (ndarray-based) implementations of normalization
4//! techniques beyond standard LayerNorm, including:
5//!
6//! - **RmsNorm**: Root Mean Square Layer Normalization (no mean centering)
7//! - **GroupNorm**: Group Normalization (divides channels into groups)
8//! - **InstanceNorm**: Instance Normalization (per-instance, per-channel)
9//! - **BatchNorm**: Batch Normalization (across-batch statistics)
10//! - **WeightNorm**: Weight Normalization (weight reparametrization)
11//! - **NormStats**: Normalization statistics for debugging/monitoring
12
13use ndarray::{ArrayD, Axis, IxDyn};
14
15/// Errors that can occur during normalization operations.
16#[derive(Debug, Clone)]
17pub enum NormalizationError {
18    /// Shape mismatch between expected and actual tensor shapes.
19    ShapeMismatch {
20        expected: Vec<usize>,
21        got: Vec<usize>,
22    },
23    /// Invalid axis for the given tensor dimensionality.
24    InvalidAxis { axis: usize, ndim: usize },
25    /// Number of groups does not evenly divide number of channels.
26    InvalidNumGroups { groups: usize, channels: usize },
27    /// Encountered zero variance during normalization.
28    ZeroVariance,
29    /// Input tensor is empty.
30    EmptyInput,
31    /// Tensor does not have enough dimensions for the operation.
32    InsufficientDimensions { ndim: usize, required: usize },
33}
34
35impl std::fmt::Display for NormalizationError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Self::ShapeMismatch { expected, got } => {
39                write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
40            }
41            Self::InvalidAxis { axis, ndim } => {
42                write!(
43                    f,
44                    "Invalid axis {} for tensor with {} dimensions",
45                    axis, ndim
46                )
47            }
48            Self::InvalidNumGroups { groups, channels } => {
49                write!(
50                    f,
51                    "Invalid number of groups {}: does not evenly divide {} channels",
52                    groups, channels
53                )
54            }
55            Self::ZeroVariance => write!(f, "Zero variance encountered during normalization"),
56            Self::EmptyInput => write!(f, "Empty input tensor"),
57            Self::InsufficientDimensions { ndim, required } => {
58                write!(
59                    f,
60                    "Insufficient dimensions: tensor has {} dims, but {} required",
61                    ndim, required
62                )
63            }
64        }
65    }
66}
67
68impl std::error::Error for NormalizationError {}
69
70// ---------------------------------------------------------------------------
71// Helper: compute mean and variance along a set of axes
72// ---------------------------------------------------------------------------
73
74/// Compute element count for given axes.
75fn axis_element_count(shape: &[usize], axes: &[usize]) -> f64 {
76    axes.iter().map(|&a| shape[a] as f64).product()
77}
78
79/// Compute mean along specified axes, keeping dimensions.
80fn mean_along_axes(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
81    let mut result = input.clone();
82    // Process axes in descending order so indices remain valid
83    let mut sorted_axes: Vec<usize> = axes.to_vec();
84    sorted_axes.sort_unstable();
85    sorted_axes.reverse();
86
87    let count = axis_element_count(input.shape(), axes);
88
89    for &ax in &sorted_axes {
90        result = result.sum_axis(Axis(ax)).insert_axis(Axis(ax));
91    }
92    result / count
93}
94
95/// Compute variance along specified axes, keeping dimensions.
96fn var_along_axes(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
97    let mean = mean_along_axes(input, axes);
98    let diff = input - &mean;
99    let sq = &diff * &diff;
100    mean_along_axes(&sq, axes)
101}
102
103// ---------------------------------------------------------------------------
104// RmsNorm
105// ---------------------------------------------------------------------------
106
107/// Root Mean Square Layer Normalization (no mean centering).
108///
109/// Normalizes by: `x / sqrt(mean(x^2) + eps) * gamma`
110///
111/// Used in LLaMA and other modern transformer architectures.
112#[derive(Debug, Clone)]
113pub struct RmsNorm {
114    /// Shape of the dimensions to normalize over (typically the last N dims).
115    pub normalized_shape: Vec<usize>,
116    /// Small constant for numerical stability.
117    pub eps: f64,
118    /// Learnable scale parameter.
119    pub gamma: ArrayD<f64>,
120}
121
122impl RmsNorm {
123    /// Create a new RmsNorm layer.
124    ///
125    /// `normalized_shape` specifies the trailing dimensions to normalize over.
126    /// `gamma` is initialized to ones.
127    pub fn new(normalized_shape: Vec<usize>, eps: f64) -> Result<Self, NormalizationError> {
128        if normalized_shape.is_empty() {
129            return Err(NormalizationError::EmptyInput);
130        }
131        let gamma = ArrayD::ones(IxDyn(&normalized_shape));
132        Ok(Self {
133            normalized_shape,
134            eps,
135            gamma,
136        })
137    }
138
139    /// Forward pass: normalize the input tensor.
140    ///
141    /// The last `normalized_shape.len()` dimensions of the input must match
142    /// `normalized_shape`.
143    pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
144        let ndim = input.ndim();
145        let norm_ndim = self.normalized_shape.len();
146        if ndim < norm_ndim {
147            return Err(NormalizationError::InsufficientDimensions {
148                ndim,
149                required: norm_ndim,
150            });
151        }
152        if input.is_empty() {
153            return Err(NormalizationError::EmptyInput);
154        }
155
156        // Verify trailing shape matches
157        let trailing: Vec<usize> = input.shape()[(ndim - norm_ndim)..].to_vec();
158        if trailing != self.normalized_shape {
159            return Err(NormalizationError::ShapeMismatch {
160                expected: self.normalized_shape.clone(),
161                got: trailing,
162            });
163        }
164
165        // Axes to reduce over (trailing dimensions)
166        let axes: Vec<usize> = ((ndim - norm_ndim)..ndim).collect();
167        let rms = Self::rms(input, &axes);
168
169        // x / (rms + eps) * gamma
170        let rms_inv = rms.mapv(|v| 1.0 / (v + self.eps));
171        let normalized = input * &rms_inv;
172        Ok(normalized * &self.gamma)
173    }
174
175    /// Compute Root Mean Square along specified axes (keeping dims).
176    pub fn rms(input: &ArrayD<f64>, axes: &[usize]) -> ArrayD<f64> {
177        let sq = input.mapv(|x| x * x);
178        let mean_sq = mean_along_axes(&sq, axes);
179        mean_sq.mapv(f64::sqrt)
180    }
181
182    /// Update the learnable scale parameter gamma.
183    pub fn update_gamma(&mut self, new_gamma: ArrayD<f64>) -> Result<(), NormalizationError> {
184        let expected: Vec<usize> = self.normalized_shape.clone();
185        let got: Vec<usize> = new_gamma.shape().to_vec();
186        if expected != got {
187            return Err(NormalizationError::ShapeMismatch { expected, got });
188        }
189        self.gamma = new_gamma;
190        Ok(())
191    }
192}
193
194// ---------------------------------------------------------------------------
195// GroupNorm
196// ---------------------------------------------------------------------------
197
198/// Group Normalization: divides channels into groups and normalizes within each group.
199///
200/// Input shape: `[batch, channels, ...spatial_dims...]`
201#[derive(Debug, Clone)]
202pub struct GroupNorm {
203    /// Number of groups to divide channels into.
204    pub num_groups: usize,
205    /// Total number of channels.
206    pub num_channels: usize,
207    /// Small constant for numerical stability.
208    pub eps: f64,
209    /// Learnable scale parameter `[channels]`.
210    pub gamma: ArrayD<f64>,
211    /// Learnable shift parameter `[channels]`.
212    pub beta: ArrayD<f64>,
213    /// Whether to apply learnable affine transformation.
214    pub affine: bool,
215}
216
217impl GroupNorm {
218    /// Create a new GroupNorm layer.
219    ///
220    /// `num_channels` must be evenly divisible by `num_groups`.
221    pub fn new(
222        num_groups: usize,
223        num_channels: usize,
224        eps: f64,
225        affine: bool,
226    ) -> Result<Self, NormalizationError> {
227        if num_groups == 0 || num_channels == 0 {
228            return Err(NormalizationError::EmptyInput);
229        }
230        if !num_channels.is_multiple_of(num_groups) {
231            return Err(NormalizationError::InvalidNumGroups {
232                groups: num_groups,
233                channels: num_channels,
234            });
235        }
236        let gamma = ArrayD::ones(IxDyn(&[num_channels]));
237        let beta = ArrayD::zeros(IxDyn(&[num_channels]));
238        Ok(Self {
239            num_groups,
240            num_channels,
241            eps,
242            gamma,
243            beta,
244            affine,
245        })
246    }
247
248    /// Forward pass: normalize within each group.
249    ///
250    /// Input shape: `[batch, channels, ...spatial_dims...]` (at least 2-D).
251    pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
252        let ndim = input.ndim();
253        if ndim < 2 {
254            return Err(NormalizationError::InsufficientDimensions { ndim, required: 2 });
255        }
256        if input.is_empty() {
257            return Err(NormalizationError::EmptyInput);
258        }
259
260        let shape = input.shape();
261        let batch_size = shape[0];
262        let channels = shape[1];
263
264        if channels != self.num_channels {
265            return Err(NormalizationError::ShapeMismatch {
266                expected: vec![batch_size, self.num_channels],
267                got: vec![batch_size, channels],
268            });
269        }
270
271        let cpg = self.channels_per_group();
272        let spatial: Vec<usize> = shape[2..].to_vec();
273
274        // Build reshaped dims: [B, G, C/G, *spatial]
275        let mut reshaped_dims = vec![batch_size, self.num_groups, cpg];
276        reshaped_dims.extend_from_slice(&spatial);
277
278        let reshaped = input
279            .clone()
280            .into_shape_with_order(IxDyn(&reshaped_dims))
281            .map_err(|_| NormalizationError::ShapeMismatch {
282                expected: reshaped_dims.clone(),
283                got: shape.to_vec(),
284            })?;
285
286        // Normalize over axes 2.. (C/G and spatial dims)
287        let norm_axes: Vec<usize> = (2..reshaped.ndim()).collect();
288        let mean = mean_along_axes(&reshaped, &norm_axes);
289        let var = var_along_axes(&reshaped, &norm_axes);
290
291        let inv_std = var.mapv(|v| 1.0 / (v + self.eps).sqrt());
292        let normalized = (&reshaped - &mean) * &inv_std;
293
294        // Reshape back to [B, C, *spatial]
295        let mut out_shape = vec![batch_size, channels];
296        out_shape.extend_from_slice(&spatial);
297        let mut output = normalized
298            .into_shape_with_order(IxDyn(&out_shape))
299            .map_err(|_| NormalizationError::ShapeMismatch {
300                expected: out_shape.clone(),
301                got: vec![],
302            })?;
303
304        // Apply affine: broadcast gamma/beta over batch and spatial dims
305        if self.affine {
306            // Build broadcast shape for gamma/beta: [1, C, 1, 1, ...]
307            let mut broadcast_shape = vec![1usize; ndim];
308            broadcast_shape[1] = channels;
309
310            let gamma_bc = self
311                .gamma
312                .clone()
313                .into_shape_with_order(IxDyn(&broadcast_shape))
314                .map_err(|_| NormalizationError::ShapeMismatch {
315                    expected: broadcast_shape.clone(),
316                    got: self.gamma.shape().to_vec(),
317                })?;
318            let beta_bc = self
319                .beta
320                .clone()
321                .into_shape_with_order(IxDyn(&broadcast_shape))
322                .map_err(|_| NormalizationError::ShapeMismatch {
323                    expected: broadcast_shape.clone(),
324                    got: self.beta.shape().to_vec(),
325                })?;
326
327            output = output * &gamma_bc + &beta_bc;
328        }
329
330        Ok(output)
331    }
332
333    /// Number of channels per group.
334    pub fn channels_per_group(&self) -> usize {
335        self.num_channels / self.num_groups
336    }
337}
338
339// ---------------------------------------------------------------------------
340// InstanceNorm
341// ---------------------------------------------------------------------------
342
343/// Instance Normalization: normalizes each (batch, channel) independently.
344///
345/// Equivalent to GroupNorm with `num_groups == num_channels`.
346/// Input shape: `[batch, channels, ...spatial_dims...]`
347#[derive(Debug, Clone)]
348pub struct InstanceNorm {
349    /// Number of channels.
350    pub num_channels: usize,
351    /// Small constant for numerical stability.
352    pub eps: f64,
353    /// Learnable scale parameter.
354    pub gamma: ArrayD<f64>,
355    /// Learnable shift parameter.
356    pub beta: ArrayD<f64>,
357    /// Whether to apply learnable affine transformation.
358    pub affine: bool,
359}
360
361impl InstanceNorm {
362    /// Create a new InstanceNorm layer.
363    pub fn new(num_channels: usize, eps: f64, affine: bool) -> Result<Self, NormalizationError> {
364        if num_channels == 0 {
365            return Err(NormalizationError::EmptyInput);
366        }
367        let gamma = ArrayD::ones(IxDyn(&[num_channels]));
368        let beta = ArrayD::zeros(IxDyn(&[num_channels]));
369        Ok(Self {
370            num_channels,
371            eps,
372            gamma,
373            beta,
374            affine,
375        })
376    }
377
378    /// Forward pass: normalize each (batch, channel) slice independently.
379    ///
380    /// Delegates to GroupNorm with `num_groups == num_channels`.
381    pub fn forward(&self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
382        let mut gn = GroupNorm::new(self.num_channels, self.num_channels, self.eps, self.affine)?;
383        if self.affine {
384            gn.gamma = self.gamma.clone();
385            gn.beta = self.beta.clone();
386        }
387        gn.forward(input)
388    }
389}
390
391// ---------------------------------------------------------------------------
392// BatchNorm
393// ---------------------------------------------------------------------------
394
395/// Batch Normalization: normalizes across the batch dimension.
396///
397/// Tracks running mean/variance for evaluation mode.
398/// Input shape: `[batch, channels, ...spatial_dims...]`
399#[derive(Debug, Clone)]
400pub struct BatchNorm {
401    /// Number of channels (features).
402    pub num_channels: usize,
403    /// Small constant for numerical stability.
404    pub eps: f64,
405    /// Momentum for running statistics update (EMA coefficient).
406    pub momentum: f64,
407    /// Learnable scale parameter.
408    pub gamma: ArrayD<f64>,
409    /// Learnable shift parameter.
410    pub beta: ArrayD<f64>,
411    /// Whether to apply learnable affine transformation.
412    pub affine: bool,
413    /// Running mean used during evaluation.
414    pub running_mean: ArrayD<f64>,
415    /// Running variance used during evaluation.
416    pub running_var: ArrayD<f64>,
417    /// Whether the module is in training mode.
418    pub training: bool,
419    /// Number of mini-batches tracked.
420    pub num_batches_tracked: u64,
421}
422
423impl BatchNorm {
424    /// Create a new BatchNorm layer.
425    pub fn new(
426        num_channels: usize,
427        eps: f64,
428        momentum: f64,
429        affine: bool,
430    ) -> Result<Self, NormalizationError> {
431        if num_channels == 0 {
432            return Err(NormalizationError::EmptyInput);
433        }
434        let gamma = ArrayD::ones(IxDyn(&[num_channels]));
435        let beta = ArrayD::zeros(IxDyn(&[num_channels]));
436        let running_mean = ArrayD::zeros(IxDyn(&[num_channels]));
437        let running_var = ArrayD::ones(IxDyn(&[num_channels]));
438        Ok(Self {
439            num_channels,
440            eps,
441            momentum,
442            gamma,
443            beta,
444            affine,
445            running_mean,
446            running_var,
447            training: true,
448            num_batches_tracked: 0,
449        })
450    }
451
452    /// Forward pass: normalize across batch (and spatial) dimensions per channel.
453    ///
454    /// In training mode: compute batch statistics and update running stats via EMA.
455    /// In eval mode: use running statistics.
456    pub fn forward(&mut self, input: &ArrayD<f64>) -> Result<ArrayD<f64>, NormalizationError> {
457        let ndim = input.ndim();
458        if ndim < 2 {
459            return Err(NormalizationError::InsufficientDimensions { ndim, required: 2 });
460        }
461        if input.is_empty() {
462            return Err(NormalizationError::EmptyInput);
463        }
464
465        let shape = input.shape();
466        let channels = shape[1];
467        if channels != self.num_channels {
468            return Err(NormalizationError::ShapeMismatch {
469                expected: vec![shape[0], self.num_channels],
470                got: vec![shape[0], channels],
471            });
472        }
473
474        // Axes to reduce: batch (0) and all spatial dims (2..)
475        // We keep the channel axis (1).
476        let reduce_axes: Vec<usize> = std::iter::once(0).chain(2..ndim).collect();
477
478        let (mean, var) = if self.training {
479            let batch_mean = mean_along_axes(input, &reduce_axes);
480            let batch_var = var_along_axes(input, &reduce_axes);
481
482            // Squeeze to [C] for running stat update
483            let mean_1d = batch_mean
484                .clone()
485                .into_shape_with_order(IxDyn(&[channels]))
486                .map_err(|_| NormalizationError::ShapeMismatch {
487                    expected: vec![channels],
488                    got: batch_mean.shape().to_vec(),
489                })?;
490            let var_1d = batch_var
491                .clone()
492                .into_shape_with_order(IxDyn(&[channels]))
493                .map_err(|_| NormalizationError::ShapeMismatch {
494                    expected: vec![channels],
495                    got: batch_var.shape().to_vec(),
496                })?;
497
498            // EMA update: running = (1 - momentum) * running + momentum * batch
499            let mom = self.momentum;
500            self.running_mean =
501                self.running_mean.mapv(|r| r * (1.0 - mom)) + mean_1d.mapv(|m| m * mom);
502            self.running_var =
503                self.running_var.mapv(|r| r * (1.0 - mom)) + var_1d.mapv(|v| v * mom);
504            self.num_batches_tracked += 1;
505
506            (batch_mean, batch_var)
507        } else {
508            // Build broadcast shape: [1, C, 1, 1, ...]
509            let mut bc_shape = vec![1usize; ndim];
510            bc_shape[1] = channels;
511
512            let mean = self
513                .running_mean
514                .clone()
515                .into_shape_with_order(IxDyn(&bc_shape))
516                .map_err(|_| NormalizationError::ShapeMismatch {
517                    expected: bc_shape.clone(),
518                    got: self.running_mean.shape().to_vec(),
519                })?;
520            let var = self
521                .running_var
522                .clone()
523                .into_shape_with_order(IxDyn(&bc_shape))
524                .map_err(|_| NormalizationError::ShapeMismatch {
525                    expected: bc_shape.clone(),
526                    got: self.running_var.shape().to_vec(),
527                })?;
528            (mean, var)
529        };
530
531        let inv_std = var.mapv(|v| 1.0 / (v + self.eps).sqrt());
532        let mut output = (input - &mean) * &inv_std;
533
534        if self.affine {
535            let mut bc_shape = vec![1usize; ndim];
536            bc_shape[1] = channels;
537
538            let gamma_bc = self
539                .gamma
540                .clone()
541                .into_shape_with_order(IxDyn(&bc_shape))
542                .map_err(|_| NormalizationError::ShapeMismatch {
543                    expected: bc_shape.clone(),
544                    got: self.gamma.shape().to_vec(),
545                })?;
546            let beta_bc = self
547                .beta
548                .clone()
549                .into_shape_with_order(IxDyn(&bc_shape))
550                .map_err(|_| NormalizationError::ShapeMismatch {
551                    expected: bc_shape.clone(),
552                    got: self.beta.shape().to_vec(),
553                })?;
554            output = output * &gamma_bc + &beta_bc;
555        }
556
557        Ok(output)
558    }
559
560    /// Switch to evaluation mode (use running statistics).
561    pub fn eval_mode(&mut self) {
562        self.training = false;
563    }
564
565    /// Switch to training mode (compute batch statistics).
566    pub fn train_mode(&mut self) {
567        self.training = true;
568    }
569
570    /// Check whether the module is in training mode.
571    pub fn is_training(&self) -> bool {
572        self.training
573    }
574
575    /// Reset running statistics to initial values.
576    pub fn reset_running_stats(&mut self) {
577        self.running_mean = ArrayD::zeros(IxDyn(&[self.num_channels]));
578        self.running_var = ArrayD::ones(IxDyn(&[self.num_channels]));
579        self.num_batches_tracked = 0;
580    }
581}
582
583// ---------------------------------------------------------------------------
584// WeightNorm
585// ---------------------------------------------------------------------------
586
587/// Weight Normalization: reparametrizes weight as `w = g * v / ||v||`.
588///
589/// This is a weight reparametrization technique, not a layer normalization.
590#[derive(Debug, Clone)]
591pub struct WeightNorm {
592    /// Dimension along which to compute the norm.
593    pub dim: usize,
594}
595
596impl WeightNorm {
597    /// Create a new WeightNorm reparametrization.
598    pub fn new(dim: usize) -> Self {
599        Self { dim }
600    }
601
602    /// Decompose a weight tensor into `(g, v)` where `g = ||w||` per slice
603    /// along `self.dim` and `v = w / ||w||`.
604    pub fn apply(
605        &self,
606        weight: &ArrayD<f64>,
607    ) -> Result<(ArrayD<f64>, ArrayD<f64>), NormalizationError> {
608        let ndim = weight.ndim();
609        if ndim == 0 {
610            return Err(NormalizationError::EmptyInput);
611        }
612        if self.dim >= ndim {
613            return Err(NormalizationError::InvalidAxis {
614                axis: self.dim,
615                ndim,
616            });
617        }
618
619        // Compute ||w|| along all axes except self.dim
620        let reduce_axes: Vec<usize> = (0..ndim).filter(|&a| a != self.dim).collect();
621
622        // Compute squared sum along those axes
623        let sq = weight.mapv(|x| x * x);
624        let mut sum_sq = sq;
625        // Reduce in descending order to keep indices valid
626        let mut sorted_axes = reduce_axes.clone();
627        sorted_axes.sort_unstable();
628        sorted_axes.reverse();
629        for &ax in &sorted_axes {
630            sum_sq = sum_sq.sum_axis(Axis(ax));
631        }
632        // g shape: [dim_size]
633        let g = sum_sq.mapv(f64::sqrt);
634
635        // Build broadcast shape for g: all 1s except self.dim
636        let mut bc_shape = vec![1usize; ndim];
637        bc_shape[self.dim] = weight.shape()[self.dim];
638        let g_bc = g
639            .clone()
640            .into_shape_with_order(IxDyn(&bc_shape))
641            .map_err(|_| NormalizationError::ShapeMismatch {
642                expected: bc_shape.clone(),
643                got: g.shape().to_vec(),
644            })?;
645
646        // v = w / ||w||, avoiding division by zero
647        let v = weight / &g_bc.mapv(|val| if val.abs() < 1e-12 { 1e-12 } else { val });
648
649        Ok((g, v))
650    }
651
652    /// Reparametrize: given `(g, v)` compute `g * v / ||v||`.
653    pub fn reparametrize(
654        g: &ArrayD<f64>,
655        v: &ArrayD<f64>,
656        dim: usize,
657    ) -> Result<ArrayD<f64>, NormalizationError> {
658        let ndim = v.ndim();
659        if ndim == 0 {
660            return Err(NormalizationError::EmptyInput);
661        }
662        if dim >= ndim {
663            return Err(NormalizationError::InvalidAxis { axis: dim, ndim });
664        }
665
666        // Compute ||v|| along all axes except dim
667        let reduce_axes: Vec<usize> = (0..ndim).filter(|&a| a != dim).collect();
668        let sq = v.mapv(|x| x * x);
669        let mut sum_sq = sq;
670        let mut sorted_axes = reduce_axes;
671        sorted_axes.sort_unstable();
672        sorted_axes.reverse();
673        for &ax in &sorted_axes {
674            sum_sq = sum_sq.sum_axis(Axis(ax));
675        }
676        let v_norm = sum_sq.mapv(f64::sqrt);
677
678        // Broadcast g and v_norm to weight shape
679        let mut bc_shape = vec![1usize; ndim];
680        bc_shape[dim] = v.shape()[dim];
681
682        let g_bc = g
683            .clone()
684            .into_shape_with_order(IxDyn(&bc_shape))
685            .map_err(|_| NormalizationError::ShapeMismatch {
686                expected: bc_shape.clone(),
687                got: g.shape().to_vec(),
688            })?;
689        let v_norm_bc = v_norm
690            .into_shape_with_order(IxDyn(&bc_shape))
691            .map_err(|_| NormalizationError::ShapeMismatch {
692                expected: bc_shape.clone(),
693                got: vec![],
694            })?;
695
696        let v_norm_safe = v_norm_bc.mapv(|val| if val.abs() < 1e-12 { 1e-12 } else { val });
697        Ok(v * &g_bc / &v_norm_safe)
698    }
699}
700
701// ---------------------------------------------------------------------------
702// NormStats
703// ---------------------------------------------------------------------------
704
705/// Normalization statistics for debugging and monitoring.
706#[derive(Debug, Clone)]
707pub struct NormStats {
708    /// Mean of the input tensor.
709    pub input_mean: f64,
710    /// Standard deviation of the input tensor.
711    pub input_std: f64,
712    /// Mean of the output tensor.
713    pub output_mean: f64,
714    /// Standard deviation of the output tensor.
715    pub output_std: f64,
716    /// Mean of the gamma (scale) parameter.
717    pub gamma_mean: f64,
718    /// Mean of the beta (shift) parameter.
719    pub beta_mean: f64,
720}
721
722impl NormStats {
723    /// Compute normalization statistics from input, output, gamma, and beta tensors.
724    pub fn compute(
725        input: &ArrayD<f64>,
726        output: &ArrayD<f64>,
727        gamma: &ArrayD<f64>,
728        beta: &ArrayD<f64>,
729    ) -> Self {
730        let input_mean = Self::array_mean(input);
731        let input_std = Self::array_std(input, input_mean);
732        let output_mean = Self::array_mean(output);
733        let output_std = Self::array_std(output, output_mean);
734        let gamma_mean = Self::array_mean(gamma);
735        let beta_mean = Self::array_mean(beta);
736
737        Self {
738            input_mean,
739            input_std,
740            output_mean,
741            output_std,
742            gamma_mean,
743            beta_mean,
744        }
745    }
746
747    /// Produce a human-readable summary string.
748    pub fn summary(&self) -> String {
749        format!(
750            "NormStats {{ input: mean={:.6}, std={:.6} | output: mean={:.6}, std={:.6} | gamma_mean={:.6}, beta_mean={:.6} }}",
751            self.input_mean, self.input_std,
752            self.output_mean, self.output_std,
753            self.gamma_mean, self.beta_mean,
754        )
755    }
756
757    // -- private helpers --
758
759    fn array_mean(arr: &ArrayD<f64>) -> f64 {
760        if arr.is_empty() {
761            return 0.0;
762        }
763        arr.sum() / arr.len() as f64
764    }
765
766    fn array_std(arr: &ArrayD<f64>, mean: f64) -> f64 {
767        if arr.len() <= 1 {
768            return 0.0;
769        }
770        let var = arr.mapv(|x| (x - mean).powi(2)).sum() / arr.len() as f64;
771        var.sqrt()
772    }
773}
774
775// ===========================================================================
776// Tests
777// ===========================================================================
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use ndarray::ArrayD;
783
784    fn make_input_4d(batch: usize, channels: usize, h: usize, w: usize) -> ArrayD<f64> {
785        let total = batch * channels * h * w;
786        ArrayD::from_shape_vec(
787            IxDyn(&[batch, channels, h, w]),
788            (0..total).map(|i| (i as f64) * 0.01 + 0.1).collect(),
789        )
790        .expect("test helper: shape matches element count")
791    }
792
793    fn make_input_2d(rows: usize, cols: usize) -> ArrayD<f64> {
794        let total = rows * cols;
795        ArrayD::from_shape_vec(
796            IxDyn(&[rows, cols]),
797            (0..total).map(|i| (i as f64) * 0.05 + 0.5).collect(),
798        )
799        .expect("test helper: shape matches element count")
800    }
801
802    // -----------------------------------------------------------------------
803    // RmsNorm tests
804    // -----------------------------------------------------------------------
805
806    #[test]
807    fn test_rmsnorm_new_valid() {
808        let rms = RmsNorm::new(vec![64], 1e-5);
809        assert!(rms.is_ok());
810        let rms = rms.expect("already checked");
811        assert_eq!(rms.normalized_shape, vec![64]);
812    }
813
814    #[test]
815    fn test_rmsnorm_forward_shape_preserved() {
816        let rms = RmsNorm::new(vec![8], 1e-5).expect("valid config");
817        let input = make_input_2d(4, 8);
818        let output = rms.forward(&input).expect("forward should succeed");
819        assert_eq!(output.shape(), input.shape());
820    }
821
822    #[test]
823    fn test_rmsnorm_output_scale() {
824        // After RMSNorm with gamma=1, the RMS of the output along the last dim
825        // should be close to 1 for each row.
826        let rms = RmsNorm::new(vec![16], 1e-8).expect("valid config");
827        let input = ArrayD::from_shape_vec(
828            IxDyn(&[2, 16]),
829            (0..32).map(|i| (i as f64) * 0.1 + 1.0).collect(),
830        )
831        .expect("test data");
832
833        let output = rms.forward(&input).expect("forward");
834        // Check RMS of each row is close to 1
835        for row_idx in 0..2 {
836            let mut sum_sq = 0.0;
837            for col_idx in 0..16 {
838                let v = output[[row_idx, col_idx]];
839                sum_sq += v * v;
840            }
841            let row_rms = (sum_sq / 16.0).sqrt();
842            assert!(
843                (row_rms - 1.0).abs() < 0.1,
844                "RMS should be close to 1, got {}",
845                row_rms
846            );
847        }
848    }
849
850    #[test]
851    fn test_rmsnorm_update_gamma() {
852        let mut rms = RmsNorm::new(vec![4], 1e-5).expect("valid");
853        let new_gamma =
854            ArrayD::from_shape_vec(IxDyn(&[4]), vec![2.0, 2.0, 2.0, 2.0]).expect("test data");
855        assert!(rms.update_gamma(new_gamma).is_ok());
856        assert!((rms.gamma[[0]] - 2.0).abs() < 1e-10);
857    }
858
859    // -----------------------------------------------------------------------
860    // GroupNorm tests
861    // -----------------------------------------------------------------------
862
863    #[test]
864    fn test_groupnorm_new_valid() {
865        let gn = GroupNorm::new(4, 16, 1e-5, true);
866        assert!(gn.is_ok());
867    }
868
869    #[test]
870    fn test_groupnorm_invalid_groups() {
871        let gn = GroupNorm::new(5, 16, 1e-5, true);
872        assert!(gn.is_err());
873        match gn {
874            Err(NormalizationError::InvalidNumGroups { groups, channels }) => {
875                assert_eq!(groups, 5);
876                assert_eq!(channels, 16);
877            }
878            _ => panic!("Expected InvalidNumGroups error"),
879        }
880    }
881
882    #[test]
883    fn test_groupnorm_forward_shape_preserved() {
884        let gn = GroupNorm::new(4, 8, 1e-5, true).expect("valid");
885        let input = make_input_4d(2, 8, 4, 4);
886        let output = gn.forward(&input).expect("forward");
887        assert_eq!(output.shape(), input.shape());
888    }
889
890    #[test]
891    fn test_groupnorm_channels_per_group() {
892        let gn = GroupNorm::new(4, 16, 1e-5, true).expect("valid");
893        assert_eq!(gn.channels_per_group(), 4);
894    }
895
896    // -----------------------------------------------------------------------
897    // InstanceNorm tests
898    // -----------------------------------------------------------------------
899
900    #[test]
901    fn test_instancenorm_new_valid() {
902        let ins = InstanceNorm::new(8, 1e-5, true);
903        assert!(ins.is_ok());
904    }
905
906    #[test]
907    fn test_instancenorm_forward_shape_preserved() {
908        let ins = InstanceNorm::new(4, 1e-5, true).expect("valid");
909        let input = make_input_4d(2, 4, 3, 3);
910        let output = ins.forward(&input).expect("forward");
911        assert_eq!(output.shape(), input.shape());
912    }
913
914    #[test]
915    fn test_instancenorm_normalizes_per_instance() {
916        let ins = InstanceNorm::new(2, 1e-8, false).expect("valid");
917        let input = ArrayD::from_shape_vec(
918            IxDyn(&[2, 2, 4]),
919            (0..16).map(|i| (i as f64) * 0.5 + 1.0).collect(),
920        )
921        .expect("test data");
922
923        let output = ins.forward(&input).expect("forward");
924        // Each (batch, channel) slice should have mean close to 0
925        for b in 0..2 {
926            for c in 0..2 {
927                let mut sum = 0.0;
928                for s in 0..4 {
929                    sum += output[[b, c, s]];
930                }
931                let slice_mean = sum / 4.0;
932                assert!(
933                    slice_mean.abs() < 0.01,
934                    "Expected ~0 mean, got {} at b={}, c={}",
935                    slice_mean,
936                    b,
937                    c
938                );
939            }
940        }
941    }
942
943    // -----------------------------------------------------------------------
944    // BatchNorm tests
945    // -----------------------------------------------------------------------
946
947    #[test]
948    fn test_batchnorm_new_valid() {
949        let bn = BatchNorm::new(16, 1e-5, 0.1, true);
950        assert!(bn.is_ok());
951        let bn = bn.expect("valid");
952        assert!(bn.is_training());
953    }
954
955    #[test]
956    fn test_batchnorm_forward_training() {
957        let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
958        let input = make_input_4d(2, 4, 3, 3);
959        let output = bn.forward(&input).expect("forward");
960        assert_eq!(output.shape(), input.shape());
961    }
962
963    #[test]
964    fn test_batchnorm_running_stats_update() {
965        let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
966        let initial_mean = bn.running_mean.clone();
967        let input = make_input_4d(2, 4, 3, 3);
968        let _output = bn.forward(&input).expect("forward");
969        // Running mean should have changed
970        assert_ne!(bn.running_mean, initial_mean);
971        assert_eq!(bn.num_batches_tracked, 1);
972    }
973
974    #[test]
975    fn test_batchnorm_eval_mode() {
976        let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
977        // Train once to populate running stats
978        let input = make_input_4d(2, 4, 3, 3);
979        let _output = bn.forward(&input).expect("forward training");
980
981        // Switch to eval and run again -- should use running stats
982        bn.eval_mode();
983        assert!(!bn.is_training());
984        let batches_before = bn.num_batches_tracked;
985        let output = bn.forward(&input).expect("forward eval");
986        assert_eq!(output.shape(), input.shape());
987        // num_batches_tracked should not change in eval mode
988        assert_eq!(bn.num_batches_tracked, batches_before);
989    }
990
991    #[test]
992    fn test_batchnorm_train_eval_toggle() {
993        let mut bn = BatchNorm::new(4, 1e-5, 0.1, true).expect("valid");
994        assert!(bn.is_training());
995        bn.eval_mode();
996        assert!(!bn.is_training());
997        bn.train_mode();
998        assert!(bn.is_training());
999    }
1000
1001    // -----------------------------------------------------------------------
1002    // WeightNorm tests
1003    // -----------------------------------------------------------------------
1004
1005    #[test]
1006    fn test_weightnorm_apply() {
1007        let wn = WeightNorm::new(0);
1008        let weight = ArrayD::from_shape_vec(
1009            IxDyn(&[3, 4]),
1010            (0..12).map(|i| (i as f64) * 0.1 + 0.1).collect(),
1011        )
1012        .expect("test data");
1013
1014        let (g, v) = wn.apply(&weight).expect("apply");
1015        // g should have shape [3] (one norm per row)
1016        assert_eq!(g.shape(), &[3]);
1017        // v should have same shape as weight
1018        assert_eq!(v.shape(), weight.shape());
1019    }
1020
1021    #[test]
1022    fn test_weightnorm_reparametrize() {
1023        let wn = WeightNorm::new(0);
1024        let weight = ArrayD::from_shape_vec(
1025            IxDyn(&[3, 4]),
1026            (0..12).map(|i| (i as f64) * 0.3 + 0.5).collect(),
1027        )
1028        .expect("test data");
1029
1030        let (g, v) = wn.apply(&weight).expect("apply");
1031        let reconstructed = WeightNorm::reparametrize(&g, &v, 0).expect("reparametrize");
1032
1033        assert_eq!(reconstructed.shape(), weight.shape());
1034        // Reconstructed weight should be close to original
1035        for (orig, recon) in weight.iter().zip(reconstructed.iter()) {
1036            assert!(
1037                (orig - recon).abs() < 1e-8,
1038                "Mismatch: orig={}, recon={}",
1039                orig,
1040                recon
1041            );
1042        }
1043    }
1044
1045    // -----------------------------------------------------------------------
1046    // NormStats tests
1047    // -----------------------------------------------------------------------
1048
1049    #[test]
1050    fn test_norm_stats_compute() {
1051        let input = make_input_2d(4, 8);
1052        let output = make_input_2d(4, 8);
1053        let gamma = ArrayD::ones(IxDyn(&[8]));
1054        let beta = ArrayD::zeros(IxDyn(&[8]));
1055
1056        let stats = NormStats::compute(&input, &output, &gamma, &beta);
1057        // All fields should be populated (finite)
1058        assert!(stats.input_mean.is_finite());
1059        assert!(stats.input_std.is_finite());
1060        assert!(stats.output_mean.is_finite());
1061        assert!(stats.output_std.is_finite());
1062        assert!(stats.gamma_mean.is_finite());
1063        assert!(stats.beta_mean.is_finite());
1064    }
1065
1066    #[test]
1067    fn test_norm_stats_summary_nonempty() {
1068        let input = make_input_2d(2, 4);
1069        let output = make_input_2d(2, 4);
1070        let gamma = ArrayD::ones(IxDyn(&[4]));
1071        let beta = ArrayD::zeros(IxDyn(&[4]));
1072
1073        let stats = NormStats::compute(&input, &output, &gamma, &beta);
1074        let summary = stats.summary();
1075        assert!(!summary.is_empty());
1076        assert!(summary.contains("NormStats"));
1077    }
1078}