tract_core/ops/nn/
rms_norm.rs1use crate::internal::*;
2use crate::ops::binary::{BinMiniOp, TypedBinOp};
3use crate::ops::element_wise::ElementWiseOp;
4use crate::ops::math::{Add, Mul, Rsqrt};
5use crate::ops::nn::{Reduce, Reducer};
6use tract_itertools::Itertools;
7
8#[derive(Clone, Debug, Hash, PartialEq, Eq)]
9pub struct RmsNorm {
10 pub axis: usize,
11 pub eps: Arc<Tensor>,
12}
13
14impl Op for RmsNorm {
15 fn name(&self) -> StaticName {
16 "RmsNorm".to_string().into()
17 }
18 fn info(&self) -> TractResult<Vec<String>> {
19 Ok(vec![format!("axis: {:?}, eps: {:?}", self.axis, self.eps)])
20 }
21 op_as_typed_op!();
22}
23
24impl EvalOp for RmsNorm {
25 fn is_stateless(&self) -> bool {
26 true
27 }
28
29 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
30 let input = args_1!(inputs);
31
32 let input_f32 = input.cast_to::<f32>()?.into_owned();
33 let a1 = Reducer::MeanOfSquares.reduce(&[self.axis], &input_f32)?;
34 let mut a2 = Add.eval(a1.into_tvalue(), self.eps.clone().into_tvalue(), DatumType::F32)?;
35 Rsqrt {}.eval_in_place(&mut a2, None)?;
36 let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?;
37 Ok(tvec![a3.cast_to_dt(input.datum_type())?.into_owned().into()])
38 }
39}
40
41impl TypedOp for RmsNorm {
42 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
43 ensure!(self.eps.rank() == 0, "RmsNorm: eps must be a rank-0 tensor");
44 ensure!(
45 self.axis < inputs[0].rank(),
46 "RmsNorm: axis {} is out of bounds for input rank {}",
47 self.axis,
48 inputs[0].rank()
49 );
50 let dt = inputs[0].datum_type;
51 let fact = dt.fact(inputs[0].shape.clone());
52 Ok(tvec!(fact))
53 }
54
55 fn input_roi(
56 &self,
57 model: &TypedModel,
58 node: &TypedNode,
59 ) -> TractResult<Option<TVec<Option<TDim>>>> {
60 crate::optim::propagate_roi::bubble_roi(model, node)
61 }
62
63 fn axes_mapping(
64 &self,
65 inputs: &[&TypedFact],
66 _outputs: &[&TypedFact],
67 ) -> TractResult<AxesMapping> {
68 let rank = inputs[0].rank();
69 let mut letters = 'a'..;
70 let axes = (0..rank)
71 .map(|ix| {
72 Axis::new(letters.next().unwrap(), inputs.len(), 1).input(0, ix).output(0, ix)
73 })
74 .collect_vec();
75 AxesMapping::new(1, 1, axes)
76 }
77
78 fn change_axes(
79 &self,
80 model: &TypedModel,
81 node: &TypedNode,
82 _io: InOut,
83 change: &AxisOp,
84 ) -> TractResult<Option<AxisChangeConsequence>> {
85 if let Some(axis) = change.transform_axis(self.axis) {
86 let op = Some(Box::new(RmsNorm { axis, eps: self.eps.clone() }) as _);
87 Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
88 } else {
89 Ok(None)
90 }
91 }
92
93 fn slice(
94 &self,
95 patch: &mut TypedModelPatch,
96 _model: &TypedModel,
97 node: &TypedNode,
98 _prefix: &str,
99 inputs: &[OutletId],
100 output_axis: usize,
101 _start: &TDim,
102 _end: &TDim,
103 ) -> TractResult<Option<TVec<OutletId>>> {
104 if output_axis == self.axis {
105 return Ok(None);
106 }
107 patch.wire_node(&node.name, self.clone(), inputs).map(Some)
108 }
109
110 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
111 let dt = inputs[0].datum_type;
112 let count: TDim = inputs[0].shape.iter().product();
113 let groups: TDim = inputs[0]
116 .shape
117 .iter()
118 .enumerate()
119 .filter(|(i, _)| *i != self.axis)
120 .map(|(_, d)| d)
121 .product();
122 Ok(tvec!((Cost::FMA(dt), count * 3), (Cost::Div(dt), groups)))
123 }
124
125 as_op!();
126}
127
128pub fn detect_rms_norm(
130 op: &Reduce,
131 model: &TypedModel,
132 node: &TypedNode,
133) -> TractResult<Option<TypedModelPatch>> {
134 rule_if!(op.reducer == Reducer::MeanOfSquares);
135 rule_if!(op.axes.len() == 1);
136 let axis = op.axes[0];
137
138 let in_fact = model.node_input_facts(node.id)?[0];
139 let dt = in_fact.datum_type;
140
141 rule_if!(matches!(dt, DatumType::F32 | DatumType::F16));
143
144 rule_if_some!(add_succ = model.single_succ(node.id)?);
146 rule_if_some!(add_succ_op = add_succ.op_as::<TypedBinOp>());
147 rule_if!(add_succ_op.0.is::<Add>());
148
149 let add_consts = model.collect_const_inputs(add_succ);
151 rule_if!(add_consts.len() == 1);
152 let eps = add_consts[0].val().clone();
153 rule_if!(eps.len() == 1);
154 rule_if!(eps.datum_type() == dt);
155 let eps = eps.into_tensor().into_shape(&[])?.into_arc_tensor();
156
157 rule_if_some!(rsqrt_succ = model.single_succ(add_succ.id)?);
159 rule_if_some!(rsqrt_succ_op = rsqrt_succ.op_as::<ElementWiseOp>());
160 rule_if!(rsqrt_succ_op.0.is::<Rsqrt>());
161
162 rule_if_some!(mul_succ = model.find_succ_bin_with_outlet::<Mul>(rsqrt_succ, &node.inputs[0]));
164
165 let mut patch = TypedModelPatch::default();
166 let rsm_input = patch.taps(model, &node.inputs)?;
167 let out =
168 patch.wire_node(format!("{}.rms_norm", node.name), RmsNorm { axis, eps }, &rsm_input)?;
169
170 patch.shunt_outside(model, mul_succ.id.into(), out[0])?;
171 Ok(Some(patch))
172}