Skip to main content

tract_core/ops/nn/
rms_norm.rs

1use crate::internal::*;
2use crate::ops::binary::{BinMiniOp, TypedBinOp};
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Rsqrt};
5use crate::ops::nn::{Reduce, Reducer};
6use tract_itertools::Itertools;
7
8#[derive(Clone, Debug, Hash)]
9pub struct RmsNorm {
10    pub axis: usize,
11    pub eps: Arc<Tensor>,
12}
13
14impl Op for RmsNorm {
15    fn name(&self) -> StaticName {
16        "RmsNorm".to_string().into()
17    }
18    fn info(&self) -> TractResult<Vec<String>> {
19        Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)])
20    }
21    op_as_typed_op!();
22}
23
24impl EvalOp for RmsNorm {
25    fn is_stateless(&self) -> bool {
26        true
27    }
28
29    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30        let input = args_1!(inputs);
31
32        let input_f32 = input.cast_to::<f32>()?.into_owned();
33        let a1 = Reducer::MeanOfSquares.reduce(&[self.axis], &input_f32)?;
34        let mut a2 = Add.eval(a1.into_tvalue(), self.eps.clone().into_tvalue(), DatumType::F32)?;
35        Rsqrt {}.eval_in_place(&mut a2, None)?;
36        let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?;
37        Ok(tvec![a3.cast_to_dt(input.datum_type())?.into_owned().into()])
38    }
39}
40
41impl TypedOp for RmsNorm {
42    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
43        ensure!(self.eps.rank() == 0, "RmsNorm: eps must be a rank-0 tensor");
44        ensure!(
45            self.axis < inputs[0].rank(),
46            "RmsNorm: axis {} is out of bounds for input rank {}",
47            self.axis,
48            inputs[0].rank()
49        );
50        let dt = inputs[0].datum_type;
51        let fact = dt.fact(inputs[0].shape.clone());
52        Ok(tvec!(fact))
53    }
54
55    fn axes_mapping(
56        &self,
57        inputs: &[&TypedFact],
58        _outputs: &[&TypedFact],
59    ) -> TractResult<AxesMapping> {
60        let rank = inputs[0].rank();
61        let mut letters = 'a'..;
62        let axes = (0..rank)
63            .map(|ix| {
64                Axis::new(letters.next().unwrap(), inputs.len(), 1).input(0, ix).output(0, ix)
65            })
66            .collect_vec();
67        AxesMapping::new(1, 1, axes)
68    }
69
70    fn change_axes(
71        &self,
72        model: &TypedModel,
73        node: &TypedNode,
74        _io: InOut,
75        change: &AxisOp,
76    ) -> TractResult<Option<AxisChangeConsequence>> {
77        if let Some(axis) = change.transform_axis(self.axis) {
78            let op = Some(Box::new(RmsNorm { axis, eps: self.eps.clone() }) as _);
79            Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
80        } else {
81            Ok(None)
82        }
83    }
84
85    fn slice(
86        &self,
87        patch: &mut TypedModelPatch,
88        _model: &TypedModel,
89        node: &TypedNode,
90        _prefix: &str,
91        inputs: &[OutletId],
92        output_axis: usize,
93        _start: &TDim,
94        _end: &TDim,
95    ) -> TractResult<Option<TVec<OutletId>>> {
96        if output_axis == self.axis {
97            return Ok(None);
98        }
99        patch.wire_node(&node.name, self.clone(), inputs).map(Some)
100    }
101
102    as_op!();
103}
104
105/// Search pattern => A = A * RSQRT(MEAN_OF_SQUARES(A) + EPS)
106pub fn detect_rms_norm(
107    op: &Reduce,
108    model: &TypedModel,
109    node: &TypedNode,
110) -> TractResult<Option<TypedModelPatch>> {
111    rule_if!(op.reducer == Reducer::MeanOfSquares);
112    rule_if!(op.axes.len() == 1);
113    let axis = op.axes[0];
114
115    let in_fact = model.node_input_facts(node.id)?[0];
116    let dt = in_fact.datum_type;
117
118    // Only F16 and F32 is supported.
119    rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
120
121    // Identify Add operator
122    rule_if_some!(add_succ = model.single_succ(node.id)?);
123    rule_if_some!(add_succ_op = add_succ.op_as::<TypedBinOp>());
124    rule_if!(add_succ_op.0.is::<Add>());
125
126    // Retrieve epsilon
127    let add_consts = model.collect_const_inputs(add_succ);
128    rule_if!(add_consts.len() == 1);
129    let eps = add_consts[0].val().clone();
130    rule_if!(eps.len() == 1);
131    rule_if!(eps.datum_type() == dt);
132    let eps = eps.into_tensor().into_shape(&[])?.into_arc_tensor();
133
134    // Identify Rsqrt
135    rule_if_some!(rsqrt_succ = model.single_succ(add_succ.id)?);
136    rule_if_some!(rsqrt_succ_op = rsqrt_succ.op_as::<ElementWiseOp>());
137    rule_if!(rsqrt_succ_op.0.is::<Rsqrt>());
138
139    // Identify Mul: RSQRT(...) * A
140    rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(rsqrt_succ, &node.inputs[0]));
141
142    let mut patch = TypedModelPatch::default();
143    let rsm_input = patch.taps(model, &node.inputs)?;
144    let out =
145        patch.wire_node(format!("{}.rms_norm", node.name), RmsNorm { axis, eps }, &rsm_input)?;
146
147    patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
148    Ok(Some(patch))
149}