Skip to main content

tract_core/ops/
element_wise.rs

1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6
7pub trait ElementWiseMiniOp:
8    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
9{
10    fn name(&self) -> String;
11    fn prefix(&self) -> &'static str {
12        ""
13    }
14    fn validation(&self) -> Validation {
15        Validation::Accurate
16    }
17    #[allow(unused_variables)]
18    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
19        None
20    }
21    #[allow(unused_variables)]
22    fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
23        bail!("Element wise eval in-place not defined");
24    }
25    #[allow(unused_variables)]
26    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
27        bail!("Element wise eval out-of-place place not defined");
28    }
29    #[allow(unused_variables)]
30    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
31        tvec!()
32    }
33    #[allow(unused_variables)]
34    fn operating_datum_type(&self, dt: DatumType) -> DatumType {
35        dt
36    }
37    #[allow(unused_variables)]
38    fn declutter(
39        &self,
40        model: &TypedModel,
41        node: &TypedNode,
42    ) -> TractResult<Option<TypedModelPatch>> {
43        Ok(None)
44    }
45
46    #[allow(unused_variables)]
47    fn quantize(
48        &self,
49        dt: DatumType,
50        scale: f32,
51        zero_point: i32,
52    ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
53        Ok(None)
54    }
55    #[allow(unused_variables)]
56    fn info(&self) -> TractResult<Vec<String>> {
57        Ok(vec![])
58    }
59}
60
61dyn_clone::clone_trait_object!(ElementWiseMiniOp);
62dyn_eq::eq_trait_object!(ElementWiseMiniOp);
63downcast_rs::impl_downcast!(ElementWiseMiniOp);
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
67
68impl ElementWiseOp {
69    fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
70        self.1.unwrap_or(self.0.operating_datum_type(input_dt))
71    }
72}
73
74impl Op for ElementWiseOp {
75    fn name(&self) -> StaticName {
76        self.0.name().into()
77    }
78
79    fn info(&self) -> TractResult<Vec<String>> {
80        self.0.info()
81    }
82
83    fn validation(&self) -> Validation {
84        self.0.validation()
85    }
86
87    op_as_typed_op!();
88}
89
90impl EvalOp for ElementWiseOp {
91    fn is_stateless(&self) -> bool {
92        true
93    }
94
95    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
96        if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
97            Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
98        } else {
99            let mut m = inputs.remove(0).into_tensor();
100            self.0.eval_in_place(&mut m, self.1)?;
101            Ok(tvec!(m.into()))
102        }
103    }
104}
105
106impl TypedOp for ElementWiseOp {
107    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108        let mut fact = inputs[0].clone().without_value();
109        let dt = self.output_datum_type(fact.datum_type);
110        if let Some(dt) = self.1 {
111            fact.datum_type = dt;
112        } else if let Some(dt) = self.0.output_type(dt) {
113            fact.datum_type = dt;
114        }
115        // Propagate uniform_tdim through this op.
116        if let Some(tdim) = &inputs[0].uniform_tdim {
117            // Logical NOT on bool tensors: NOT(x) = 1 - x for 0/1 values.
118            // Not is bool-only by definition. BitNot is bitwise (valid on integers
119            // where ~x ≠ 1-x), so only apply this for bool input.
120            let is_logical_not = self.0.downcast_ref::<crate::ops::logic::Not>().is_some()
121                || (self.0.downcast_ref::<crate::ops::logic::BitNot>().is_some()
122                    && inputs[0].datum_type == bool::datum_type());
123            if is_logical_not {
124                fact.uniform_tdim = Some((TDim::Val(1) - tdim.clone()).reduce());
125            } else {
126                // General path: evaluate the op on a TDim scalar.
127                // Ops with a TDim arm (e.g. Floor → identity) pass the value through;
128                // ops without one return an error and uniform_tdim stays None.
129                let mut tmp = tensor0(tdim.clone());
130                if self.0.eval_in_place(&mut tmp, None).is_ok() {
131                    fact.uniform_tdim = tmp
132                        .try_as_plain()
133                        .ok()
134                        .and_then(|d| d.as_slice::<TDim>().ok())
135                        .and_then(|s| s.first())
136                        .cloned()
137                        .map(|d| d.reduce());
138                }
139            }
140        }
141        Ok(tvec!(fact))
142    }
143
144    fn input_roi(
145        &self,
146        model: &TypedModel,
147        node: &TypedNode,
148    ) -> TractResult<Option<TVec<Option<TDim>>>> {
149        crate::optim::propagate_roi::bubble_roi(model, node)
150    }
151
152    fn change_axes(
153        &self,
154        model: &TypedModel,
155        node: &TypedNode,
156        _io: InOut,
157        change: &AxisOp,
158    ) -> TractResult<Option<AxisChangeConsequence>> {
159        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160    }
161
162    fn declutter(
163        &self,
164        model: &TypedModel,
165        node: &TypedNode,
166    ) -> TractResult<Option<TypedModelPatch>> {
167        // linear_prec (fan-in=1, fan-out=1) rather than single_prec: swapping
168        // through a fan-out predecessor clones it, and the clone can break
169        // downstream pattern detectors (e.g. Square+Reduce<Sum>+Mul fusion
170        // into Reduce<MeanOfSquares> feeding RmsNorm detection).
171        if let Some(prec) = model.linear_prec(node.id)?
172            && (prec.op_is::<AxisOp>()
173                || prec.op_is::<IntoShape>()
174                || prec.op_is::<MultiBroadcastTo>())
175        {
176            let mut patch = TypedModelPatch::default();
177            let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
178            wire = patch.wire_node(&node.name, &node.op, &wire)?;
179            wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
180            patch.shunt_outside(model, node.id.into(), wire[0])?;
181            return Ok(Some(patch));
182        }
183        self.0.declutter(model, node)
184    }
185
186    fn axes_mapping(
187        &self,
188        inputs: &[&TypedFact],
189        outputs: &[&TypedFact],
190    ) -> TractResult<AxesMapping> {
191        AxesMapping::natural(inputs, outputs)
192    }
193
194    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
195        let count: TDim = inputs[0].shape.iter().product();
196        Ok(self
197            .0
198            .cost_per_element(inputs[0].datum_type)
199            .into_iter()
200            .map(|(c, n)| (c, count.clone() * n))
201            .collect())
202    }
203
204    fn quantize(
205        &self,
206        _model: &TypedModel,
207        _node: &TypedNode,
208        dt: DatumType,
209        scale: f32,
210        zero_point: i32,
211    ) -> TractResult<Option<Box<dyn TypedOp>>> {
212        if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
213            Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
214        } else {
215            Ok(None)
216        }
217    }
218
219    fn slice(
220        &self,
221        patch: &mut TypedModelPatch,
222        _model: &TypedModel,
223        node: &TypedNode,
224        _prefix: &str,
225        inputs: &[OutletId],
226        _output_axis: usize,
227        _start: &TDim,
228        _end: &TDim,
229    ) -> TractResult<Option<TVec<OutletId>>> {
230        patch.wire_node(&node.name, &node.op, inputs).map(Some)
231    }
232
233    as_op!();
234}
235
236#[macro_export]
237macro_rules! element_wise {
238    ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
239        $([$($typ:ident),*] => $f:expr ),*
240        $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
241        $(; cost: $cost:expr )?
242        $(; declutter: $declutter:expr )?
243        $(; operating_datum_type: $operating_datum_type:expr )?
244        $(; prefix: $prefix:expr )?
245        $(; quantize: $quantize:expr )?
246        $(; validation: $validation:expr )?
247    ) => {
248        #[derive(Debug, Clone, PartialEq)]
249        pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
250        impl Eq for $Op {}
251        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
252            fn name(&self) -> String {
253                format!("{}{}", self.prefix(), stringify!($Op))
254            }
255            fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
256                $(
257                    $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
258                        let mut t_plain = t.try_as_plain_mut()?;
259                        let t: &mut[$typ] = t_plain.as_slice_mut::<$typ>()?;
260                        let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
261                        f(self, t)?;
262                        return Ok(())
263                    }
264                    )*
265                )*
266                $(
267                    $(
268                       $(
269                        let mut input_dt = t.datum_type();
270                        let sout_dt = out_dt.unwrap_or(input_dt);
271                        if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
272                           if input_dt.unquantized() != sout_dt.unquantized() {
273                               // align unquantized input type to unquantized output type
274                               *t = match input_dt.unquantized() {
275                                   DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
276                                   DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
277                                   unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
278                               }.into_tensor();
279                               input_dt = t.datum_type(); // because zero_point change
280                           }
281                           unsafe { t.set_datum_type(sout_dt) } // force cast
282                           let mut t_plain = t.try_as_plain_mut()?;
283                           let t: &mut[$typ_dt] = t_plain.as_slice_mut::<$typ_dt>()?;
284                           let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
285                               let (izp, iscale) = input_dt.zp_scale();
286                               let (ozp, oscale) = out_dt.zp_scale();
287                               xs.iter_mut().for_each(|x| {
288                                   let x_f32 = (*x as f32 - izp as f32) * iscale;
289                                   *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
290                               });
291                               Ok(())
292                           };
293                           f(self, t, input_dt, sout_dt)?;
294                           return Ok(())
295                       }
296                       )*
297                   )*
298                )?
299                bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
300            }
301            $(
302            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
303                $cost(dt)
304            }
305            )?
306            $(
307                fn declutter(
308                    &self,
309                    model: &TypedModel,
310                    node: &TypedNode,
311                ) -> TractResult<Option<TypedModelPatch>> {
312                    $declutter(model, node)
313                }
314            )?
315            $(
316            fn prefix(&self) -> &'static str {
317                $prefix
318            }
319            )?
320            $(
321            fn quantize(
322                &self,
323                dt: DatumType,
324                scale: f32,
325                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
326                    $quantize(&self, dt, scale, zero_point)
327            }
328            )?
329            $(
330            fn validation(&self) -> Validation {
331                $validation
332            }
333            )?
334            $(
335            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
336                ($operating_datum_type)(dt)
337            }
338            )?
339        }
340        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
341            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
342        }
343    }
344}
345
346#[macro_export]
347macro_rules! element_wise_oop {
348    ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
349        $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
350        $(; cost: $cost:expr )?
351        $(; info: $info:expr )?
352        $(; operating_datum_type: $operating_datum_type:expr )?
353        $(; prefix: $prefix:expr )?
354        $(; quantize: $quantize:expr )?
355        $(; validation: $validation:expr )?
356    ) => {
357        #[derive(Debug, Clone)]
358        pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
359        impl PartialEq for $Op {
360            #[allow(unused_variables)]
361            fn eq(&self, other: &Self) -> bool {
362                $( $( if &self.$var != &other.$var { return false; })* )?
363                true
364            }
365        }
366        impl Eq for $Op {}
367        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
368            fn name(&self) -> String {
369                format!("{}{}", self.prefix(), stringify!($Op))
370            }
371            fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
372                $(
373                    $(if input_type == $typ::datum_type() {
374                        return Some(<$typ_dst>::datum_type())
375                    }
376                    )*
377                )*
378                None
379            }
380            fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
381                $(
382                    let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
383                    $(if t.datum_type() == $typ::datum_type() {
384                        let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
385                        let mut dst_plain = dst.try_as_plain_mut()?;
386                        f(self, t.try_as_plain()?.as_slice::<$typ>()?, dst_plain.as_slice_mut::<$typ_dst>()?)?;
387                        return Ok(dst)
388                    }
389                    )*
390                )*
391                bail!("{} does not support {:?}", self.name(), t.datum_type());
392            }
393            $(
394            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
395                $cost(dt)
396            }
397            )?
398            $(
399            fn info(&self) -> TractResult<Vec<String>> {
400                $info(self)
401            }
402            )?
403            $(
404            fn prefix(&self) -> &'static str {
405                $prefix
406            }
407            )?
408            $(
409            fn quantize(
410                &self,
411                dt: DatumType,
412                scale: f32,
413                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
414                    $quantize(ft, scale, zero_point)
415            }
416            )?
417            $(
418            fn validation(&self) -> Validation {
419                $validation
420            }
421            )?
422            $(
423            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
424                ($operating_datum_type)(dt)
425            }
426            )?
427        }
428        $(#[$fmeta])*
429        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
430            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
431        }
432    }
433}