Skip to main content

tract_core/ops/nn/
rms_norm.rs

1use crate::internal::*;
2use crate::ops::binary::{BinMiniOp, TypedBinOp};
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Rsqrt};
5use crate::ops::nn::{Reduce, Reducer};
6use tract_itertools::Itertools;
7
8#[derive(Clone, Debug, Hash, PartialEq, Eq)]
9pub struct RmsNorm {
10    pub axis: usize,
11    pub eps: Arc<Tensor>,
12}
13
14impl Op for RmsNorm {
15    fn name(&self) -> StaticName {
16        "RmsNorm".to_string().into()
17    }
18    fn info(&self) -> TractResult<Vec<String>> {
19        Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)])
20    }
21    op_as_typed_op!();
22}
23
24impl EvalOp for RmsNorm {
25    fn is_stateless(&self) -> bool {
26        true
27    }
28
29    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30        let input = args_1!(inputs);
31        let in_dt = input.datum_type();
32
33        // Fast path: F32 or F16 input where the normalised axis is the last
34        // (contiguous) one. Use the fused tract_linalg::rms_norm_f32 kernel
35        // (AVX-512 when available; scalar fallback otherwise) instead of the
36        // 4-call MeanOfSquares + Add + Rsqrt + Mul composition below. ~16-18x
37        // faster on Cascade Lake AVX-512, ~equivalent on the scalar fallback
38        // since the composition is also memory-bandwidth bound.
39        if matches!(in_dt, DatumType::F32 | DatumType::F16)
40            && input.rank() > 0
41            && self.axis == input.rank() - 1
42        {
43            let eps_f32: f32 = self.eps.cast_to_scalar::<f32>()?;
44            let mut buf = input.cast_to::<f32>()?.into_owned();
45            let row_len = buf.shape()[self.axis];
46            if row_len > 0 {
47                let n_rows: usize = buf.shape().iter().take(self.axis).product();
48                let data = unsafe { buf.as_slice_mut_unchecked::<f32>() };
49                let rms_norm = &tract_linalg::ops().rms_norm_f32;
50                for r in 0..n_rows {
51                    let start = r * row_len;
52                    rms_norm(&mut data[start..start + row_len], eps_f32);
53                }
54            }
55            return Ok(tvec![buf.cast_to_dt(in_dt)?.into_owned().into()]);
56        }
57
58        // Slow path: original 4-call composition (kept for non-contiguous axes).
59        let input_f32 = input.cast_to::<f32>()?.into_owned();
60        // eps inherits the input dtype from the declutter pattern (F16 when the
61        // surrounding LayerNorm chain is F16). The MeanOfSquares + Add + Rsqrt
62        // + Mul chain below all runs at F32, so eps must be cast to match —
63        // otherwise the Add::eval call below panics with
64        //   "tensor is F32, accessed as F16"
65        // when input is F16.
66        let eps = self.eps.cast_to::<f32>()?.into_owned();
67        let a1 = Reducer::MeanOfSquares.reduce(&[self.axis], &input_f32)?;
68        let mut a2 = Add.eval(a1.into_tvalue(), eps.into_tvalue(), DatumType::F32)?;
69        Rsqrt {}.eval_in_place(&mut a2, None)?;
70        let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?;
71        Ok(tvec![a3.cast_to_dt(in_dt)?.into_owned().into()])
72    }
73}
74
75impl TypedOp for RmsNorm {
76    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
77        ensure!(self.eps.rank() == 0, "RmsNorm: eps must be a rank-0 tensor");
78        ensure!(
79            self.axis < inputs[0].rank(),
80            "RmsNorm: axis {} is out of bounds for input rank {}",
81            self.axis,
82            inputs[0].rank()
83        );
84        let dt = inputs[0].datum_type;
85        let fact = dt.fact(inputs[0].shape.clone());
86        Ok(tvec!(fact))
87    }
88
89    fn input_roi(
90        &self,
91        model: &TypedModel,
92        node: &TypedNode,
93    ) -> TractResult<Option<TVec<Option<TDim>>>> {
94        crate::optim::propagate_roi::bubble_roi(model, node)
95    }
96
97    fn axes_mapping(
98        &self,
99        inputs: &[&TypedFact],
100        _outputs: &[&TypedFact],
101    ) -> TractResult<AxesMapping> {
102        let rank = inputs[0].rank();
103        let mut letters = 'a'..;
104        let axes = (0..rank)
105            .map(|ix| {
106                Axis::new(letters.next().unwrap(), inputs.len(), 1).input(0, ix).output(0, ix)
107            })
108            .collect_vec();
109        AxesMapping::new(1, 1, axes)
110    }
111
112    fn change_axes(
113        &self,
114        model: &TypedModel,
115        node: &TypedNode,
116        _io: InOut,
117        change: &AxisOp,
118    ) -> TractResult<Option<AxisChangeConsequence>> {
119        if let Some(axis) = change.transform_axis(self.axis) {
120            let op = Some(Box::new(RmsNorm { axis, eps: self.eps.clone() }) as _);
121            Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
122        } else {
123            Ok(None)
124        }
125    }
126
127    fn slice(
128        &self,
129        patch: &mut TypedModelPatch,
130        _model: &TypedModel,
131        node: &TypedNode,
132        _prefix: &str,
133        inputs: &[OutletId],
134        output_axis: usize,
135        _start: &TDim,
136        _end: &TDim,
137    ) -> TractResult<Option<TVec<OutletId>>> {
138        rule_if!(output_axis != self.axis);
139        patch.wire_node(&node.name, self.clone(), inputs).map(Some)
140    }
141
142    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
143        let dt = inputs[0].datum_type;
144        let count: TDim = inputs[0].shape.iter().product();
145        // per element: square + accumulate + mul by rsqrt ≈ 3 FMA
146        // per reduction group: 1 div (rsqrt)
147        let groups: TDim = inputs[0]
148            .shape
149            .iter()
150            .enumerate()
151            .filter(|(i, _)| *i != self.axis)
152            .map(|(_, d)| d)
153            .product();
154        Ok(tvec!((Cost::FMA(dt), count * 3), (Cost::Div(dt), groups)))
155    }
156
157    as_op!();
158}
159
160/// Search pattern => A = A * RSQRT(MEAN_OF_SQUARES(A) + EPS)
161pub fn detect_rms_norm(
162    op: &Reduce,
163    model: &TypedModel,
164    node: &TypedNode,
165) -> TractResult<Option<TypedModelPatch>> {
166    rule_if!(op.reducer == Reducer::MeanOfSquares);
167    rule_if!(op.axes.len() == 1);
168    let axis = op.axes[0];
169
170    let in_fact = model.node_input_facts(node.id)?[0];
171    let dt = in_fact.datum_type;
172
173    // Only F16 and F32 is supported.
174    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
175
176    // Identify Add operator
177    rule_if_some!(add_succ = model.single_succ(node.id)?);
178    rule_if_some!(add_succ_op = add_succ.op_as::<TypedBinOp>());
179    rule_if!(add_succ_op.0.is::<Add>());
180
181    // Retrieve epsilon
182    let add_consts = model.collect_const_inputs(add_succ);
183    rule_if!(add_consts.len() == 1);
184    let eps = add_consts[0].val().clone();
185    rule_if!(eps.len() == 1);
186    rule_if!(eps.datum_type() == dt);
187    let eps = eps.into_tensor().into_shape(&[])?.into_arc_tensor();
188
189    // Identify Rsqrt
190    rule_if_some!(rsqrt_succ = model.single_succ(add_succ.id)?);
191    rule_if_some!(rsqrt_succ_op = rsqrt_succ.op_as::<ElementWiseOp>());
192    rule_if!(rsqrt_succ_op.0.is::<Rsqrt>());
193
194    // Identify Mul: RSQRT(...) * A
195    rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(rsqrt_succ, &node.inputs[0]));
196
197    let mut patch = TypedModelPatch::default();
198    let rsm_input = patch.taps(model, &node.inputs)?;
199    let out =
200        patch.wire_node(format!("{}.rms_norm", node.name), RmsNorm { axis, eps }, &rsm_input)?;
201
202    patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
203    Ok(Some(patch))
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::ops::nn::RmsNorm;
210
211    /// Regression: the declutter pattern (`detect_rms_norm`) stores `eps` with
212    /// the input dtype (F16 when the surrounding LayerNorm chain is F16) — see
213    /// `rule_if!(eps.datum_type() == dt)` above. The eval path runs at F32, so
214    /// it must cast `self.eps` to F32 before using it. Without the cast in
215    /// `RmsNorm::eval`, this test panics with "tensor is F32, accessed as F16".
216    #[test]
217    fn eval_with_f16_eps_and_f16_input() {
218        let to_h = |x: f32| f16::from_f32(x);
219        let input = tensor1(&[to_h(1.0), to_h(2.0), to_h(3.0), to_h(4.0)]);
220        let eps = tensor0(to_h(1e-5)).into_arc_tensor();
221        let op = RmsNorm { axis: 0, eps };
222        let out = op.eval(tvec!(input.clone().into())).expect("eval should not panic");
223        let out = out.into_iter().next().unwrap().into_tensor();
224        assert_eq!(out.datum_type(), DatumType::F16);
225        assert_eq!(out.shape(), &[4]);
226        // Reference: rms = sqrt((1+4+9+16)/4 + eps) = sqrt(7.5 + 1e-5) ≈ 2.7386
227        // normalised: [1, 2, 3, 4] / 2.7386 ≈ [0.365, 0.730, 1.095, 1.461]
228        let got = unsafe { out.as_slice_unchecked::<f16>() };
229        let expected = [0.365_f32, 0.730, 1.095, 1.461];
230        for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
231            let diff = (g.to_f32() - e).abs();
232            assert!(diff < 0.01, "lane {i}: got {} expected {}", g.to_f32(), e);
233        }
234    }
235
236    /// Slow path: when the normalised axis is NOT the trailing one, the fast
237    /// path in `eval` (which dispatches to `tract_linalg::ops().rms_norm_f32`)
238    /// is skipped and the original 4-call `MeanOfSquares` + `Add` + `Rsqrt` +
239    /// `Mul` composition runs. Asserts the result is identical to a hand-
240    /// computed reference, so the slow path stays correct after the fast-path
241    /// addition.
242    #[test]
243    fn eval_with_non_trailing_axis_f32() {
244        // 2x3 input, axis=0 means we normalise across the 2 rows for each
245        // column independently:
246        //   col 0: [1, 4] → mean_sq = (1 + 16) / 2 =  8.5 → 1/√8.5
247        //   col 1: [2, 5] → mean_sq = (4 + 25) / 2 = 14.5 → 1/√14.5
248        //   col 2: [3, 6] → mean_sq = (9 + 36) / 2 = 22.5 → 1/√22.5
249        let input = tensor2(&[[1.0_f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
250        let eps = tensor0(0.0_f32).into_arc_tensor();
251        let op = RmsNorm { axis: 0, eps };
252        let out = op.eval(tvec!(input.into())).expect("eval should not panic");
253        let out = out.into_iter().next().unwrap().into_tensor();
254        assert_eq!(out.datum_type(), DatumType::F32);
255        assert_eq!(out.shape(), &[2, 3]);
256        let got = unsafe { out.as_slice_unchecked::<f32>() };
257        let inv = |ms: f32| ms.sqrt().recip();
258        let expected: [f32; 6] = [
259            1.0 * inv(8.5),
260            2.0 * inv(14.5),
261            3.0 * inv(22.5),
262            4.0 * inv(8.5),
263            5.0 * inv(14.5),
264            6.0 * inv(22.5),
265        ];
266        for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
267            let diff = (g - e).abs();
268            assert!(diff < 1e-5, "lane {i}: got {g}, want {e}, diff {diff}");
269        }
270    }
271}