svod_tensor/nn/norm.rs
1//! Normalization: layernorm, rms_norm, group_norm.
2
3use bon::bon;
4use snafu::ResultExt;
5use svod_dtype::DType;
6use svod_ir::{ConstValue, UOp};
7
8use crate::Tensor;
9use crate::error::{NdimMinimumSnafu, ParamRangeSnafu, UOpSnafu};
10use crate::reduce::AxisSpec;
11
12type Result<T> = crate::Result<T>;
13
14#[bon]
15impl Tensor {
16 /// Layer normalization over axes `[axis..ndim)`. Casts to f32 internally
17 /// for numerical stability.
18 ///
19 /// Normalizes the input so that the slice along the specified trailing axes
20 /// has zero mean and unit variance, then returns the result cast back to
21 /// the original dtype.
22 ///
23 /// # Examples
24 ///
25 /// ```
26 /// # use svod_tensor::Tensor;
27 /// # use ndarray::array;
28 /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
29 /// let mut y = x.layernorm(-1, 1e-5).unwrap();
30 /// y.realize().unwrap();
31 /// let vals = y.as_vec::<f32>().unwrap();
32 /// // Each row is independently normalized to mean~0, std~1
33 /// assert!((vals[0] + vals[1] + vals[2]).abs() < 1e-5);
34 /// ```
35 pub fn layernorm(&self, axis: isize, eps: f64) -> Result<Tensor> {
36 let (normed, _, _) = self.layernorm_with_stats(axis, eps)?;
37 Ok(normed)
38 }
39
40 /// Layer normalization returning `(normalized, mean, inv_std_dev)`.
41 ///
42 /// Computes in f32 for numerical stability (matches ONNX `stash_type=1`).
43 /// The `mean` and `inv_std_dev` tensors remain in f32 regardless of input dtype.
44 ///
45 /// # Examples
46 ///
47 /// ```
48 /// # use svod_tensor::Tensor;
49 /// # use ndarray::array;
50 /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0]]);
51 /// let (_normed, mut mean, _inv_std) = x.layernorm_with_stats(-1, 1e-5).unwrap();
52 /// mean.realize().unwrap();
53 /// let mean_val = mean.as_vec::<f32>().unwrap();
54 /// assert!((mean_val[0] - 2.0).abs() < 1e-5);
55 /// ```
56 pub fn layernorm_with_stats(&self, axis: isize, eps: f64) -> Result<(Tensor, Tensor, Tensor)> {
57 let ndim = self.ndim()?;
58 let norm_axis = Tensor::normalize_axis(axis, ndim)?;
59 let axes: Vec<isize> = (norm_axis..ndim).map(|a| a as isize).collect();
60 let axes_spec = AxisSpec::Multiple(axes);
61
62 let original_dtype = self.uop().dtype();
63 let x32 = if original_dtype != DType::Float32 { self.cast(DType::Float32)? } else { self.clone() };
64
65 let mean = x32.mean_with().axes(axes_spec.clone()).keepdim(true).call()?;
66 let centered = x32.try_sub(&mean)?;
67 let variance = centered.square()?.mean_with().axes(axes_spec).keepdim(true).call()?;
68 let eps_t = Tensor::new(UOp::const_(DType::Float32, ConstValue::Float(eps)));
69 let inv_std = variance.try_add(&eps_t)?.try_rsqrt()?;
70 let normalized = centered.try_mul(&inv_std)?;
71
72 let normalized = if original_dtype != DType::Float32 { normalized.cast(original_dtype)? } else { normalized };
73 Ok((normalized, mean, inv_std))
74 }
75
76 /// RMS normalization over axes `[axis..ndim)`.
77 ///
78 /// Like layernorm but without mean subtraction: divides each element by the
79 /// root-mean-square of its slice. Computes the normalization factor in f32,
80 /// then multiplies the original (unconverted) input.
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// # use svod_tensor::Tensor;
86 /// # use ndarray::array;
87 /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0]]);
88 /// let mut y = x.rms_norm(-1, 1e-5).unwrap();
89 /// y.realize().unwrap();
90 /// let vals = y.as_vec::<f32>().unwrap();
91 /// // RMS of [1,2,3] = sqrt((1+4+9)/3) ≈ 2.16
92 /// // Output ≈ [0.46, 0.93, 1.39]
93 /// assert!((vals[0] - 1.0 / (14.0f32 / 3.0).sqrt()).abs() < 1e-4);
94 /// ```
95 pub fn rms_norm(&self, axis: isize, eps: f64) -> Result<Tensor> {
96 let ndim = self.ndim()?;
97 let norm_axis = Tensor::normalize_axis(axis, ndim)?;
98 let axes: Vec<isize> = (norm_axis..ndim).map(|a| a as isize).collect();
99 let axes_spec = AxisSpec::Multiple(axes);
100
101 let original_dtype = self.uop().dtype();
102 let x32 = if original_dtype != DType::Float32 { self.cast(DType::Float32)? } else { self.clone() };
103
104 let norm = x32
105 .square()?
106 .mean_with()
107 .axes(axes_spec)
108 .keepdim(true)
109 .call()?
110 .try_add(&Tensor::new(UOp::const_(DType::Float32, ConstValue::Float(eps))))?
111 .try_rsqrt()?;
112
113 self.try_mul(&norm)
114 }
115
116 /// Lp normalization along an axis.
117 ///
118 /// Divides each element by the Lp norm of its slice along `axis`,
119 /// so that every such slice has unit Lp norm. Only `p=1` (L1) and
120 /// `p=2` (L2) are implemented; any `p != 1` defaults to L2.
121 ///
122 /// # Examples
123 ///
124 /// L2 normalization (default `p=2`):
125 ///
126 /// ```
127 /// # use svod_tensor::Tensor;
128 /// # use ndarray::array;
129 /// let x = Tensor::from_ndarray(&array![[3.0f32, 4.0]]);
130 /// let mut y = x.lp_normalize(-1, 2).unwrap();
131 /// y.realize().unwrap();
132 /// let vals = y.as_vec::<f32>().unwrap();
133 /// // L2 norm of [3,4] = 5, so output ≈ [0.6, 0.8]
134 /// assert!((vals[0] - 0.6).abs() < 1e-5);
135 /// assert!((vals[1] - 0.8).abs() < 1e-5);
136 /// ```
137 ///
138 /// L1 normalization (`p=1`):
139 ///
140 /// ```
141 /// # use svod_tensor::Tensor;
142 /// # use ndarray::array;
143 /// let x = Tensor::from_ndarray(&array![[3.0f32, 4.0]]);
144 /// let mut y = x.lp_normalize(-1, 1).unwrap();
145 /// y.realize().unwrap();
146 /// let vals = y.as_vec::<f32>().unwrap();
147 /// // L1 norm of [3,4] = 7, so output ≈ [3/7, 4/7]
148 /// assert!((vals[0] - 3.0 / 7.0).abs() < 1e-5);
149 /// ```
150 pub fn lp_normalize(&self, axis: isize, p: i64) -> Result<Tensor> {
151 let norm = match p {
152 1 => self.try_abs()?.sum_with().axes(AxisSpec::Single(axis)).keepdim(true).call()?,
153 _ => self.square()?.sum_with().axes(AxisSpec::Single(axis)).keepdim(true).call()?.try_sqrt()?,
154 };
155 let eps = self.uop().dtype().base().min_positive();
156 self.try_div(&norm.try_add(&Tensor::const_(eps, self.uop().dtype()))?)
157 }
158
159 /// Mean Variance Normalization.
160 ///
161 /// Subtracts the mean and divides by the population standard deviation
162 /// (plus `eps`) over the given axes. Implements the ONNX
163 /// `MeanVarianceNormalization` operator.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// # use svod_tensor::Tensor;
169 /// # use ndarray::array;
170 /// let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
171 /// let mut y = x.mean_variance_normalize(&[0, 1], 1e-5).unwrap();
172 /// y.realize().unwrap();
173 /// let vals = y.as_vec::<f32>().unwrap();
174 /// // Global mean = 3.5, std ≈ 1.708
175 /// assert!((vals[0] - (1.0 - 3.5) / (35.0f32 / 12.0).sqrt()).abs() < 1e-4);
176 /// assert!(vals[0] < 0.0);
177 /// assert!(vals[5] > 0.0);
178 /// ```
179 pub fn mean_variance_normalize(&self, axes: &[isize], eps: f64) -> Result<Tensor> {
180 let axes_spec = AxisSpec::Multiple(axes.to_vec());
181 let mean = self.mean_with().axes(axes_spec.clone()).keepdim(true).call()?;
182 let centered = self.try_sub(&mean)?;
183 let pop_std = centered.square()?.mean_with().axes(axes_spec).keepdim(true).call()?.try_sqrt()?;
184 let eps = Tensor::const_(eps, self.uop().dtype());
185 centered.try_div(&pop_std.try_add(&eps)?)
186 }
187
188 /// Group normalization: reshape into groups, layernorm each group, then
189 /// apply per-channel scale and bias.
190 ///
191 /// Input must be at least 2-D with shape `[N, C, ...]`. Channels are split
192 /// into `num_groups` groups and each group is independently normalized.
193 /// Casts to f32 internally for numerical stability.
194 ///
195 /// # Examples
196 ///
197 /// ```
198 /// # use svod_tensor::Tensor;
199 /// # use ndarray::Array4;
200 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 2, 2), 1.0f32));
201 /// let scale = Tensor::from_slice([1.0f32; 4]);
202 /// let bias = Tensor::from_slice([0.0f32; 4]);
203 /// let y = x.group_norm().scale(&scale).bias(&bias).num_groups(2).call().unwrap();
204 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
205 /// assert_eq!(shape, [1, 4, 2, 2]);
206 /// ```
207 ///
208 /// Custom epsilon:
209 ///
210 /// ```
211 /// # use svod_tensor::Tensor;
212 /// # use ndarray::Array4;
213 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 4, 2, 2), 1.0f32));
214 /// let scale = Tensor::from_slice([1.0f32; 4]);
215 /// let bias = Tensor::from_slice([0.0f32; 4]);
216 /// let y = x.group_norm().scale(&scale).bias(&bias).num_groups(2).eps(1e-6).call().unwrap();
217 /// let shape: Vec<_> = y.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
218 /// assert_eq!(shape, [1, 4, 2, 2]);
219 /// ```
220 #[builder]
221 pub fn group_norm(
222 &self,
223 scale: &Tensor,
224 bias: &Tensor,
225 num_groups: usize,
226 #[builder(default = 1e-5)] eps: f64,
227 ) -> Result<Tensor> {
228 let x_shape = self.shape()?;
229 let ndim = x_shape.len();
230 snafu::ensure!(ndim >= 2, NdimMinimumSnafu { op: "group_norm", min: 2_usize, actual: ndim });
231 snafu::ensure!(
232 num_groups > 0,
233 ParamRangeSnafu { op: "group_norm", param: "num_groups", value: num_groups.to_string(), constraint: "> 0" }
234 );
235 let batch = x_shape[0].as_const().unwrap();
236
237 // Reshape to (batch, num_groups, -1), cast to f32 before layernorm
238 let reshaped = self.try_reshape([batch as isize, num_groups as isize, -1])?;
239 let reshaped = if reshaped.uop().dtype() != DType::Float32 { reshaped.cast(DType::Float32)? } else { reshaped };
240 let normed = reshaped.layernorm(-1, eps)?;
241 // Cast back and reshape to original
242 let normed = if self.uop().dtype() != DType::Float32 { normed.cast(self.uop().dtype())? } else { normed };
243 let orig_shape = svod_ir::shape::to_vec_isize(&x_shape).context(UOpSnafu)?;
244 let normed = normed.try_reshape(&orig_shape)?;
245
246 // Scale and bias: reshape to (1, C, 1, 1, ...)
247 let mut sb_shape: Vec<isize> = vec![1, -1];
248 sb_shape.extend(std::iter::repeat_n(1isize, ndim - 2));
249 let scale = scale.try_reshape(&sb_shape)?;
250 let bias = bias.try_reshape(&sb_shape)?;
251 normed.try_mul(&scale)?.try_add(&bias)
252 }
253}