Skip to main content

svod_tensor/nn/
norm.rs

1//! Normalization: layernorm, rms_norm, group_norm.
2
3use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::{ConstValue, UOp};
7
8use crate::Tensor;
9use crate::error::{NdimMinimumSnafu, ParamRangeSnafu, UOpSnafu};
10use crate::reduce::AxisSpec;
11
12type Result<T> = crate::Result<T>;
13
14#[bon]
15impl Tensor {
16    /// Layer normalization over axes `[axis..ndim)`. Casts to f32 internally
17    /// for numerical stability.
18    ///
19    /// Normalizes the input so that the slice along the specified trailing axes
20    /// has zero mean and unit variance, then returns the result cast back to
21    /// the original dtype.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// # use svod_tensor::Tensor;
27    /// # use ndarray::array;
28    /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
29    /// let mut y = x.layernorm(-1, 1e-5).unwrap();
30    /// y.realize().unwrap();
31    /// let vals = y.as_vec::<f32>().unwrap();
32    /// // Each row is independently normalized to mean~0, std~1
33    /// assert!((vals[0] + vals[1] + vals[2]).abs() < 1e-5);
34    /// ```
35    pub fn layernorm(&self, axis: isize, eps: f64) -> Result<Tensor> {
36        let (normed, _, _) = self.layernorm_with_stats(axis, eps)?;
37        Ok(normed)
38    }
39
40    /// Layer normalization returning `(normalized, mean, inv_std_dev)`.
41    ///
42    /// Computes in f32 for numerical stability (matches ONNX `stash_type=1`).
43    /// The `mean` and `inv_std_dev` tensors remain in f32 regardless of input dtype.
44    ///
45    /// # Examples
46    ///
47    /// ```
48    /// # use svod_tensor::Tensor;
49    /// # use ndarray::array;
50    /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0]]);
51    /// let (_normed, mut mean, _inv_std) = x.layernorm_with_stats(-1, 1e-5).unwrap();
52    /// mean.realize().unwrap();
53    /// let mean_val = mean.as_vec::<f32>().unwrap();
54    /// assert!((mean_val[0] - 2.0).abs() < 1e-5);
55    /// ```
56    pub fn layernorm_with_stats(&self, axis: isize, eps: f64) -> Result<(Tensor, Tensor, Tensor)> {
57        let ndim = self.ndim()?;
58        let norm_axis = Tensor::normalize_axis(axis, ndim)?;
59        let axes: Vec<isize> = (norm_axis..ndim).map(|a| a as isize).collect();
60        let axes_spec = AxisSpec::Multiple(axes);
61
62        let original_dtype = self.uop().dtype();
63        let x32 = if original_dtype != DType::Float32 { self.cast(DType::Float32)? } else { self.clone() };
64
65        let mean = x32.mean_with().axes(axes_spec.clone()).keepdim(true).call()?;
66        let centered = x32.try_sub(&mean)?;
67        let variance = centered.square()?.mean_with().axes(axes_spec).keepdim(true).call()?;
68        let eps_t = Tensor::new(UOp::const_(DType::Float32, ConstValue::Float(eps)));
69        let inv_std = variance.try_add(&eps_t)?.try_rsqrt()?;
70        let normalized = centered.try_mul(&inv_std)?;
71
72        let normalized = if original_dtype != DType::Float32 { normalized.cast(original_dtype)? } else { normalized };
73        Ok((normalized, mean, inv_std))
74    }
75
76    /// RMS normalization over axes `[axis..ndim)`.
77    ///
78    /// Like layernorm but without mean subtraction: divides each element by the
79    /// root-mean-square of its slice. Computes the normalization factor in f32,
80    /// then multiplies the original (unconverted) input.
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// # use svod_tensor::Tensor;
86    /// # use ndarray::array;
87    /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0]]);
88    /// let mut y = x.rms_norm(-1, 1e-5).unwrap();
89    /// y.realize().unwrap();
90    /// let vals = y.as_vec::<f32>().unwrap();
91    /// // RMS of [1,2,3] = sqrt((1+4+9)/3) ≈ 2.16
92    /// // Output ≈ [0.46, 0.93, 1.39]
93    /// assert!((vals[0] - 1.0 / (14.0f32 / 3.0).sqrt()).abs() < 1e-4);
94    /// ```
95    pub fn rms_norm(&self, axis: isize, eps: f64) -> Result<Tensor> {
96        let ndim = self.ndim()?;
97        let norm_axis = Tensor::normalize_axis(axis, ndim)?;
98        let axes: Vec<isize> = (norm_axis..ndim).map(|a| a as isize).collect();
99        let axes_spec = AxisSpec::Multiple(axes);
100
101        let original_dtype = self.uop().dtype();
102        let x32 = if original_dtype != DType::Float32 { self.cast(DType::Float32)? } else { self.clone() };
103
104        let norm = x32
105            .square()?
106            .mean_with()
107            .axes(axes_spec)
108            .keepdim(true)
109            .call()?
110            .try_add(&Tensor::new(UOp::const_(DType::Float32, ConstValue::Float(eps))))?
111            .try_rsqrt()?;
112
113        self.try_mul(&norm)
114    }
115
116    /// Lp normalization along an axis.
117    ///
118    /// Divides each element by the Lp norm of its slice along `axis`,
119    /// so that every such slice has unit Lp norm. Only `p=1` (L1) and
120    /// `p=2` (L2) are implemented; any `p != 1` defaults to L2.
121    ///
122    /// # Examples
123    ///
124    /// L2 normalization (default `p=2`):
125    ///
126    /// ```
127    /// # use svod_tensor::Tensor;
128    /// # use ndarray::array;
129    /// let x = Tensor::from_ndarray(&array![[3.0f32, 4.0]]);
130    /// let mut y = x.lp_normalize(-1, 2).unwrap();
131    /// y.realize().unwrap();
132    /// let vals = y.as_vec::<f32>().unwrap();
133    /// // L2 norm of [3,4] = 5, so output ≈ [0.6, 0.8]
134    /// assert!((vals[0] - 0.6).abs() < 1e-5);
135    /// assert!((vals[1] - 0.8).abs() < 1e-5);
136    /// ```
137    ///
138    /// L1 normalization (`p=1`):
139    ///
140    /// ```
141    /// # use svod_tensor::Tensor;
142    /// # use ndarray::array;
143    /// let x = Tensor::from_ndarray(&array![[3.0f32, 4.0]]);
144    /// let mut y = x.lp_normalize(-1, 1).unwrap();
145    /// y.realize().unwrap();
146    /// let vals = y.as_vec::<f32>().unwrap();
147    /// // L1 norm of [3,4] = 7, so output ≈ [3/7, 4/7]
148    /// assert!((vals[0] - 3.0 / 7.0).abs() < 1e-5);
149    /// ```
150    pub fn lp_normalize(&self, axis: isize, p: i64) -> Result<Tensor> {
151        let norm = match p {
152            1 => self.try_abs()?.sum_with().axes(AxisSpec::Single(axis)).keepdim(true).call()?,
153            _ => self.square()?.sum_with().axes(AxisSpec::Single(axis)).keepdim(true).call()?.try_sqrt()?,
154        };
155        let eps = self.uop().dtype().base().min_positive();
156        self.try_div(&norm.try_add(&Tensor::const_(eps, self.uop().dtype()))?)
157    }
158
159    /// Mean Variance Normalization.
160    ///
161    /// Subtracts the mean and divides by the population standard deviation
162    /// (plus `eps`) over the given axes. Implements the ONNX
163    /// `MeanVarianceNormalization` operator.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// # use svod_tensor::Tensor;
169    /// # use ndarray::array;
170    /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
171    /// let mut y = x.mean_variance_normalize(&[0, 1], 1e-5).unwrap();
172    /// y.realize().unwrap();
173    /// let vals = y.as_vec::<f32>().unwrap();
174    /// // Global mean = 3.5, std ≈ 1.708
175    /// assert!((vals[0] - (1.0 - 3.5) / (35.0f32 / 12.0).sqrt()).abs() < 1e-4);
176    /// assert!(vals[0] < 0.0);
177    /// assert!(vals[5] > 0.0);
178    /// ```
179    pub fn mean_variance_normalize(&self, axes: &[isize], eps: f64) -> Result<Tensor> {
180        let axes_spec = AxisSpec::Multiple(axes.to_vec());
181        let mean = self.mean_with().axes(axes_spec.clone()).keepdim(true).call()?;
182        let centered = self.try_sub(&mean)?;
183        let pop_std = centered.square()?.mean_with().axes(axes_spec).keepdim(true).call()?.try_sqrt()?;
184        let eps = Tensor::const_(eps, self.uop().dtype());
185        centered.try_div(&pop_std.try_add(&eps)?)
186    }
187
188    /// Group normalization: reshape into groups, layernorm each group, then
189    /// apply per-channel scale and bias.
190    ///
191    /// Input must be at least 2-D with shape `[N, C, ...]`. Channels are split
192    /// into `num_groups` groups and each group is independently normalized.
193    /// Casts to f32 internally for numerical stability.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// # use svod_tensor::Tensor;
199    /// # use ndarray::Array4;
200    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 2, 2), 1.0f32));
201    /// let scale = Tensor::from_slice([1.0f32; 4]);
202    /// let bias = Tensor::from_slice([0.0f32; 4]);
203    /// let y = x.group_norm().scale(&scale).bias(&bias).num_groups(2).call().unwrap();
204    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
205    /// assert_eq!(shape, [1, 4, 2, 2]);
206    /// ```
207    ///
208    /// Custom epsilon:
209    ///
210    /// ```
211    /// # use svod_tensor::Tensor;
212    /// # use ndarray::Array4;
213    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 2, 2), 1.0f32));
214    /// let scale = Tensor::from_slice([1.0f32; 4]);
215    /// let bias = Tensor::from_slice([0.0f32; 4]);
216    /// let y = x.group_norm().scale(&scale).bias(&bias).num_groups(2).eps(1e-6).call().unwrap();
217    /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
218    /// assert_eq!(shape, [1, 4, 2, 2]);
219    /// ```
220    #[builder]
221    pub fn group_norm(
222        &self,
223        scale: &Tensor,
224        bias: &Tensor,
225        num_groups: usize,
226        #[builder(default = 1e-5)] eps: f64,
227    ) -> Result<Tensor> {
228        let x_shape = self.shape()?;
229        let ndim = x_shape.len();
230        snafu::ensure!(ndim >= 2, NdimMinimumSnafu { op: "group_norm", min: 2_usize, actual: ndim });
231        snafu::ensure!(
232            num_groups > 0,
233            ParamRangeSnafu { op: "group_norm", param: "num_groups", value: num_groups.to_string(), constraint: "> 0" }
234        );
235        let batch = x_shape[0].as_const().unwrap();
236
237        // Reshape to (batch, num_groups, -1), cast to f32 before layernorm
238        let reshaped = self.try_reshape([batch as isize, num_groups as isize, -1])?;
239        let reshaped = if reshaped.uop().dtype() != DType::Float32 { reshaped.cast(DType::Float32)? } else { reshaped };
240        let normed = reshaped.layernorm(-1, eps)?;
241        // Cast back and reshape to original
242        let normed = if self.uop().dtype() != DType::Float32 { normed.cast(self.uop().dtype())? } else { normed };
243        let orig_shape = svod_ir::shape::to_vec_isize(&x_shape).context(UOpSnafu)?;
244        let normed = normed.try_reshape(&orig_shape)?;
245
246        // Scale and bias: reshape to (1, C, 1, 1, ...)
247        let mut sb_shape: Vec<isize> = vec![1, -1];
248        sb_shape.extend(std::iter::repeat_n(1isize, ndim - 2));
249        let scale = scale.try_reshape(&sb_shape)?;
250        let bias = bias.try_reshape(&sb_shape)?;
251        normed.try_mul(&scale)?.try_add(&bias)
252    }
253}