1use crate::internal::*;
2use downcast_rs::Downcast;
3use dyn_eq::DynEq;
4use std::fmt;
5
6pub trait ElementWiseMiniOp:
7 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
8{
9 fn name(&self) -> String;
10 fn prefix(&self) -> &'static str {
11 ""
12 }
13 fn validation(&self) -> Validation {
14 Validation::Accurate
15 }
16 #[allow(unused_variables)]
17 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
18 None
19 }
20 #[allow(unused_variables)]
21 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
22 bail!("Element wise eval in-place not defined");
23 }
24 #[allow(unused_variables)]
25 fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
26 bail!("Element wise eval out-of-place place not defined");
27 }
28 #[allow(unused_variables)]
29 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
30 tvec!()
31 }
32 #[allow(unused_variables)]
33 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
34 dt
35 }
36 #[allow(unused_variables)]
37 fn declutter(
38 &self,
39 model: &TypedModel,
40 node: &TypedNode,
41 ) -> TractResult<Option<TypedModelPatch>> {
42 Ok(None)
43 }
44
45 #[allow(unused_variables)]
46 fn quantize(
47 &self,
48 dt: DatumType,
49 scale: f32,
50 zero_point: i32,
51 ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
52 Ok(None)
53 }
54 #[allow(unused_variables)]
55 fn info(&self) -> TractResult<Vec<String>> {
56 Ok(vec![])
57 }
58}
59
60dyn_clone::clone_trait_object!(ElementWiseMiniOp);
61dyn_eq::eq_trait_object!(ElementWiseMiniOp);
62downcast_rs::impl_downcast!(ElementWiseMiniOp);
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
66
67impl ElementWiseOp {
68 fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
69 self.1.unwrap_or(self.0.operating_datum_type(input_dt))
70 }
71}
72
73impl Op for ElementWiseOp {
74 fn name(&self) -> StaticName {
75 self.0.name().into()
76 }
77
78 fn info(&self) -> TractResult<Vec<String>> {
79 self.0.info()
80 }
81
82 fn validation(&self) -> Validation {
83 self.0.validation()
84 }
85
86 op_as_typed_op!();
87}
88
89impl EvalOp for ElementWiseOp {
90 fn is_stateless(&self) -> bool {
91 true
92 }
93
94 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
95 if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
96 Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
97 } else {
98 let mut m = inputs.remove(0).into_tensor();
99 self.0.eval_in_place(&mut m, self.1)?;
100 Ok(tvec!(m.into()))
101 }
102 }
103}
104
105impl TypedOp for ElementWiseOp {
106 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
107 let mut fact = inputs[0].clone().without_value();
108 let dt = self.output_datum_type(fact.datum_type);
109 if let Some(dt) = self.1 {
110 fact.datum_type = dt;
111 } else if let Some(dt) = self.0.output_type(dt) {
112 fact.datum_type = dt;
113 }
114 if let Some(tdim) = &inputs[0].uniform_tdim {
116 let is_logical_not = self.0.downcast_ref::<crate::ops::logic::Not>().is_some()
120 || (self.0.downcast_ref::<crate::ops::logic::BitNot>().is_some()
121 && inputs[0].datum_type == bool::datum_type());
122 if is_logical_not {
123 fact.uniform_tdim = Some((TDim::Val(1) - tdim.clone()).reduce());
124 } else {
125 let mut tmp = tensor0(tdim.clone());
129 if self.0.eval_in_place(&mut tmp, None).is_ok() {
130 fact.uniform_tdim = tmp
131 .try_as_plain()
132 .ok()
133 .and_then(|d| d.as_slice::<TDim>().ok())
134 .and_then(|s| s.first())
135 .cloned()
136 .map(|d| d.reduce());
137 }
138 }
139 }
140 Ok(tvec!(fact))
141 }
142
143 fn input_roi(
144 &self,
145 model: &TypedModel,
146 node: &TypedNode,
147 ) -> TractResult<Option<TVec<Option<TDim>>>> {
148 crate::optim::propagate_roi::bubble_roi(model, node)
149 }
150
151 fn change_axes(
152 &self,
153 model: &TypedModel,
154 node: &TypedNode,
155 _io: InOut,
156 change: &AxisOp,
157 ) -> TractResult<Option<AxisChangeConsequence>> {
158 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
159 }
160
161 fn declutter(
162 &self,
163 model: &TypedModel,
164 node: &TypedNode,
165 ) -> TractResult<Option<TypedModelPatch>> {
166 if let Some(prec) = model.single_prec(node.id)?
167 && (prec.op_is::<AxisOp>() || prec.op_is::<IntoShape>())
168 {
169 let mut patch = TypedModelPatch::default();
170 let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
171 wire = patch.wire_node(&node.name, &node.op, &wire)?;
172 wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
173 patch.shunt_outside(model, node.id.into(), wire[0])?;
174 return Ok(Some(patch));
175 }
176 self.0.declutter(model, node)
177 }
178
179 fn axes_mapping(
180 &self,
181 inputs: &[&TypedFact],
182 outputs: &[&TypedFact],
183 ) -> TractResult<AxesMapping> {
184 AxesMapping::natural(inputs, outputs)
185 }
186
187 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
188 let count: TDim = inputs[0].shape.iter().product();
189 Ok(self
190 .0
191 .cost_per_element(inputs[0].datum_type)
192 .into_iter()
193 .map(|(c, n)| (c, count.clone() * n))
194 .collect())
195 }
196
197 fn quantize(
198 &self,
199 _model: &TypedModel,
200 _node: &TypedNode,
201 dt: DatumType,
202 scale: f32,
203 zero_point: i32,
204 ) -> TractResult<Option<Box<dyn TypedOp>>> {
205 if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
206 Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
207 } else {
208 Ok(None)
209 }
210 }
211
212 fn slice(
213 &self,
214 patch: &mut TypedModelPatch,
215 _model: &TypedModel,
216 node: &TypedNode,
217 _prefix: &str,
218 inputs: &[OutletId],
219 _output_axis: usize,
220 _start: &TDim,
221 _end: &TDim,
222 ) -> TractResult<Option<TVec<OutletId>>> {
223 patch.wire_node(&node.name, &node.op, inputs).map(Some)
224 }
225
226 as_op!();
227}
228
229#[macro_export]
230macro_rules! element_wise {
231 ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
232 $([$($typ:ident),*] => $f:expr ),*
233 $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
234 $(; cost: $cost:expr )?
235 $(; declutter: $declutter:expr )?
236 $(; operating_datum_type: $operating_datum_type:expr )?
237 $(; prefix: $prefix:expr )?
238 $(; quantize: $quantize:expr )?
239 $(; validation: $validation:expr )?
240 ) => {
241 #[derive(Debug, Clone, PartialEq)]
242 pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
243 impl Eq for $Op {}
244 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
245 fn name(&self) -> String {
246 format!("{}{}", self.prefix(), stringify!($Op))
247 }
248 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
249 $(
250 $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
251 let mut t_plain = t.try_as_plain_mut()?;
252 let t: &mut[$typ] = t_plain.as_slice_mut::<$typ>()?;
253 let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
254 f(self, t)?;
255 return Ok(())
256 }
257 )*
258 )*
259 $(
260 $(
261 $(
262 let mut input_dt = t.datum_type();
263 let sout_dt = out_dt.unwrap_or(input_dt);
264 if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
265 if input_dt.unquantized() != sout_dt.unquantized() {
266 *t = match input_dt.unquantized() {
268 DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
269 DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
270 unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
271 }.into_tensor();
272 input_dt = t.datum_type(); }
274 unsafe { t.set_datum_type(sout_dt) } let mut t_plain = t.try_as_plain_mut()?;
276 let t: &mut[$typ_dt] = t_plain.as_slice_mut::<$typ_dt>()?;
277 let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
278 let (izp, iscale) = input_dt.zp_scale();
279 let (ozp, oscale) = out_dt.zp_scale();
280 xs.iter_mut().for_each(|x| {
281 let x_f32 = (*x as f32 - izp as f32) * iscale;
282 *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
283 });
284 Ok(())
285 };
286 f(self, t, input_dt, sout_dt)?;
287 return Ok(())
288 }
289 )*
290 )*
291 )?
292 bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
293 }
294 $(
295 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
296 $cost(dt)
297 }
298 )?
299 $(
300 fn declutter(
301 &self,
302 model: &TypedModel,
303 node: &TypedNode,
304 ) -> TractResult<Option<TypedModelPatch>> {
305 $declutter(model, node)
306 }
307 )?
308 $(
309 fn prefix(&self) -> &'static str {
310 $prefix
311 }
312 )?
313 $(
314 fn quantize(
315 &self,
316 dt: DatumType,
317 scale: f32,
318 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
319 $quantize(&self, dt, scale, zero_point)
320 }
321 )?
322 $(
323 fn validation(&self) -> Validation {
324 $validation
325 }
326 )?
327 $(
328 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
329 ($operating_datum_type)(dt)
330 }
331 )?
332 }
333 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
334 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
335 }
336 }
337}
338
339#[macro_export]
340macro_rules! element_wise_oop {
341 ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
342 $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
343 $(; cost: $cost:expr )?
344 $(; info: $info:expr )?
345 $(; operating_datum_type: $operating_datum_type:expr )?
346 $(; prefix: $prefix:expr )?
347 $(; quantize: $quantize:expr )?
348 $(; validation: $validation:expr )?
349 ) => {
350 #[derive(Debug, Clone)]
351 pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
352 impl PartialEq for $Op {
353 #[allow(unused_variables)]
354 fn eq(&self, other: &Self) -> bool {
355 $( $( if &self.$var != &other.$var { return false; })* )?
356 true
357 }
358 }
359 impl Eq for $Op {}
360 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
361 fn name(&self) -> String {
362 format!("{}{}", self.prefix(), stringify!($Op))
363 }
364 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
365 $(
366 $(if input_type == $typ::datum_type() {
367 return Some(<$typ_dst>::datum_type())
368 }
369 )*
370 )*
371 None
372 }
373 fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
374 $(
375 let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
376 $(if t.datum_type() == $typ::datum_type() {
377 let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
378 let mut dst_plain = dst.try_as_plain_mut()?;
379 f(self, t.try_as_plain()?.as_slice::<$typ>()?, dst_plain.as_slice_mut::<$typ_dst>()?)?;
380 return Ok(dst)
381 }
382 )*
383 )*
384 bail!("{} does not support {:?}", self.name(), t.datum_type());
385 }
386 $(
387 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
388 $cost(dt)
389 }
390 )?
391 $(
392 fn info(&self) -> TractResult<Vec<String>> {
393 $info(self)
394 }
395 )?
396 $(
397 fn prefix(&self) -> &'static str {
398 $prefix
399 }
400 )?
401 $(
402 fn quantize(
403 &self,
404 dt: DatumType,
405 scale: f32,
406 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
407 $quantize(ft, scale, zero_point)
408 }
409 )?
410 $(
411 fn validation(&self) -> Validation {
412 $validation
413 }
414 )?
415 $(
416 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
417 ($operating_datum_type)(dt)
418 }
419 )?
420 }
421 $(#[$fmeta])*
422 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
423 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
424 }
425 }
426}