1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::{cast::cast, math::SubF};
10
11pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
12 fn name(&self) -> &'static str;
13 fn validation(&self) -> Validation {
14 Validation::Accurate
15 }
16 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
17 a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
18 }
19 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
20 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
21 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
22
23 fn is_commutative(&self) -> bool {
24 true
25 }
26 fn neutral_element(&self) -> Option<i64> {
27 None
28 }
29
30 #[allow(unused_variables)]
31 fn maybe_eval_qbinary_as_float_op(
32 &self,
33 a: &TValue,
34 b: &TValue,
35 c_dt: &DatumType,
36 ) -> TractResult<Option<Tensor>> {
37 Ok(None)
38 }
39
40 fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
41 if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
42 Ok(tensor)
43 } else {
44 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
45 if &*c_shape == a.shape() && c_dt == a.datum_type() {
46 let mut a = a.into_tensor();
47 self.eval_in_a(&mut a, &b)?;
48 Ok(a)
49 } else {
50 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
51 self.eval_out_of_place(&mut c, &a, &b)?;
52 Ok(c)
53 }
54 }
55 }
56 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
57 self.generic_eval(a, b, c_dt)
58 }
59 #[allow(unused_variables)]
60 fn declutter(
61 &self,
62 model: &TypedModel,
63 node: &TypedNode,
64 ) -> TractResult<Option<TypedModelPatch>> {
65 Ok(None)
66 }
67 #[allow(unused_variables)]
68 fn codegen(
69 &self,
70 model: &TypedModel,
71 node: &TypedNode,
72 ) -> TractResult<Option<TypedModelPatch>> {
73 Ok(None)
74 }
75 #[allow(unused_variables)]
76 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
77 tvec!()
78 }
79 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
80 None
81 }
82
83 #[allow(unused_variables)]
84 fn same_as(&self, other: &dyn BinMiniOp) -> bool {
85 false
86 }
87}
88dyn_clone::clone_trait_object!(BinMiniOp);
89downcast_rs::impl_downcast!(BinMiniOp);
90
91#[derive(Debug, Clone)]
92pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
93
94impl Op for TypedBinOp {
95 fn name(&self) -> Cow<str> {
96 self.0.name().into()
97 }
98
99 fn validation(&self) -> Validation {
100 self.0.validation()
101 }
102
103 fn same_as(&self, other: &dyn Op) -> bool {
104 let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
105 self.1 == other.1 && self.0.same_as(&*other.0)
106 }
107
108 op_as_typed_op!();
109}
110
111impl TypedBinOp {
112 fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
113 if let Some(dt) = self.1 {
114 Ok(dt)
115 } else {
116 self.0.result_datum_type(a_dt, b_dt)
117 }
118 }
119}
120
121impl EvalOp for TypedBinOp {
122 fn is_stateless(&self) -> bool {
123 true
124 }
125
126 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127 let (a, b) = args_2!(inputs);
128 ensure!(a.rank() == b.rank());
129 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
130 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
131 }
132}
133
134impl TypedOp for TypedBinOp {
135 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
136 if inputs[0].rank() != inputs[1].rank() {
137 bail!(
138 "Typed ops require rank match. Invalid inputs for {}: {}",
139 self.name(),
140 inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
141 );
142 }
143 let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
144 Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
145 &inputs[0].shape.to_tvec(),
146 &inputs[1].shape.to_tvec()
147 ])?)))
148 }
149
150 fn change_axes(
151 &self,
152 model: &TypedModel,
153 node: &TypedNode,
154 _io: InOut,
155 change: &AxisOp,
156 ) -> TractResult<Option<AxisChangeConsequence>> {
157 if let AxisOp::Rm(rm) = change {
158 let (inputs, outputs) = model.node_facts(node.id)?;
159 if !inputs[0].shape[*rm].is_one()
160 || !inputs[1].shape[*rm].is_one()
161 || !outputs[0].shape[*rm].is_one()
162 {
163 return Ok(None);
164 }
165 }
166 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
167 }
168
169 fn axes_mapping(
170 &self,
171 inputs: &[&TypedFact],
172 outputs: &[&TypedFact],
173 ) -> TractResult<AxesMapping> {
174 AxesMapping::natural(inputs, outputs)
175 }
176
177 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
178 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
179 Ok(self
180 .0
181 .cost_per_element(inputs[0].datum_type)
182 .into_iter()
183 .map(|(c, n)| (c, count.clone() * n))
184 .collect())
185 }
186
187 fn slice(
188 &self,
189 patch: &mut TypedModelPatch,
190 _model: &TypedModel,
191 _node: &TypedNode,
192 prefix: &str,
193 inputs: &[OutletId],
194 _output_axis: usize,
195 _start: usize,
196 _end: usize,
197 ) -> TractResult<Option<TVec<OutletId>>> {
198 Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
199 }
200
201 fn declutter(
202 &self,
203 model: &TypedModel,
204 node: &TypedNode,
205 ) -> TractResult<Option<TypedModelPatch>> {
206 let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
207 (a.datum_type().unwrap(), b.datum_type().unwrap())
208 } else {
209 unreachable!("TypedBinOp has two inputs.")
210 };
211 if let Some(neutral_patch) =
212 declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
213 {
214 return Ok(Some(neutral_patch));
215 }
216 if let Some(broadcast_patch) =
217 declutter_broadcasting_operand_1(model, node, self.0.clone())?
218 {
219 return Ok(Some(broadcast_patch));
220 }
221 self.0.declutter(model, node)
222 }
223
224 fn codegen(
225 &self,
226 model: &TypedModel,
227 node: &TypedNode,
228 ) -> TractResult<Option<TypedModelPatch>> {
229 if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
230 let (operand_1, operand_2, should_swap, op) =
231 if let &[a, b] = &*model.node_input_facts(node.id)? {
232 let num_elements_1 = a.shape.iter().product::<TDim>();
233 let num_elements_2 = b.shape.iter().product::<TDim>();
234 let operand_1_should_be_broadcast =
235 (num_elements_1 - num_elements_2).prove_strict_negative();
236
237 let is_sub = linalg_bin_op == BinOp::Sub;
238 if operand_1_should_be_broadcast & is_sub {
239 let sub_flipped: Box<dyn BinMiniOp> = Box::new(SubF {});
240 (b, a, true, sub_flipped)
241 } else {
242 (a, b, false, self.0.clone())
243 }
244 } else {
245 unreachable!("TypedBinOp has two inputs.")
246 };
247
248 let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
249 find_most_efficient_config(model, node, should_swap)?;
250
251 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
253 let op_is_quant = c_dt.is_quantized()
254 || operand_1.datum_type.is_quantized()
255 || operand_2.datum_type.is_quantized();
256
257 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
259 let c_shape = crate::broadcast::multi_broadcast(&[
260 operand_1.shape.clone(),
261 operand_2.shape.clone(),
262 ])?;
263 let can_eval_in_a =
264 (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
265
266 let inputs = if should_swap {
268 let mut swap_input = node.inputs.clone();
269 swap_input.swap(0, 1);
270 swap_input
271 } else {
272 node.inputs.clone()
273 };
274
275 let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap();
276 if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
277 let Some(func) = tract_linalg::bin_by_scalar(dt, linalg_bin_op) else {
278 return Ok(None);
279 };
280 let eval_fn = Arc::from(func);
281 return Ok(Some(
282 TypedModelPatch::replace_single_op(
283 model,
284 node,
285 &inputs,
286 OptBinByScalar { binop: op, eval_fn },
287 )?
288 .with_context("ByScalar"),
289 ));
290 }
291
292 if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
293 let Some(func) = tract_linalg::bin_unicast(dt, linalg_bin_op) else {
294 return Ok(None);
295 };
296 let eval_fn = Arc::from(func);
297 return Ok(Some(
298 TypedModelPatch::replace_single_op(
299 model,
300 node,
301 &inputs,
302 OptBinUnicast { binop: op, eval_fn },
303 )?
304 .with_context("Unicast"),
305 ));
306 }
307 }
308
309 Ok(None)
310 }
311 as_op!();
312}
313
314fn declutter_broadcasting_operand_1(
315 model: &TypedModel,
316 node: &TypedNode,
317 mini_op: Box<dyn BinMiniOp>,
318) -> TractResult<Option<TypedModelPatch>> {
319 let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
320 (a.shape.clone(), b.shape.clone())
321 } else {
322 unreachable!("TypedBinOp has two inputs.")
323 };
324
325 let a_num_elements = a_shape.iter().product::<TDim>();
326 let b_num_elements = b_shape.iter().product::<TDim>();
327 let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
328 if a_should_be_broadcast & mini_op.is_commutative() {
329 let mut swap_input = node.inputs.clone();
330 swap_input.swap(0, 1);
331 return Ok(Some(TypedModelPatch::replace_single_op(
332 model,
333 node,
334 &swap_input,
335 TypedBinOp(mini_op, None),
336 )?));
337 }
338
339 Ok(None)
340}
341
342fn declutter_neutral(
343 model: &TypedModel,
344 node: &TypedNode,
345 mini_op: &dyn BinMiniOp,
346 out_dt: DatumType,
347) -> TractResult<Option<TypedModelPatch>> {
348 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
349 let is_neutral = mini_op
350 .neutral_element()
351 .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
352 .unwrap_or(false);
353
354 let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
357
358 if is_neutral && pos_checked {
359 if uniform.uni.datum_type().is_quantized() {
364 return Ok(Some(TypedModelPatch::replace_single_op(
365 model,
366 node,
367 &[node.inputs[0]],
368 cast(out_dt),
369 )?));
370 } else {
372 return Ok(Some(TypedModelPatch::rewire(
373 model,
374 &[uniform.var],
375 &[node.id.into()],
376 &|_, inputs| Ok(inputs.into()),
377 )?));
378 }
379 }
380 }
381 Ok(None)
382}
383
384fn find_most_efficient_config(
385 model: &TypedModel,
386 node: &TypedNode,
387 swap_input: bool,
388) -> TractResult<(bool, bool)> {
389 if let &[a, b] = &*model.node_input_facts(node.id)? {
390 let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
391 let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
392
393 let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
394 let num_by_scalar_elements = if by_scalar_is_possible {
395 a_shape
396 .iter()
397 .zip(b_shape.iter())
398 .rev()
399 .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
400 .map(|(rev_a_dim, _)| rev_a_dim)
401 .product::<TDim>()
402 } else {
403 TDim::Val(0)
404 };
405
406 let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
407 let num_unicast_elements = if unicast_is_possible {
408 a_shape
409 .iter()
410 .zip(b_shape.iter())
411 .rev()
412 .take_while(|(a_dim, b_dim)| a_dim == b_dim)
413 .map(|(a_dim, _)| a_dim)
414 .product::<TDim>()
415 } else {
416 TDim::Val(0)
417 };
418
419 let min_num_elements = 32;
420 let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
421 let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
422 return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
423 }
424 Ok((false, false))
425}
426
427pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
428 TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
429}
430
431#[derive(Clone)]
432pub struct OptBinByScalar {
433 pub binop: Box<dyn BinMiniOp>,
434 eval_fn: Arc<LinalgFn>,
435}
436
437impl Debug for OptBinByScalar {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
439 f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
440 }
441}
442
443impl OptBinByScalar {
444 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
445 if a_shape.len() != b_shape.len() {
446 return false;
447 };
448
449 a_shape
450 .iter()
451 .zip(b_shape.iter())
452 .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
453 .all(|(_, b_dim)| *b_dim == 1.to_dim())
454 }
455}
456
457impl Op for OptBinByScalar {
458 fn name(&self) -> Cow<str> {
459 format!("Opt{}ByScalar", self.binop.name()).into()
460 }
461
462 fn same_as(&self, other: &dyn Op) -> bool {
463 let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
464 self.binop.same_as(&*other.binop)
465 }
466
467 op_as_typed_op!();
468}
469
470impl EvalOp for OptBinByScalar {
471 fn is_stateless(&self) -> bool {
472 true
473 }
474
475 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
476 let (a, b) = args_2!(inputs);
477 let a = a.into_tensor();
480 let b_shape = b.shape();
481
482 let first_unary_axis = b_shape
483 .iter()
484 .enumerate()
485 .rev()
486 .take_while(|&(_, &dim)| dim == 1)
487 .map(|(i, _)| i)
488 .last()
489 .context("Cannot use by_scalar when no trailing dimensions are unary")?;
490
491 let iterating_shape = &a.shape()[..first_unary_axis];
492 if !iterating_shape.is_empty() {
493 for it_coords in tract_ndarray::indices(iterating_shape) {
494 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
495 let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
496 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
497 (self.eval_fn)(&mut view, &b_view)?;
498 }
499 } else {
500 let mut view = a.view();
501 let b_view = b.view();
502 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
503 (self.eval_fn)(&mut view, &b_view)?;
504 }
505 Ok(tvec!(a.into_tvalue()))
506 }
507}
508
509impl TypedOp for OptBinByScalar {
510 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
511 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
512 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
513 let out_shape = inputs[0].shape.clone();
514 Ok(tvec!(out_dt.fact(out_shape)))
515 }
516
517 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
518 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
519 Ok(self
520 .binop
521 .cost_per_element(inputs[0].datum_type)
522 .into_iter()
523 .map(|(c, n)| (c, count.clone() * n))
524 .collect())
525 }
526
527 as_op!();
528}
529
530#[derive(Clone)]
531pub struct OptBinUnicast {
532 pub binop: Box<dyn BinMiniOp>,
533 eval_fn: Arc<LinalgFn>,
534}
535
536impl Debug for OptBinUnicast {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
538 f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
539 }
540}
541
542impl OptBinUnicast {
543 fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
544 let num_iterations: TDim = a_shape
545 .iter()
546 .zip(b_shape.iter())
547 .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
548 .map(|(a_dim, _)| a_dim)
549 .product();
550
551 if num_iterations.is_one() {
552 return true;
553 }
554
555 let elements_per_iteration: TDim = a_shape
556 .iter()
557 .zip(b_shape.iter())
558 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
559 .map(|(_, b_dim)| b_dim)
560 .product();
561
562 if let Ok(num_element) = elements_per_iteration.to_i64() {
563 let required_alignment = vector_size();
564 (num_element as usize % required_alignment) == 0
565 } else {
566 false
567 }
568 }
569 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
570 if a_shape.len() != b_shape.len() {
571 return false;
572 };
573
574 let unicast_possible = a_shape
575 .iter()
576 .zip(b_shape.iter())
577 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
578 .all(|(a_dim, b_dim)| a_dim == b_dim);
579 let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
580
581 unicast_possible && unicast_is_aligned
582 }
583}
584
585impl Op for OptBinUnicast {
586 fn name(&self) -> Cow<str> {
587 format!("Opt{}Unicast", self.binop.name()).into()
588 }
589
590 fn same_as(&self, other: &dyn Op) -> bool {
591 let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
592 self.binop.same_as(&*other.binop)
593 }
594 op_as_typed_op!();
595}
596
597impl EvalOp for OptBinUnicast {
598 fn is_stateless(&self) -> bool {
599 true
600 }
601
602 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
603 let (a, b) = args_2!(inputs);
604 let a = a.into_tensor();
607 let b_shape = b.shape();
608 let b_view = b.view();
609 let first_non_unary_axis =
610 b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
611
612 if let Some(first_non_unary_axis) = first_non_unary_axis {
613 let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
615 for it_coords in tract_ndarray::indices(iterating_shape) {
616 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
617 debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
618 (self.eval_fn)(&mut view, &b_view)?;
619 }
620 } else {
621 let mut view = a.view();
622 debug_assert_eq!(view.shape(), b_view.shape());
623 (self.eval_fn)(&mut view, &b_view)?;
624 }
625
626 Ok(tvec!(a.into_tvalue()))
627 }
628}
629
630impl TypedOp for OptBinUnicast {
631 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
632 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
633 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
634 let out_shape = inputs[0].shape.clone();
635 Ok(tvec!(out_dt.fact(out_shape)))
636 }
637
638 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
639 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
640 Ok(self
641 .binop
642 .cost_per_element(inputs[0].datum_type)
643 .into_iter()
644 .map(|(c, n)| (c, count.clone() * n))
645 .collect())
646 }
647
648 as_op!();
649}
650
651#[macro_export]
652macro_rules! bin_to_super_type {
653 ($func:ident, $Op:ident,
654 $(codegen: $codegen:expr,)?
655 $(cost: $cost:expr,)?
656 $(declutter: $declutter:expr,)?
657 $(eval_in_a: $eval_in_a:expr,)?
658 $(eval_override: $eval_override: expr,)?
659 $(linalg: $linalg:ident,)?
660 $(operating_datum_type: $operating_datum_type:expr,)?
661 $(is_commutative: $is_commutative:expr,)?
662 $(neutral_element: $neutral_element:expr,)?
663 $(out_of_place: $out_of_place:expr,)?
664 $(validation: $validation:expr,)?
665 $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
666 $(q_op_on_f32: $q_op_on_f32:expr,)?
667 $( [$($typ:ident),*] => $cab:expr),*) => {
668 #[derive(Debug, Clone, Hash)]
669 pub struct $Op;
670 #[allow(clippy::redundant_closure_call)]
671 impl $crate::ops::binary::BinMiniOp for $Op {
672 fn name(&self) -> &'static str {
673 stringify!($Op)
674 }
675
676 fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
677 other.downcast_ref::<$Op>().is_some()
678 }
679
680 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
681 $(if $out_of_place(c, a, b)? { return Ok(()) } )?
682 $(
683 $(if c.datum_type() == $typ::datum_type() {
684 let a = a.to_array_view::<$typ>()?;
685 let b = b.to_array_view::<$typ>()?;
686 let mut c = c.to_array_view_mut::<$typ>()?;
687 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
688 return Ok(())
689 })*
690 )*
691 $(
692 $(
693 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
694 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
695 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
696 let a = a.to_array_view::<$typ_dt>()?;
697 let b = b.to_array_view::<$typ_dt>()?;
698 let mut c = c.to_array_view_mut::<$typ_dt>()?;
699 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
700 return Ok(())
701 }
702 )*
703 )*
704 )?
705 bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
706 }
707
708 $(fn is_commutative(&self) -> bool {
709 $is_commutative
710 })?
711 $(fn neutral_element(&self) -> Option<i64> {
712 Some($neutral_element)
713 })?
714 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
715 $(if $eval_in_a(a, b)? { return Ok(()) } )?
717 $(
718 $(if b.datum_type() == $typ::datum_type() {
719 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
720 let b = b.to_array_view::<$typ>()?;
721 let mut a = a.to_array_view_mut::<$typ>()?;
722 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
723 return Ok(())
724 })*
725 )*
726 $(
727 $(
728 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
729 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
730 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
731 let mut a = a.to_array_view_mut::<$typ_dt>()?;
732 let b = b.to_array_view::<$typ_dt>()?;
733 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
734 cab(a, &(a.clone()), b, zp, scale)
735 });
736 return Ok(())
737 })*
738 )*
739 )?
740 bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
741 }
742
743 $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
744 $eval_override(a, b, c_dt)
745 })?
746
747 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
748 if a.unquantized() == b.unquantized() {
749 if a.is_quantized() || !b.is_quantized() {
750 return Ok(a)
751 }
752 else {
753 return Ok(b)
754 }
755 }
756 self.operating_datum_type(a, b)
757 }
758
759 $(
760 fn declutter(
761 &self,
762 model: &TypedModel,
763 node: &TypedNode,
764 ) -> TractResult<Option<TypedModelPatch>> {
765 ($declutter)(self, model, node)
766 }
767 )?
768 $(
769 fn codegen(
770 &self,
771 model: &TypedModel,
772 node: &TypedNode,
773 a: &Arc<Tensor>,
774 ) -> TractResult<Option<TypedModelPatch>> {
775 ($codegen)(self, model, node, a)
776 }
777 )?
778 $(
779 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
780 ($cost)(dt)
781 }
782 )?
783 $(
784 fn validation(&self) -> Validation {
785 $validation
786 }
787 )?
788 $(
789 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
790 Some(tract_linalg::BinOp::$linalg)
791 }
792 )?
793 $(
794 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
795 ($operating_datum_type)(a, b)
796 })?
797
798
799 #[allow(unused_variables)]
803 fn maybe_eval_qbinary_as_float_op(
804 &self,
805 a: &TValue,
806 b: &TValue,
807 c_dt: &DatumType,
808 ) -> TractResult<Option<Tensor>> {
809 $(
810 fn memory_optimised_q_binary_as_float_op(
814 a: &TValue,
815 b: &TValue,
816 c_dt: &DatumType,
817 ) -> TractResult<Option<Tensor>> {
818 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
819 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
820 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
821 (a.datum_type(), b.datum_type(), c_dt)
822 {
823 let c_inv_scale = 1.0 / c_scale;
824 let a = a.to_array_view::<u8>()?;
825 let b = b.to_array_view::<u8>()?;
826 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
827 let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
828 let view = c.to_array_view_mut::<u8>()?;
829 $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
830 *c = (scale_by($q_op_on_f32(
831 ((*a as i32 - a_zp as i32) as f32 * a_scale),
832 ((*b as i32 - b_zp as i32) as f32 * b_scale),
833 ), c_inv_scale) as i32
834 + *c_zp as i32)
835 .clamp_cast()
836 });
837 return Ok(Some(c));
838 }
839 Ok(None)
840 }
841
842 fn generic_q_binary_as_float_op(
846 a: &TValue,
847 b: &TValue,
848 c_dt: &DatumType,
849 accumulator_dt: DatumType
850 ) -> TractResult<Option<Tensor>> {
851 if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
852 let a = a.cast_to_dt(accumulator_dt)?.into_owned();
853 let b = b.cast_to_dt(accumulator_dt)?.into_owned();
854 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
855 let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
856 match accumulator_dt {
857 DatumType::F32 => {
858 let view = c.to_array_view_mut::<f32>()?;
859 $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
860 *c = $q_op_on_f32(*a,*b);
861 })
862 },
863 other => bail!("unexpected accumulator data type as {:?}", other)
864 };
865
866 return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
867 }
868 Ok(None)
869 }
870
871 if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
872 return Ok(Some(c));
873 }
874 if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
875 return Ok(Some(d));
876 }
877 )?
878 Ok(None)
879 }
880 }
881
882 pub fn $func() -> $crate::ops::binary::TypedBinOp {
883 $crate::ops::binary::TypedBinOp(Box::new($Op), None)
884 }
885 };
886}
887
888#[derive(Debug)]
889pub(crate) struct OneUniformInput {
890 pub uni: Arc<Tensor>,
891 pub var: OutletId,
892 pub left_is_uniform: bool,
893}
894
895pub(crate) fn one_input_is_uniform(
896 model: &TypedModel,
897 node: &TypedNode,
898) -> TractResult<Option<OneUniformInput>> {
899 if let &[a, b] = &*model.node_input_facts(node.id)? {
900 let uni = if let Some(a) = &a.uniform {
901 OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
902 } else if let Some(b) = &b.uniform {
903 OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
904 } else {
905 return Ok(None);
906 };
907 let var_fact = [a, b][uni.left_is_uniform as usize];
908 let uni_fact = [a, b][!uni.left_is_uniform as usize];
909 if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
910 return Ok(Some(uni));
911 }
912 }
913 Ok(None)
914}