tract_core/ops/
element_wise.rs

1use crate::internal::*;
2use downcast_rs::Downcast;
3use std::fmt;
4
5pub trait ElementWiseMiniOp:
6    fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast
7{
8    fn name(&self) -> String;
9    fn prefix(&self) -> &'static str {
10        ""
11    }
12    fn validation(&self) -> Validation {
13        Validation::Accurate
14    }
15    #[allow(unused_variables)]
16    fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
17        None
18    }
19    #[allow(unused_variables)]
20    fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
21        bail!("Element wise eval in-place not defined");
22    }
23    #[allow(unused_variables)]
24    fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
25        bail!("Element wise eval out-of-place place not defined");
26    }
27    #[allow(unused_variables)]
28    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
29        tvec!()
30    }
31    #[allow(unused_variables)]
32    fn operating_datum_type(&self, dt: DatumType) -> DatumType {
33        dt
34    }
35    #[allow(unused_variables)]
36    fn declutter(
37        &self,
38        model: &TypedModel,
39        node: &TypedNode,
40    ) -> TractResult<Option<TypedModelPatch>> {
41        Ok(None)
42    }
43
44    #[allow(unused_variables)]
45    fn quantize(
46        &self,
47        dt: DatumType,
48        scale: f32,
49        zero_point: i32,
50    ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
51        Ok(None)
52    }
53    #[allow(unused_variables)]
54    fn info(&self) -> TractResult<Vec<String>> {
55        Ok(vec![])
56    }
57
58    #[allow(unused_variables)]
59    fn same_as(&self, other: &dyn ElementWiseMiniOp) -> bool {
60        false
61    }
62}
63
64dyn_clone::clone_trait_object!(ElementWiseMiniOp);
65downcast_rs::impl_downcast!(ElementWiseMiniOp);
66
67#[derive(Debug, Clone)]
68pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
69
70impl ElementWiseOp {
71    fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
72        self.1.unwrap_or(self.0.operating_datum_type(input_dt))
73    }
74}
75
76impl Op for ElementWiseOp {
77    fn name(&self) -> Cow<str> {
78        self.0.name().into()
79    }
80
81    fn info(&self) -> TractResult<Vec<String>> {
82        self.0.info()
83    }
84
85    fn validation(&self) -> Validation {
86        self.0.validation()
87    }
88
89    fn same_as(&self, other: &dyn Op) -> bool {
90        let Some(other) = other.downcast_ref::<ElementWiseOp>() else { return false };
91        self.1 == other.1 && self.0.same_as(&*other.0)
92    }
93
94    op_as_typed_op!();
95}
96
97impl EvalOp for ElementWiseOp {
98    fn is_stateless(&self) -> bool {
99        true
100    }
101
102    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
103        if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
104            Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
105        } else {
106            let mut m = inputs.remove(0).into_tensor();
107            self.0.eval_in_place(&mut m, self.1)?;
108            Ok(tvec!(m.into()))
109        }
110    }
111}
112
113impl TypedOp for ElementWiseOp {
114    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
115        let mut fact = inputs[0].clone().without_value();
116        let dt = self.output_datum_type(fact.datum_type);
117        if let Some(dt) = self.1 {
118            fact.datum_type = dt;
119        } else if let Some(dt) = self.0.output_type(dt) {
120            fact.datum_type = dt;
121        }
122        Ok(tvec!(fact))
123    }
124
125    fn change_axes(
126        &self,
127        model: &TypedModel,
128        node: &TypedNode,
129        _io: InOut,
130        change: &AxisOp,
131    ) -> TractResult<Option<AxisChangeConsequence>> {
132        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
133    }
134
135    fn declutter(
136        &self,
137        model: &TypedModel,
138        node: &TypedNode,
139    ) -> TractResult<Option<TypedModelPatch>> {
140        if let Some(prec) = model.single_prec(node.id)? {
141            if prec.op_is::<AxisOp>() || prec.op_is::<IntoShape>() {
142                let mut patch = TypedModelPatch::default();
143                let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
144                wire = patch.wire_node(&node.name, &node.op, &wire)?;
145                wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
146                patch.shunt_outside(model, node.id.into(), wire[0])?;
147                return Ok(Some(patch));
148            }
149        }
150        self.0.declutter(model, node)
151    }
152
153    fn axes_mapping(
154        &self,
155        inputs: &[&TypedFact],
156        outputs: &[&TypedFact],
157    ) -> TractResult<AxesMapping> {
158        AxesMapping::natural(inputs, outputs)
159    }
160
161    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
162        let count: TDim = inputs[0].shape.iter().product();
163        Ok(self
164            .0
165            .cost_per_element(inputs[0].datum_type)
166            .into_iter()
167            .map(|(c, n)| (c, count.clone() * n))
168            .collect())
169    }
170
171    fn quantize(
172        &self,
173        _model: &TypedModel,
174        _node: &TypedNode,
175        dt: DatumType,
176        scale: f32,
177        zero_point: i32,
178    ) -> TractResult<Option<Box<dyn TypedOp>>> {
179        if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
180            Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
181        } else {
182            Ok(None)
183        }
184    }
185
186    fn slice(
187        &self,
188        patch: &mut TypedModelPatch,
189        _model: &TypedModel,
190        node: &TypedNode,
191        _prefix: &str,
192        inputs: &[OutletId],
193        _output_axis: usize,
194        _start: &TDim,
195        _end: &TDim,
196    ) -> TractResult<Option<TVec<OutletId>>> {
197        patch.wire_node(&node.name, &node.op, inputs).map(Some)
198    }
199
200    as_op!();
201}
202
203#[macro_export]
204macro_rules! element_wise {
205    ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
206        $([$($typ:ident),*] => $f:expr ),*
207        $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
208        $(; cost: $cost:expr )?
209        $(; declutter: $declutter:expr )?
210        $(; operating_datum_type: $operating_datum_type:expr )?
211        $(; prefix: $prefix:expr )?
212        $(; quantize: $quantize:expr )?
213        $(; validation: $validation:expr )?
214    ) => {
215        #[derive(Debug, Clone)]
216        pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
217        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
218            fn name(&self) -> String {
219                format!("{}{}", self.prefix(), stringify!($Op))
220            }
221            #[allow(unused_variables)]
222            fn same_as(&self, other: &dyn ElementWiseMiniOp) -> bool {
223                let Some(other) = other.downcast_ref::<$Op>() else { return false };
224                $( $( if self.$var != other.$var { return false; })* )?
225                true
226            }
227            fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
228                $(
229                    $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
230                        let t: &mut[$typ] = t.as_slice_mut::<$typ>()?;
231                        let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
232                        f(self, t)?;
233                        return Ok(())
234                    }
235                    )*
236                )*
237                $(
238                    $(
239                       $(
240                        let mut input_dt = t.datum_type();
241                        let sout_dt = out_dt.unwrap_or(input_dt);
242                        if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
243                           if input_dt.unquantized() != sout_dt.unquantized() {
244                               // align unquantized input type to unquantized output type
245                               *t = match input_dt.unquantized() {
246                                   DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
247                                   DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
248                                   unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
249                               }.into_tensor();
250                               input_dt = t.datum_type(); // because zero_point change
251                           }
252                           unsafe { t.set_datum_type(sout_dt) } // force cast
253                           let t: &mut[$typ_dt] = t.as_slice_mut::<$typ_dt>()?;
254                           let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
255                               let (izp, iscale) = input_dt.zp_scale();
256                               let (ozp, oscale) = out_dt.zp_scale();
257                               xs.iter_mut().for_each(|x| {
258                                   let x_f32 = (*x as f32 - izp as f32) * iscale;
259                                   *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
260                               });
261                               Ok(())
262                           };
263                           f(self, t, input_dt, sout_dt)?;
264                           return Ok(())
265                       }
266                       )*
267                   )*
268                )?
269                bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
270            }
271            $(
272            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
273                $cost(dt)
274            }
275            )?
276            $(
277                fn declutter(
278                    &self,
279                    model: &TypedModel,
280                    node: &TypedNode,
281                ) -> TractResult<Option<TypedModelPatch>> {
282                    $declutter(model, node)
283                }
284            )?
285            $(
286            fn prefix(&self) -> &'static str {
287                $prefix
288            }
289            )?
290            $(
291            fn quantize(
292                &self,
293                dt: DatumType,
294                scale: f32,
295                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
296                    $quantize(&self, dt, scale, zero_point)
297            }
298            )?
299            $(
300            fn validation(&self) -> Validation {
301                $validation
302            }
303            )?
304            $(
305            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
306                ($operating_datum_type)(dt)
307            }
308            )?
309        }
310        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
311            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
312        }
313    }
314}
315
316#[macro_export]
317macro_rules! element_wise_oop {
318    ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
319        $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
320        $(; cost: $cost:expr )?
321        $(; info: $info:expr )?
322        $(; operating_datum_type: $operating_datum_type:expr )?
323        $(; prefix: $prefix:expr )?
324        $(; quantize: $quantize:expr )?
325        $(; validation: $validation:expr )?
326    ) => {
327        #[derive(Debug, Clone)]
328        pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
329        impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
330            fn name(&self) -> String {
331                format!("{}{}", self.prefix(), stringify!($Op))
332            }
333            fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
334                $(
335                    $(if input_type == $typ::datum_type() {
336                        return Some(<$typ_dst>::datum_type())
337                    }
338                    )*
339                )*
340                None
341            }
342            fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
343                $(
344                    let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
345                    $(if t.datum_type() == $typ::datum_type() {
346                        let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
347                        f(self, t.as_slice::<$typ>()?, dst.as_slice_mut::<$typ_dst>()?)?;
348                        return Ok(dst)
349                    }
350                    )*
351                )*
352                bail!("{} does not support {:?}", self.name(), t.datum_type());
353            }
354            $(
355            fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
356                $cost(dt)
357            }
358            )?
359            $(
360            fn info(&self) -> TractResult<Vec<String>> {
361                $info(self)
362            }
363            )?
364            $(
365            fn prefix(&self) -> &'static str {
366                $prefix
367            }
368            )?
369            $(
370            fn quantize(
371                &self,
372                dt: DatumType,
373                scale: f32,
374                zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
375                    $quantize(ft, scale, zero_point)
376            }
377            )?
378            $(
379            fn validation(&self) -> Validation {
380                $validation
381            }
382            )?
383            $(
384            fn operating_datum_type(&self, dt: DatumType) -> DatumType {
385                ($operating_datum_type)(dt)
386            }
387            )?
388        }
389        $(#[$fmeta])*
390        pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
391            $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
392        }
393    }
394}