Skip to main content

tract_core/ops/
element_wise.rs

1use crate::internal::*;
2use downcast_rs::Downcast;
3use dyn_eq::DynEq;
4use std::fmt;
5
6pub trait ElementWiseMiniOp:
7    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
8{
9    fn name(&self) -> String;
10    fn prefix(&self) -> &'static str {
11        ""
12    }
13    fn validation(&self) -> Validation {
14        Validation::Accurate
15    }
16    #[allow(unused_variables)]
17    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
18        None
19    }
20    #[allow(unused_variables)]
21    fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
22        bail!("Element wise eval in-place not defined");
23    }
24    #[allow(unused_variables)]
25    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
26        bail!("Element wise eval out-of-place place not defined");
27    }
28    #[allow(unused_variables)]
29    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
30        tvec!()
31    }
32    #[allow(unused_variables)]
33    fn operating_datum_type(&self, dt: DatumType) -> DatumType {
34        dt
35    }
36    #[allow(unused_variables)]
37    fn declutter(
38        &self,
39        model: &TypedModel,
40        node: &TypedNode,
41    ) -> TractResult<Option<TypedModelPatch>> {
42        Ok(None)
43    }
44
45    #[allow(unused_variables)]
46    fn quantize(
47        &self,
48        dt: DatumType,
49        scale: f32,
50        zero_point: i32,
51    ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
52        Ok(None)
53    }
54    #[allow(unused_variables)]
55    fn info(&self) -> TractResult<Vec<String>> {
56        Ok(vec![])
57    }
58}
59
60dyn_clone::clone_trait_object!(ElementWiseMiniOp);
61dyn_eq::eq_trait_object!(ElementWiseMiniOp);
62downcast_rs::impl_downcast!(ElementWiseMiniOp);
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
66
67impl ElementWiseOp {
68    fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
69        self.1.unwrap_or(self.0.operating_datum_type(input_dt))
70    }
71}
72
73impl Op for ElementWiseOp {
74    fn name(&self) -> StaticName {
75        self.0.name().into()
76    }
77
78    fn info(&self) -> TractResult<Vec<String>> {
79        self.0.info()
80    }
81
82    fn validation(&self) -> Validation {
83        self.0.validation()
84    }
85
86    op_as_typed_op!();
87}
88
89impl EvalOp for ElementWiseOp {
90    fn is_stateless(&self) -> bool {
91        true
92    }
93
94    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
95        if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
96            Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
97        } else {
98            let mut m = inputs.remove(0).into_tensor();
99            self.0.eval_in_place(&mut m, self.1)?;
100            Ok(tvec!(m.into()))
101        }
102    }
103}
104
105impl TypedOp for ElementWiseOp {
106    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
107        let mut fact = inputs[0].clone().without_value();
108        let dt = self.output_datum_type(fact.datum_type);
109        if let Some(dt) = self.1 {
110            fact.datum_type = dt;
111        } else if let Some(dt) = self.0.output_type(dt) {
112            fact.datum_type = dt;
113        }
114        // Propagate uniform_tdim through this op.
115        if let Some(tdim) = &inputs[0].uniform_tdim {
116            // Logical NOT on bool tensors: NOT(x) = 1 - x for 0/1 values.
117            // Not is bool-only by definition. BitNot is bitwise (valid on integers
118            // where ~x ≠ 1-x), so only apply this for bool input.
119            let is_logical_not = self.0.downcast_ref::<crate::ops::logic::Not>().is_some()
120                || (self.0.downcast_ref::<crate::ops::logic::BitNot>().is_some()
121                    && inputs[0].datum_type == bool::datum_type());
122            if is_logical_not {
123                fact.uniform_tdim = Some((TDim::Val(1) - tdim.clone()).reduce());
124            } else {
125                // General path: evaluate the op on a TDim scalar.
126                // Ops with a TDim arm (e.g. Floor → identity) pass the value through;
127                // ops without one return an error and uniform_tdim stays None.
128                let mut tmp = tensor0(tdim.clone());
129                if self.0.eval_in_place(&mut tmp, None).is_ok() {
130                    fact.uniform_tdim = tmp
131                        .try_as_plain()
132                        .ok()
133                        .and_then(|d| d.as_slice::<TDim>().ok())
134                        .and_then(|s| s.first())
135                        .cloned()
136                        .map(|d| d.reduce());
137                }
138            }
139        }
140        Ok(tvec!(fact))
141    }
142
143    fn input_roi(
144        &self,
145        model: &TypedModel,
146        node: &TypedNode,
147    ) -> TractResult<Option<TVec<Option<TDim>>>> {
148        crate::optim::propagate_roi::bubble_roi(model, node)
149    }
150
151    fn change_axes(
152        &self,
153        model: &TypedModel,
154        node: &TypedNode,
155        _io: InOut,
156        change: &AxisOp,
157    ) -> TractResult<Option<AxisChangeConsequence>> {
158        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
159    }
160
161    fn declutter(
162        &self,
163        model: &TypedModel,
164        node: &TypedNode,
165    ) -> TractResult<Option<TypedModelPatch>> {
166        if let Some(prec) = model.single_prec(node.id)?
167            && (prec.op_is::<AxisOp>() || prec.op_is::<IntoShape>())
168        {
169            let mut patch = TypedModelPatch::default();
170            let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
171            wire = patch.wire_node(&node.name, &node.op, &wire)?;
172            wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
173            patch.shunt_outside(model, node.id.into(), wire[0])?;
174            return Ok(Some(patch));
175        }
176        self.0.declutter(model, node)
177    }
178
179    fn axes_mapping(
180        &self,
181        inputs: &[&TypedFact],
182        outputs: &[&TypedFact],
183    ) -> TractResult<AxesMapping> {
184        AxesMapping::natural(inputs, outputs)
185    }
186
187    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
188        let count: TDim = inputs[0].shape.iter().product();
189        Ok(self
190            .0
191            .cost_per_element(inputs[0].datum_type)
192            .into_iter()
193            .map(|(c, n)| (c, count.clone() * n))
194            .collect())
195    }
196
197    fn quantize(
198        &self,
199        _model: &TypedModel,
200        _node: &TypedNode,
201        dt: DatumType,
202        scale: f32,
203        zero_point: i32,
204    ) -> TractResult<Option<Box<dyn TypedOp>>> {
205        if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
206            Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
207        } else {
208            Ok(None)
209        }
210    }
211
212    fn slice(
213        &self,
214        patch: &mut TypedModelPatch,
215        _model: &TypedModel,
216        node: &TypedNode,
217        _prefix: &str,
218        inputs: &[OutletId],
219        _output_axis: usize,
220        _start: &TDim,
221        _end: &TDim,
222    ) -> TractResult<Option<TVec<OutletId>>> {
223        patch.wire_node(&node.name, &node.op, inputs).map(Some)
224    }
225
226    as_op!();
227}
228
229#[macro_export]
230macro_rules! element_wise {
231    ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
232        $([$($typ:ident),*] => $f:expr ),*
233        $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
234        $(; cost: $cost:expr )?
235        $(; declutter: $declutter:expr )?
236        $(; operating_datum_type: $operating_datum_type:expr )?
237        $(; prefix: $prefix:expr )?
238        $(; quantize: $quantize:expr )?
239        $(; validation: $validation:expr )?
240    ) => {
241        #[derive(Debug, Clone, PartialEq)]
242        pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
243        impl Eq for $Op {}
244        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
245            fn name(&self) -> String {
246                format!("{}{}", self.prefix(), stringify!($Op))
247            }
248            fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
249                $(
250                    $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
251                        let mut t_plain = t.try_as_plain_mut()?;
252                        let t: &mut[$typ] = t_plain.as_slice_mut::<$typ>()?;
253                        let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
254                        f(self, t)?;
255                        return Ok(())
256                    }
257                    )*
258                )*
259                $(
260                    $(
261                       $(
262                        let mut input_dt = t.datum_type();
263                        let sout_dt = out_dt.unwrap_or(input_dt);
264                        if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
265                           if input_dt.unquantized() != sout_dt.unquantized() {
266                               // align unquantized input type to unquantized output type
267                               *t = match input_dt.unquantized() {
268                                   DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
269                                   DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
270                                   unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
271                               }.into_tensor();
272                               input_dt = t.datum_type(); // because zero_point change
273                           }
274                           unsafe { t.set_datum_type(sout_dt) } // force cast
275                           let mut t_plain = t.try_as_plain_mut()?;
276                           let t: &mut[$typ_dt] = t_plain.as_slice_mut::<$typ_dt>()?;
277                           let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
278                               let (izp, iscale) = input_dt.zp_scale();
279                               let (ozp, oscale) = out_dt.zp_scale();
280                               xs.iter_mut().for_each(|x| {
281                                   let x_f32 = (*x as f32 - izp as f32) * iscale;
282                                   *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
283                               });
284                               Ok(())
285                           };
286                           f(self, t, input_dt, sout_dt)?;
287                           return Ok(())
288                       }
289                       )*
290                   )*
291                )?
292                bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
293            }
294            $(
295            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
296                $cost(dt)
297            }
298            )?
299            $(
300                fn declutter(
301                    &self,
302                    model: &TypedModel,
303                    node: &TypedNode,
304                ) -> TractResult<Option<TypedModelPatch>> {
305                    $declutter(model, node)
306                }
307            )?
308            $(
309            fn prefix(&self) -> &'static str {
310                $prefix
311            }
312            )?
313            $(
314            fn quantize(
315                &self,
316                dt: DatumType,
317                scale: f32,
318                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
319                    $quantize(&self, dt, scale, zero_point)
320            }
321            )?
322            $(
323            fn validation(&self) -> Validation {
324                $validation
325            }
326            )?
327            $(
328            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
329                ($operating_datum_type)(dt)
330            }
331            )?
332        }
333        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
334            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
335        }
336    }
337}
338
339#[macro_export]
340macro_rules! element_wise_oop {
341    ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
342        $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
343        $(; cost: $cost:expr )?
344        $(; info: $info:expr )?
345        $(; operating_datum_type: $operating_datum_type:expr )?
346        $(; prefix: $prefix:expr )?
347        $(; quantize: $quantize:expr )?
348        $(; validation: $validation:expr )?
349    ) => {
350        #[derive(Debug, Clone)]
351        pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
352        impl PartialEq for $Op {
353            #[allow(unused_variables)]
354            fn eq(&self, other: &Self) -> bool {
355                $( $( if &self.$var != &other.$var { return false; })* )?
356                true
357            }
358        }
359        impl Eq for $Op {}
360        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
361            fn name(&self) -> String {
362                format!("{}{}", self.prefix(), stringify!($Op))
363            }
364            fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
365                $(
366                    $(if input_type == $typ::datum_type() {
367                        return Some(<$typ_dst>::datum_type())
368                    }
369                    )*
370                )*
371                None
372            }
373            fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
374                $(
375                    let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
376                    $(if t.datum_type() == $typ::datum_type() {
377                        let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
378                        let mut dst_plain = dst.try_as_plain_mut()?;
379                        f(self, t.try_as_plain()?.as_slice::<$typ>()?, dst_plain.as_slice_mut::<$typ_dst>()?)?;
380                        return Ok(dst)
381                    }
382                    )*
383                )*
384                bail!("{} does not support {:?}", self.name(), t.datum_type());
385            }
386            $(
387            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
388                $cost(dt)
389            }
390            )?
391            $(
392            fn info(&self) -> TractResult<Vec<String>> {
393                $info(self)
394            }
395            )?
396            $(
397            fn prefix(&self) -> &'static str {
398                $prefix
399            }
400            )?
401            $(
402            fn quantize(
403                &self,
404                dt: DatumType,
405                scale: f32,
406                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
407                    $quantize(ft, scale, zero_point)
408            }
409            )?
410            $(
411            fn validation(&self) -> Validation {
412                $validation
413            }
414            )?
415            $(
416            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
417                ($operating_datum_type)(dt)
418            }
419            )?
420        }
421        $(#[$fmeta])*
422        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
423            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
424        }
425    }
426}