1use crate::internal::*;
2use downcast_rs::Downcast;
3use std::fmt;
4
5pub trait ElementWiseMiniOp:
6 fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast
7{
8 fn name(&self) -> String;
9 fn prefix(&self) -> &'static str {
10 ""
11 }
12 fn validation(&self) -> Validation {
13 Validation::Accurate
14 }
15 #[allow(unused_variables)]
16 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
17 None
18 }
19 #[allow(unused_variables)]
20 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
21 bail!("Element wise eval in-place not defined");
22 }
23 #[allow(unused_variables)]
24 fn eval_out_of_place(&self, t: &Tensor, out_dt: Option<DatumType>) -> TractResult<Tensor> {
25 bail!("Element wise eval out-of-place place not defined");
26 }
27 #[allow(unused_variables)]
28 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
29 tvec!()
30 }
31 #[allow(unused_variables)]
32 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
33 dt
34 }
35 #[allow(unused_variables)]
36 fn declutter(
37 &self,
38 model: &TypedModel,
39 node: &TypedNode,
40 ) -> TractResult<Option<TypedModelPatch>> {
41 Ok(None)
42 }
43
44 #[allow(unused_variables)]
45 fn quantize(
46 &self,
47 dt: DatumType,
48 scale: f32,
49 zero_point: i32,
50 ) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
51 Ok(None)
52 }
53 #[allow(unused_variables)]
54 fn info(&self) -> TractResult<Vec<String>> {
55 Ok(vec![])
56 }
57
58 #[allow(unused_variables)]
59 fn same_as(&self, other: &dyn ElementWiseMiniOp) -> bool {
60 false
61 }
62}
63
64dyn_clone::clone_trait_object!(ElementWiseMiniOp);
65downcast_rs::impl_downcast!(ElementWiseMiniOp);
66
67#[derive(Debug, Clone)]
68pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>, pub Option<DatumType>);
69
70impl ElementWiseOp {
71 fn output_datum_type(&self, input_dt: DatumType) -> DatumType {
72 self.1.unwrap_or(self.0.operating_datum_type(input_dt))
73 }
74}
75
76impl Op for ElementWiseOp {
77 fn name(&self) -> StaticName {
78 self.0.name().into()
79 }
80
81 fn info(&self) -> TractResult<Vec<String>> {
82 self.0.info()
83 }
84
85 fn validation(&self) -> Validation {
86 self.0.validation()
87 }
88
89 fn same_as(&self, other: &dyn Op) -> bool {
90 let Some(other) = other.downcast_ref::<ElementWiseOp>() else { return false };
91 self.1 == other.1 && self.0.same_as(&*other.0)
92 }
93
94 op_as_typed_op!();
95}
96
97impl EvalOp for ElementWiseOp {
98 fn is_stateless(&self) -> bool {
99 true
100 }
101
102 fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
103 if let Some(_dt) = self.0.output_type(inputs[0].datum_type()) {
104 Ok(tvec!(self.0.eval_out_of_place(&inputs[0], self.1)?.into_tvalue()))
105 } else {
106 let mut m = inputs.remove(0).into_tensor();
107 self.0.eval_in_place(&mut m, self.1)?;
108 Ok(tvec!(m.into()))
109 }
110 }
111}
112
113impl TypedOp for ElementWiseOp {
114 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
115 let mut fact = inputs[0].clone().without_value();
116 let dt = self.output_datum_type(fact.datum_type);
117 if let Some(dt) = self.1 {
118 fact.datum_type = dt;
119 } else if let Some(dt) = self.0.output_type(dt) {
120 fact.datum_type = dt;
121 }
122 Ok(tvec!(fact))
123 }
124
125 fn change_axes(
126 &self,
127 model: &TypedModel,
128 node: &TypedNode,
129 _io: InOut,
130 change: &AxisOp,
131 ) -> TractResult<Option<AxisChangeConsequence>> {
132 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
133 }
134
135 fn declutter(
136 &self,
137 model: &TypedModel,
138 node: &TypedNode,
139 ) -> TractResult<Option<TypedModelPatch>> {
140 if let Some(prec) = model.single_prec(node.id)? {
141 if prec.op_is::<AxisOp>() || prec.op_is::<IntoShape>() {
142 let mut patch = TypedModelPatch::default();
143 let mut wire = tvec!(patch.tap_model(model, prec.inputs[0])?);
144 wire = patch.wire_node(&node.name, &node.op, &wire)?;
145 wire = patch.wire_node(&prec.name, &prec.op, &wire)?;
146 patch.shunt_outside(model, node.id.into(), wire[0])?;
147 return Ok(Some(patch));
148 }
149 }
150 self.0.declutter(model, node)
151 }
152
153 fn axes_mapping(
154 &self,
155 inputs: &[&TypedFact],
156 outputs: &[&TypedFact],
157 ) -> TractResult<AxesMapping> {
158 AxesMapping::natural(inputs, outputs)
159 }
160
161 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
162 let count: TDim = inputs[0].shape.iter().product();
163 Ok(self
164 .0
165 .cost_per_element(inputs[0].datum_type)
166 .into_iter()
167 .map(|(c, n)| (c, count.clone() * n))
168 .collect())
169 }
170
171 fn quantize(
172 &self,
173 _model: &TypedModel,
174 _node: &TypedNode,
175 dt: DatumType,
176 scale: f32,
177 zero_point: i32,
178 ) -> TractResult<Option<Box<dyn TypedOp>>> {
179 if let Some(mini) = self.0.quantize(dt, scale, zero_point)? {
180 Ok(Some(Box::new(ElementWiseOp(mini, self.1))))
181 } else {
182 Ok(None)
183 }
184 }
185
186 fn slice(
187 &self,
188 patch: &mut TypedModelPatch,
189 _model: &TypedModel,
190 node: &TypedNode,
191 _prefix: &str,
192 inputs: &[OutletId],
193 _output_axis: usize,
194 _start: &TDim,
195 _end: &TDim,
196 ) -> TractResult<Option<TVec<OutletId>>> {
197 patch.wire_node(&node.name, &node.op, inputs).map(Some)
198 }
199
200 as_op!();
201}
202
203#[macro_export]
204macro_rules! element_wise {
205 ($func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
206 $([$($typ:ident),*] => $f:expr ),*
207 $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)?
208 $(; cost: $cost:expr )?
209 $(; declutter: $declutter:expr )?
210 $(; operating_datum_type: $operating_datum_type:expr )?
211 $(; prefix: $prefix:expr )?
212 $(; quantize: $quantize:expr )?
213 $(; validation: $validation:expr )?
214 ) => {
215 #[derive(Debug, Clone)]
216 pub struct $Op { $( $( $(#[$meta])? pub $var: $var_typ),* )? }
217 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
218 fn name(&self) -> String {
219 format!("{}{}", self.prefix(), stringify!($Op))
220 }
221 #[allow(unused_variables)]
222 fn same_as(&self, other: &dyn ElementWiseMiniOp) -> bool {
223 let Some(other) = other.downcast_ref::<$Op>() else { return false };
224 $( $( if self.$var != other.$var { return false; })* )?
225 true
226 }
227 fn eval_in_place(&self, t: &mut Tensor, out_dt: Option<DatumType>) -> TractResult<()> {
228 $(
229 $(if out_dt.unwrap_or(t.datum_type()) == $typ::datum_type() {
230 let mut t_dense = t.try_as_dense_mut()?;
231 let t: &mut[$typ] = t_dense.as_slice_mut::<$typ>()?;
232 let f: fn(&Self, &mut[$typ]) -> TractResult<()> = $f;
233 f(self, t)?;
234 return Ok(())
235 }
236 )*
237 )*
238 $(
239 $(
240 $(
241 let mut input_dt = t.datum_type();
242 let sout_dt = out_dt.unwrap_or(input_dt);
243 if sout_dt.unquantized() == <$typ_dt>::datum_type().unquantized() {
244 if input_dt.unquantized() != sout_dt.unquantized() {
245 *t = match input_dt.unquantized() {
247 DatumType::U8 => t.clone().into_arc_tensor().offset_u8_as_i8(),
248 DatumType::I8 => t.clone().into_arc_tensor().offset_i8_as_u8(),
249 unknown_dt => bail!("unexpected quantization input dt {:?}", unknown_dt)
250 }.into_tensor();
251 input_dt = t.datum_type(); }
253 unsafe { t.set_datum_type(sout_dt) } let mut t_dense = t.try_as_dense_mut()?;
255 let t: &mut[$typ_dt] = t_dense.as_slice_mut::<$typ_dt>()?;
256 let f: fn(&Self, &mut[$typ_dt], DatumType, DatumType) -> TractResult<()> = |_, xs, input_dt, out_dt| {
257 let (izp, iscale) = input_dt.zp_scale();
258 let (ozp, oscale) = out_dt.zp_scale();
259 xs.iter_mut().for_each(|x| {
260 let x_f32 = (*x as f32 - izp as f32) * iscale;
261 *x = (($f_f32(x_f32) / oscale) + ozp as f32).as_()
262 });
263 Ok(())
264 };
265 f(self, t, input_dt, sout_dt)?;
266 return Ok(())
267 }
268 )*
269 )*
270 )?
271 bail!("{} does not support {:?}", self.name(), out_dt.unwrap_or(t.datum_type()));
272 }
273 $(
274 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
275 $cost(dt)
276 }
277 )?
278 $(
279 fn declutter(
280 &self,
281 model: &TypedModel,
282 node: &TypedNode,
283 ) -> TractResult<Option<TypedModelPatch>> {
284 $declutter(model, node)
285 }
286 )?
287 $(
288 fn prefix(&self) -> &'static str {
289 $prefix
290 }
291 )?
292 $(
293 fn quantize(
294 &self,
295 dt: DatumType,
296 scale: f32,
297 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
298 $quantize(&self, dt, scale, zero_point)
299 }
300 )?
301 $(
302 fn validation(&self) -> Validation {
303 $validation
304 }
305 )?
306 $(
307 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
308 ($operating_datum_type)(dt)
309 }
310 )?
311 }
312 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
313 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
314 }
315 }
316}
317
318#[macro_export]
319macro_rules! element_wise_oop {
320 ($(#[$fmeta:meta])* $func:ident, $Op:ident $({$( $(#[$meta: meta])? $var: ident : $var_typ: path),*})?,
321 $( [$($typ:ident),*] => $typ_dst:ident $f:expr ),*
322 $(; cost: $cost:expr )?
323 $(; info: $info:expr )?
324 $(; operating_datum_type: $operating_datum_type:expr )?
325 $(; prefix: $prefix:expr )?
326 $(; quantize: $quantize:expr )?
327 $(; validation: $validation:expr )?
328 ) => {
329 #[derive(Debug, Clone)]
330 pub struct $Op { $( $($(#[$meta])? pub $var: $var_typ),* )? }
331 impl $crate::ops::element_wise::ElementWiseMiniOp for $Op {
332 fn name(&self) -> String {
333 format!("{}{}", self.prefix(), stringify!($Op))
334 }
335 fn output_type(&self, input_type: DatumType) -> Option<DatumType> {
336 $(
337 $(if input_type == $typ::datum_type() {
338 return Some(<$typ_dst>::datum_type())
339 }
340 )*
341 )*
342 None
343 }
344 fn eval_out_of_place(&self, t: &Tensor, _out_dt: Option<DatumType>) -> TractResult<Tensor> {
345 $(
346 let mut dst = unsafe { Tensor::uninitialized_dt(<$typ_dst>::datum_type(), &t.shape())? };
347 $(if t.datum_type() == $typ::datum_type() {
348 let f: fn(&Self, &[$typ], &mut[$typ_dst]) -> TractResult<()> = $f;
349 let mut dst_dense = dst.try_as_dense_mut()?;
350 f(self, t.try_as_dense()?.as_slice::<$typ>()?, dst_dense.as_slice_mut::<$typ_dst>()?)?;
351 return Ok(dst)
352 }
353 )*
354 )*
355 bail!("{} does not support {:?}", self.name(), t.datum_type());
356 }
357 $(
358 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
359 $cost(dt)
360 }
361 )?
362 $(
363 fn info(&self) -> TractResult<Vec<String>> {
364 $info(self)
365 }
366 )?
367 $(
368 fn prefix(&self) -> &'static str {
369 $prefix
370 }
371 )?
372 $(
373 fn quantize(
374 &self,
375 dt: DatumType,
376 scale: f32,
377 zero_point: i32) -> TractResult<Option<Box<dyn ElementWiseMiniOp>>> {
378 $quantize(ft, scale, zero_point)
379 }
380 )?
381 $(
382 fn validation(&self) -> Validation {
383 $validation
384 }
385 )?
386 $(
387 fn operating_datum_type(&self, dt: DatumType) -> DatumType {
388 ($operating_datum_type)(dt)
389 }
390 )?
391 }
392 $(#[$fmeta])*
393 pub fn $func($( $($var: $var_typ),* )?) -> $crate::ops::element_wise::ElementWiseOp {
394 $crate::ops::element_wise::ElementWiseOp(Box::new($Op { $( $($var),* )? }), None)
395 }
396 }
397}