Skip to main content

tract_core/ops/nn/
rms_norm.rs

1use crate::internal::*;
2use crate::ops::binary::{BinMiniOp, TypedBinOp};
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Rsqrt};
5use crate::ops::nn::{Reduce, Reducer};
6use tract_itertools::Itertools;
7
8#[derive(Clone, Debug, Hash, PartialEq, Eq)]
9pub struct RmsNorm {
10    pub axis: usize,
11    pub eps: Arc<Tensor>,
12}
13
14impl Op for RmsNorm {
15    fn name(&self) -> StaticName {
16        "RmsNorm".to_string().into()
17    }
18    fn info(&self) -> TractResult<Vec<String>> {
19        Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)])
20    }
21    op_as_typed_op!();
22}
23
24impl EvalOp for RmsNorm {
25    fn is_stateless(&self) -> bool {
26        true
27    }
28
29    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30        let input = args_1!(inputs);
31
32        let input_f32 = input.cast_to::<f32>()?.into_owned();
33        // eps inherits the input dtype from the declutter pattern (F16 when the
34        // surrounding LayerNorm chain is F16). The MeanOfSquares + Add + Rsqrt
35        // + Mul chain below all runs at F32, so eps must be cast to match —
36        // otherwise the Add::eval call below panics with
37        //   "tensor is F32, accessed as F16"
38        // when input is F16.
39        let eps = self.eps.cast_to::<f32>()?.into_owned();
40        let a1 = Reducer::MeanOfSquares.reduce(&[self.axis], &input_f32)?;
41        let mut a2 = Add.eval(a1.into_tvalue(), eps.into_tvalue(), DatumType::F32)?;
42        Rsqrt {}.eval_in_place(&mut a2, None)?;
43        let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?;
44        Ok(tvec![a3.cast_to_dt(input.datum_type())?.into_owned().into()])
45    }
46}
47
48impl TypedOp for RmsNorm {
49    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
50        ensure!(self.eps.rank() == 0, "RmsNorm: eps must be a rank-0 tensor");
51        ensure!(
52            self.axis < inputs[0].rank(),
53            "RmsNorm: axis {} is out of bounds for input rank {}",
54            self.axis,
55            inputs[0].rank()
56        );
57        let dt = inputs[0].datum_type;
58        let fact = dt.fact(inputs[0].shape.clone());
59        Ok(tvec!(fact))
60    }
61
62    fn input_roi(
63        &self,
64        model: &TypedModel,
65        node: &TypedNode,
66    ) -> TractResult<Option<TVec<Option<TDim>>>> {
67        crate::optim::propagate_roi::bubble_roi(model, node)
68    }
69
70    fn axes_mapping(
71        &self,
72        inputs: &[&TypedFact],
73        _outputs: &[&TypedFact],
74    ) -> TractResult<AxesMapping> {
75        let rank = inputs[0].rank();
76        let mut letters = 'a'..;
77        let axes = (0..rank)
78            .map(|ix| {
79                Axis::new(letters.next().unwrap(), inputs.len(), 1).input(0, ix).output(0, ix)
80            })
81            .collect_vec();
82        AxesMapping::new(1, 1, axes)
83    }
84
85    fn change_axes(
86        &self,
87        model: &TypedModel,
88        node: &TypedNode,
89        _io: InOut,
90        change: &AxisOp,
91    ) -> TractResult<Option<AxisChangeConsequence>> {
92        if let Some(axis) = change.transform_axis(self.axis) {
93            let op = Some(Box::new(RmsNorm { axis, eps: self.eps.clone() }) as _);
94            Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
95        } else {
96            Ok(None)
97        }
98    }
99
100    fn slice(
101        &self,
102        patch: &mut TypedModelPatch,
103        _model: &TypedModel,
104        node: &TypedNode,
105        _prefix: &str,
106        inputs: &[OutletId],
107        output_axis: usize,
108        _start: &TDim,
109        _end: &TDim,
110    ) -> TractResult<Option<TVec<OutletId>>> {
111        rule_if!(output_axis != self.axis);
112        patch.wire_node(&node.name, self.clone(), inputs).map(Some)
113    }
114
115    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
116        let dt = inputs[0].datum_type;
117        let count: TDim = inputs[0].shape.iter().product();
118        // per element: square + accumulate + mul by rsqrt ≈ 3 FMA
119        // per reduction group: 1 div (rsqrt)
120        let groups: TDim = inputs[0]
121            .shape
122            .iter()
123            .enumerate()
124            .filter(|(i, _)| *i != self.axis)
125            .map(|(_, d)| d)
126            .product();
127        Ok(tvec!((Cost::FMA(dt), count * 3), (Cost::Div(dt), groups)))
128    }
129
130    as_op!();
131}
132
133/// Search pattern => A = A * RSQRT(MEAN_OF_SQUARES(A) + EPS)
134pub fn detect_rms_norm(
135    op: &Reduce,
136    model: &TypedModel,
137    node: &TypedNode,
138) -> TractResult<Option<TypedModelPatch>> {
139    rule_if!(op.reducer == Reducer::MeanOfSquares);
140    rule_if!(op.axes.len() == 1);
141    let axis = op.axes[0];
142
143    let in_fact = model.node_input_facts(node.id)?[0];
144    let dt = in_fact.datum_type;
145
146    // Only F16 and F32 is supported.
147    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
148
149    // Identify Add operator
150    rule_if_some!(add_succ = model.single_succ(node.id)?);
151    rule_if_some!(add_succ_op = add_succ.op_as::<TypedBinOp>());
152    rule_if!(add_succ_op.0.is::<Add>());
153
154    // Retrieve epsilon
155    let add_consts = model.collect_const_inputs(add_succ);
156    rule_if!(add_consts.len() == 1);
157    let eps = add_consts[0].val().clone();
158    rule_if!(eps.len() == 1);
159    rule_if!(eps.datum_type() == dt);
160    let eps = eps.into_tensor().into_shape(&[])?.into_arc_tensor();
161
162    // Identify Rsqrt
163    rule_if_some!(rsqrt_succ = model.single_succ(add_succ.id)?);
164    rule_if_some!(rsqrt_succ_op = rsqrt_succ.op_as::<ElementWiseOp>());
165    rule_if!(rsqrt_succ_op.0.is::<Rsqrt>());
166
167    // Identify Mul: RSQRT(...) * A
168    rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(rsqrt_succ, &node.inputs[0]));
169
170    let mut patch = TypedModelPatch::default();
171    let rsm_input = patch.taps(model, &node.inputs)?;
172    let out =
173        patch.wire_node(format!("{}.rms_norm", node.name), RmsNorm { axis, eps }, &rsm_input)?;
174
175    patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
176    Ok(Some(patch))
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::ops::nn::RmsNorm;
183
184    /// Regression: the declutter pattern (`detect_rms_norm`) stores `eps` with
185    /// the input dtype (F16 when the surrounding LayerNorm chain is F16) — see
186    /// `rule_if!(eps.datum_type() == dt)` above. The eval path runs at F32, so
187    /// it must cast `self.eps` to F32 before using it. Without the cast in
188    /// `RmsNorm::eval`, this test panics with "tensor is F32, accessed as F16".
189    #[test]
190    fn eval_with_f16_eps_and_f16_input() {
191        let to_h = |x: f32| f16::from_f32(x);
192        let input = tensor1(&[to_h(1.0), to_h(2.0), to_h(3.0), to_h(4.0)]);
193        let eps = tensor0(to_h(1e-5)).into_arc_tensor();
194        let op = RmsNorm { axis: 0, eps };
195        let out = op.eval(tvec!(input.clone().into())).expect("eval should not panic");
196        let out = out.into_iter().next().unwrap().into_tensor();
197        assert_eq!(out.datum_type(), DatumType::F16);
198        assert_eq!(out.shape(), &[4]);
199        // Reference: rms = sqrt((1+4+9+16)/4 + eps) = sqrt(7.5 + 1e-5) ≈ 2.7386
200        // normalised: [1, 2, 3, 4] / 2.7386 ≈ [0.365, 0.730, 1.095, 1.461]
201        let got = unsafe { out.as_slice_unchecked::<f16>() };
202        let expected = [0.365_f32, 0.730, 1.095, 1.461];
203        for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
204            let diff = (g.to_f32() - e).abs();
205            assert!(diff < 0.01, "lane {i}: got {} expected {}", g.to_f32(), e);
206        }
207    }
208}