1use crate::internal::*;
2use crate::ops::array::MultiBroadcastTo;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt;
6
7pub trait ElementWiseMiniOp:
8 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
9{
10 fn name(&self) -> String;
11 fn prefix(&self) -> &'static str {
12 ""
13 }
14 fn validation(&self) -> Validation {
15 Validation::Accurate
16 }
17 #[allow(unused_variables)]
18 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
19 None
20 }
21 #[allow(unused_variables)]
22 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
23 bail!("Element wise eval in-place not defined");
24 }
25 #[allow(unused_variables)]
26 fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
27 bail!("Element wise eval out-of-place place not defined");
28 }
29 #[allow(unused_variables)]
30 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
31 tvec!()
32 }
33 #[allow(unused_variables)]
34 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
35 dt
36 }
37 #[allow(unused_variables)]
38 fn declutter(
39 &self,
40 model: &TypedModel,
41 node: &TypedNode,
42 ) -> TractResult<Option<TypedModelPatch>> {
43 Ok(None)
44 }
45
46 #[allow(unused_variables)]
47 fn quantize(
48 &self,
49 dt: DatumType,
50 scale: f32,
51 zero_point: i32,
52 ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
53 Ok(None)
54 }
55 #[allow(unused_variables)]
56 fn info(&self) -> TractResult<Vec<String>> {
57 Ok(vec![])
58 }
59}
60
61dyn_clone::clone_trait_object!(ElementWiseMiniOp);
62dyn_eq::eq_trait_object!(ElementWiseMiniOp);
63downcast_rs::impl_downcast!(ElementWiseMiniOp);
64
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
67
68impl ElementWiseOp {
69 fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
70 self.1.unwrap_or(self.0.operating_datum_type(input_dt))
71 }
72}
73
74impl Op for ElementWiseOp {
75 fn name(&self) -> StaticName {
76 self.0.name().into()
77 }
78
79 fn info(&self) -> TractResult<Vec<String>> {
80 self.0.info()
81 }
82
83 fn validation(&self) -> Validation {
84 self.0.validation()
85 }
86
87 op_as_typed_op!();
88}
89
90impl EvalOp for ElementWiseOp {
91 fn is_stateless(&self) -> bool {
92 true
93 }
94
95 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
96 if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
97 Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
98 } else {
99 let mut m = inputs.remove(0).into_tensor();
100 self.0.eval_in_place(&mut m, self.1)?;
101 Ok(tvec!(m.into()))
102 }
103 }
104}
105
106impl TypedOp for ElementWiseOp {
107 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108 let mut fact = inputs[0].clone().without_value();
109 let dt = self.output_datum_type(fact.datum_type);
110 if let Some(dt) = self.1 {
111 fact.datum_type = dt;
112 } else if let Some(dt) = self.0.output_type(dt) {
113 fact.datum_type = dt;
114 }
115 if let Some(tdim) = &inputs[0].uniform_tdim {
117 let is_logical_not = self.0.downcast_ref::<crate::ops::logic::Not>().is_some()
121 || (self.0.downcast_ref::<crate::ops::logic::BitNot>().is_some()
122 && inputs[0].datum_type == bool::datum_type());
123 if is_logical_not {
124 fact.uniform_tdim = Some((TDim::Val(1) - tdim.clone()).reduce());
125 } else {
126 let mut tmp = tensor0(tdim.clone());
130 if self.0.eval_in_place(&mut tmp, None).is_ok() {
131 fact.uniform_tdim = tmp
132 .try_as_plain()
133 .ok()
134 .and_then(|d| d.as_slice::<TDim>().ok())
135 .and_then(|s| s.first())
136 .cloned()
137 .map(|d| d.reduce());
138 }
139 }
140 }
141 Ok(tvec!(fact))
142 }
143
144 fn input_roi(
145 &self,
146 model: &TypedModel,
147 node: &TypedNode,
148 ) -> TractResult<Option<TVec<Option<TDim>>>> {
149 crate::optim::propagate_roi::bubble_roi(model, node)
150 }
151
152 fn change_axes(
153 &self,
154 model: &TypedModel,
155 node: &TypedNode,
156 _io: InOut,
157 change: &AxisOp,
158 ) -> TractResult<Option<AxisChangeConsequence>> {
159 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160 }
161
162 fn declutter(
163 &self,
164 model: &TypedModel,
165 node: &TypedNode,
166 ) -> TractResult<Option<TypedModelPatch>> {
167 if let Some(prec) = model.linear_prec(node.id)?
172 && (prec.op_is::<AxisOp>()
173 || prec.op_is::<IntoShape>()
174 || prec.op_is::<MultiBroadcastTo>())
175 {
176 let mut patch = TypedModelPatch::default();
177 let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
178 wire = patch.wire_node(&node.name, &node.op, &wire)?;
179 wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
180 patch.shunt_outside(model, node.id.into(), wire[0])?;
181 return Ok(Some(patch));
182 }
183 self.0.declutter(model, node)
184 }
185
186 fn axes_mapping(
187 &self,
188 inputs: &[&TypedFact],
189 outputs: &[&TypedFact],
190 ) -> TractResult<AxesMapping> {
191 AxesMapping::natural(inputs, outputs)
192 }
193
194 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
195 let count: TDim = inputs[0].shape.iter().product();
196 Ok(self
197 .0
198 .cost_per_element(inputs[0].datum_type)
199 .into_iter()
200 .map(|(c, n)| (c, count.clone() * n))
201 .collect())
202 }
203
204 fn quantize(
205 &self,
206 _model: &TypedModel,
207 _node: &TypedNode,
208 dt: DatumType,
209 scale: f32,
210 zero_point: i32,
211 ) -> TractResult<Option<Box<dyn TypedOp>>> {
212 if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
213 Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
214 } else {
215 Ok(None)
216 }
217 }
218
219 fn slice(
220 &self,
221 patch: &mut TypedModelPatch,
222 _model: &TypedModel,
223 node: &TypedNode,
224 _prefix: &str,
225 inputs: &[OutletId],
226 _output_axis: usize,
227 _start: &TDim,
228 _end: &TDim,
229 ) -> TractResult<Option<TVec<OutletId>>> {
230 patch.wire_node(&node.name, &node.op, inputs).map(Some)
231 }
232
233 as_op!();
234}
235
236#[macro_export]
237macro_rules! element_wise {
238 ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
239 $([$($typ:ident),*] => $f:expr ),*
240 $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
241 $(; cost: $cost:expr )?
242 $(; declutter: $declutter:expr )?
243 $(; operating_datum_type: $operating_datum_type:expr )?
244 $(; prefix: $prefix:expr )?
245 $(; quantize: $quantize:expr )?
246 $(; validation: $validation:expr )?
247 ) => {
248 #[derive(Debug, Clone, PartialEq)]
249 pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
250 impl Eq for $Op {}
251 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
252 fn name(&self) -> String {
253 format!("{}{}", self.prefix(), stringify!($Op))
254 }
255 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
256 $(
257 $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
258 let mut t_plain = t.try_as_plain_mut()?;
259 let t: &mut[$typ] = t_plain.as_slice_mut::<$typ>()?;
260 let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
261 f(self, t)?;
262 return Ok(())
263 }
264 )*
265 )*
266 $(
267 $(
268 $(
269 let mut input_dt = t.datum_type();
270 let sout_dt = out_dt.unwrap_or(input_dt);
271 if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
272 if input_dt.unquantized() != sout_dt.unquantized() {
273 *t = match input_dt.unquantized() {
275 DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
276 DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
277 unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
278 }.into_tensor();
279 input_dt = t.datum_type(); }
281 unsafe { t.set_datum_type(sout_dt) } let mut t_plain = t.try_as_plain_mut()?;
283 let t: &mut[$typ_dt] = t_plain.as_slice_mut::<$typ_dt>()?;
284 let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
285 let (izp, iscale) = input_dt.zp_scale();
286 let (ozp, oscale) = out_dt.zp_scale();
287 xs.iter_mut().for_each(|x| {
288 let x_f32 = (*x as f32 - izp as f32) * iscale;
289 *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
290 });
291 Ok(())
292 };
293 f(self, t, input_dt, sout_dt)?;
294 return Ok(())
295 }
296 )*
297 )*
298 )?
299 bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
300 }
301 $(
302 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
303 $cost(dt)
304 }
305 )?
306 $(
307 fn declutter(
308 &self,
309 model: &TypedModel,
310 node: &TypedNode,
311 ) -> TractResult<Option<TypedModelPatch>> {
312 $declutter(model, node)
313 }
314 )?
315 $(
316 fn prefix(&self) -> &'static str {
317 $prefix
318 }
319 )?
320 $(
321 fn quantize(
322 &self,
323 dt: DatumType,
324 scale: f32,
325 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
326 $quantize(&self, dt, scale, zero_point)
327 }
328 )?
329 $(
330 fn validation(&self) -> Validation {
331 $validation
332 }
333 )?
334 $(
335 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
336 ($operating_datum_type)(dt)
337 }
338 )?
339 }
340 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
341 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
342 }
343 }
344}
345
346#[macro_export]
347macro_rules! element_wise_oop {
348 ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
349 $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
350 $(; cost: $cost:expr )?
351 $(; info: $info:expr )?
352 $(; operating_datum_type: $operating_datum_type:expr )?
353 $(; prefix: $prefix:expr )?
354 $(; quantize: $quantize:expr )?
355 $(; validation: $validation:expr )?
356 ) => {
357 #[derive(Debug, Clone)]
358 pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
359 impl PartialEq for $Op {
360 #[allow(unused_variables)]
361 fn eq(&self, other: &Self) -> bool {
362 $( $( if &self.$var != &other.$var { return false; })* )?
363 true
364 }
365 }
366 impl Eq for $Op {}
367 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
368 fn name(&self) -> String {
369 format!("{}{}", self.prefix(), stringify!($Op))
370 }
371 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
372 $(
373 $(if input_type == $typ::datum_type() {
374 return Some(<$typ_dst>::datum_type())
375 }
376 )*
377 )*
378 None
379 }
380 fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
381 $(
382 let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
383 $(if t.datum_type() == $typ::datum_type() {
384 let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
385 let mut dst_plain = dst.try_as_plain_mut()?;
386 f(self, t.try_as_plain()?.as_slice::<$typ>()?, dst_plain.as_slice_mut::<$typ_dst>()?)?;
387 return Ok(dst)
388 }
389 )*
390 )*
391 bail!("{} does not support {:?}", self.name(), t.datum_type());
392 }
393 $(
394 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
395 $cost(dt)
396 }
397 )?
398 $(
399 fn info(&self) -> TractResult<Vec<String>> {
400 $info(self)
401 }
402 )?
403 $(
404 fn prefix(&self) -> &'static str {
405 $prefix
406 }
407 )?
408 $(
409 fn quantize(
410 &self,
411 dt: DatumType,
412 scale: f32,
413 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
414 $quantize(ft, scale, zero_point)
415 }
416 )?
417 $(
418 fn validation(&self) -> Validation {
419 $validation
420 }
421 )?
422 $(
423 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
424 ($operating_datum_type)(dt)
425 }
426 )?
427 }
428 $(#[$fmeta])*
429 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
430 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
431 }
432 }
433}