1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::math::{Add, Max, Min, Mul, Sub};
10use super::{cast::cast, math::SubF};
11
12pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
13 fn name(&self) -> &'static str;
14 fn validation(&self) -> Validation {
15 Validation::Accurate
16 }
17 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
18 a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
19 }
20 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
21 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
22 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
23
24 fn is_commutative(&self) -> bool {
25 true
26 }
27 fn neutral_element(&self) -> Option<i64> {
28 None
29 }
30
31 #[allow(unused_variables)]
32 fn maybe_eval_qbinary_as_float_op(
33 &self,
34 a: &TValue,
35 b: &TValue,
36 c_dt: &DatumType,
37 ) -> TractResult<Option<Tensor>> {
38 Ok(None)
39 }
40
41 fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
42 if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
43 Ok(tensor)
44 } else {
45 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
46 if &*c_shape == a.shape() && c_dt == a.datum_type() {
47 let mut a = a.into_tensor();
48 self.eval_in_a(&mut a, &b)?;
49 Ok(a)
50 } else {
51 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
52 self.eval_out_of_place(&mut c, &a, &b)?;
53 Ok(c)
54 }
55 }
56 }
57 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
58 self.generic_eval(a, b, c_dt)
59 }
60 #[allow(unused_variables)]
61 fn declutter(
62 &self,
63 model: &TypedModel,
64 node: &TypedNode,
65 ) -> TractResult<Option<TypedModelPatch>> {
66 Ok(None)
67 }
68 #[allow(unused_variables)]
69 fn codegen(
70 &self,
71 model: &TypedModel,
72 node: &TypedNode,
73 ) -> TractResult<Option<TypedModelPatch>> {
74 Ok(None)
75 }
76 #[allow(unused_variables)]
77 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
78 tvec!()
79 }
80 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
81 None
82 }
83
84 #[allow(unused_variables)]
85 fn same_as(&self, other: &dyn BinMiniOp) -> bool {
86 false
87 }
88}
89dyn_clone::clone_trait_object!(BinMiniOp);
90downcast_rs::impl_downcast!(BinMiniOp);
91
92#[derive(Debug, Clone)]
93pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
94
95impl Op for TypedBinOp {
96 fn name(&self) -> StaticName {
97 self.0.name().into()
98 }
99
100 fn validation(&self) -> Validation {
101 self.0.validation()
102 }
103
104 fn same_as(&self, other: &dyn Op) -> bool {
105 let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
106 self.1 == other.1 && self.0.same_as(&*other.0)
107 }
108
109 op_as_typed_op!();
110}
111
112impl TypedBinOp {
113 fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
114 if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
115 }
116}
117
118impl EvalOp for TypedBinOp {
119 fn is_stateless(&self) -> bool {
120 true
121 }
122
123 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
124 let (a, b) = args_2!(inputs);
125 ensure!(a.rank() == b.rank());
126 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
127 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
128 }
129}
130
131impl TypedOp for TypedBinOp {
132 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
133 if inputs[0].rank() != inputs[1].rank() {
134 bail!(
135 "Typed ops require rank match. Invalid inputs for {}: {}",
136 self.name(),
137 inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
138 );
139 }
140 let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
141 Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
142 &inputs[0].shape.to_tvec(),
143 &inputs[1].shape.to_tvec()
144 ])?)))
145 }
146
147 fn change_axes(
148 &self,
149 model: &TypedModel,
150 node: &TypedNode,
151 _io: InOut,
152 change: &AxisOp,
153 ) -> TractResult<Option<AxisChangeConsequence>> {
154 if let AxisOp::Rm(rm) = change {
155 let (inputs, outputs) = model.node_facts(node.id)?;
156 rule_if!(inputs[0].shape[*rm].is_one());
157 rule_if!(inputs[1].shape[*rm].is_one());
158 rule_if!(outputs[0].shape[*rm].is_one());
159 }
160 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
161 }
162
163 fn axes_mapping(
164 &self,
165 inputs: &[&TypedFact],
166 outputs: &[&TypedFact],
167 ) -> TractResult<AxesMapping> {
168 AxesMapping::natural(inputs, outputs)
169 }
170
171 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
172 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
173 Ok(self
174 .0
175 .cost_per_element(inputs[0].datum_type)
176 .into_iter()
177 .map(|(c, n)| (c, count.clone() * n))
178 .collect())
179 }
180
181 fn slice(
182 &self,
183 patch: &mut TypedModelPatch,
184 _model: &TypedModel,
185 _node: &TypedNode,
186 prefix: &str,
187 inputs: &[OutletId],
188 _output_axis: usize,
189 _start: &TDim,
190 _end: &TDim,
191 ) -> TractResult<Option<TVec<OutletId>>> {
192 Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
193 }
194
195 fn declutter(
196 &self,
197 model: &TypedModel,
198 node: &TypedNode,
199 ) -> TractResult<Option<TypedModelPatch>> {
200 let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
201 (a.datum_type().unwrap(), b.datum_type().unwrap())
202 } else {
203 unreachable!("TypedBinOp has two inputs.")
204 };
205 if let Some(neutral_patch) =
206 declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
207 {
208 return Ok(Some(neutral_patch));
209 }
210 if let Some(broadcast_patch) =
211 declutter_broadcasting_operand_1(model, node, self.0.clone())?
212 {
213 return Ok(Some(broadcast_patch));
214 }
215 self.0.declutter(model, node)
216 }
217
218 fn codegen(
219 &self,
220 model: &TypedModel,
221 node: &TypedNode,
222 ) -> TractResult<Option<TypedModelPatch>> {
223 if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
224 let input_facts = model.node_input_facts(node.id)?;
225 let must_swap_inputs =
226 input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
227 (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
228 });
229 let (operand_1, operand_2) = if must_swap_inputs {
230 (input_facts[1], input_facts[0])
231 } else {
232 (input_facts[0], input_facts[1])
233 };
234
235 let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
236 find_most_efficient_config(model, node, must_swap_inputs)?;
237
238 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
240 let op_is_quant = c_dt.is_quantized()
241 || operand_1.datum_type.is_quantized()
242 || operand_2.datum_type.is_quantized();
243
244 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
246 let c_shape = crate::broadcast::multi_broadcast(&[
247 operand_1.shape.clone(),
248 operand_2.shape.clone(),
249 ])?;
250 let can_eval_in_a =
251 (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
252
253 let inputs = if must_swap_inputs {
255 let mut swap_input = node.inputs.clone();
256 swap_input.swap(0, 1);
257 swap_input
258 } else {
259 node.inputs.clone()
260 };
261 let actual_linalg_op =
262 if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
263 let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
264
265 let dt = model.node_input_facts(node.id)?[0].datum_type;
266 if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
267 rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
268 let eval_fn = Arc::from(func);
269 return Ok(Some(
270 TypedModelPatch::replace_single_op(
271 model,
272 node,
273 &inputs,
274 OptBinByScalar { binop: actual_core_op, eval_fn },
275 )?
276 .with_context("ByScalar"),
277 ));
278 }
279
280 if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
281 rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
282 let eval_fn = Arc::from(func);
283 return Ok(Some(
284 TypedModelPatch::replace_single_op(
285 model,
286 node,
287 &inputs,
288 OptBinUnicast { binop: actual_core_op, eval_fn },
289 )?
290 .with_context("Unicast"),
291 ));
292 }
293 }
294
295 Ok(None)
296 }
297 as_op!();
298}
299
300fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
301 match linalg {
302 BinOp::Min => Box::new(Min),
303 BinOp::Max => Box::new(Max),
304 BinOp::Add => Box::new(Add),
305 BinOp::Mul => Box::new(Mul),
306 BinOp::Sub => Box::new(Sub),
307 BinOp::SubF => Box::new(SubF),
308 }
309}
310fn declutter_broadcasting_operand_1(
311 model: &TypedModel,
312 node: &TypedNode,
313 mini_op: Box<dyn BinMiniOp>,
314) -> TractResult<Option<TypedModelPatch>> {
315 let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
316 (a.shape.clone(), b.shape.clone())
317 } else {
318 unreachable!("TypedBinOp has two inputs.")
319 };
320
321 let a_num_elements = a_shape.iter().product::<TDim>();
322 let b_num_elements = b_shape.iter().product::<TDim>();
323 let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
324 if a_should_be_broadcast & mini_op.is_commutative() {
325 let mut swap_input = node.inputs.clone();
326 swap_input.swap(0, 1);
327 return Ok(Some(TypedModelPatch::replace_single_op(
328 model,
329 node,
330 &swap_input,
331 TypedBinOp(mini_op, None),
332 )?));
333 }
334
335 Ok(None)
336}
337
338fn declutter_neutral(
339 model: &TypedModel,
340 node: &TypedNode,
341 mini_op: &dyn BinMiniOp,
342 out_dt: DatumType,
343) -> TractResult<Option<TypedModelPatch>> {
344 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
345 let is_neutral = mini_op
346 .neutral_element()
347 .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
348 .unwrap_or(false);
349
350 let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
353
354 if is_neutral && pos_checked {
355 if uniform.uni.datum_type().is_quantized() {
360 return Ok(Some(TypedModelPatch::replace_single_op(
361 model,
362 node,
363 &[node.inputs[0]],
364 cast(out_dt),
365 )?));
366 } else {
368 return Ok(Some(TypedModelPatch::rewire(
369 model,
370 &[uniform.var],
371 &[node.id.into()],
372 &|_, inputs| Ok(inputs.into()),
373 )?));
374 }
375 }
376 }
377 Ok(None)
378}
379
380fn find_most_efficient_config(
381 model: &TypedModel,
382 node: &TypedNode,
383 swap_input: bool,
384) -> TractResult<(bool, bool)> {
385 if let &[a, b] = &*model.node_input_facts(node.id)? {
386 let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
387 let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
388
389 let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
390 let num_by_scalar_elements = if by_scalar_is_possible {
391 a_shape
392 .iter()
393 .zip(b_shape.iter())
394 .rev()
395 .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
396 .map(|(rev_a_dim, _)| rev_a_dim)
397 .product::<TDim>()
398 } else {
399 TDim::Val(0)
400 };
401
402 let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
403 let num_unicast_elements = if unicast_is_possible {
404 a_shape
405 .iter()
406 .zip(b_shape.iter())
407 .rev()
408 .take_while(|(a_dim, b_dim)| a_dim == b_dim)
409 .map(|(a_dim, _)| a_dim)
410 .product::<TDim>()
411 } else {
412 TDim::Val(0)
413 };
414
415 let min_num_elements = 32;
416 let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
417 let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
418 return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
419 }
420 Ok((false, false))
421}
422
423pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
424 TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
425}
426
427#[derive(Clone)]
428pub struct OptBinByScalar {
429 pub binop: Box<dyn BinMiniOp>,
430 eval_fn: Arc<LinalgFn>,
431}
432
433impl Debug for OptBinByScalar {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
435 f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
436 }
437}
438
439impl OptBinByScalar {
440 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
441 if a_shape.len() != b_shape.len() {
442 return false;
443 };
444
445 a_shape
446 .iter()
447 .zip(b_shape.iter())
448 .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
449 .all(|(_, b_dim)| *b_dim == 1.to_dim())
450 }
451}
452
453impl Op for OptBinByScalar {
454 fn name(&self) -> StaticName {
455 format!("Opt{}ByScalar", self.binop.name()).into()
456 }
457
458 fn same_as(&self, other: &dyn Op) -> bool {
459 let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
460 self.binop.same_as(&*other.binop)
461 }
462
463 op_as_typed_op!();
464}
465
466impl EvalOp for OptBinByScalar {
467 fn is_stateless(&self) -> bool {
468 true
469 }
470
471 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
472 let (a, b) = args_2!(inputs);
473 let a = a.into_tensor();
476 let b_shape = b.shape();
477
478 let first_unary_axis = b_shape
479 .iter()
480 .enumerate()
481 .rev()
482 .take_while(|&(_, &dim)| dim == 1)
483 .map(|(i, _)| i)
484 .last()
485 .context("Cannot use by_scalar when no trailing dimensions are unary")?;
486
487 let iterating_shape = &a.shape()[..first_unary_axis];
488 if !iterating_shape.is_empty() {
489 for it_coords in tract_ndarray::indices(iterating_shape) {
490 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
491 let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
492 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
493 (self.eval_fn)(&mut view, &b_view)?;
494 }
495 } else {
496 let mut view = a.view();
497 let b_view = b.view();
498 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
499 (self.eval_fn)(&mut view, &b_view)?;
500 }
501 Ok(tvec!(a.into_tvalue()))
502 }
503}
504
505impl TypedOp for OptBinByScalar {
506 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
507 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
508 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
509 let out_shape = inputs[0].shape.clone();
510 Ok(tvec!(out_dt.fact(out_shape)))
511 }
512
513 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
514 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
515 Ok(self
516 .binop
517 .cost_per_element(inputs[0].datum_type)
518 .into_iter()
519 .map(|(c, n)| (c, count.clone() * n))
520 .collect())
521 }
522
523 as_op!();
524}
525
526#[derive(Clone)]
527pub struct OptBinUnicast {
528 pub binop: Box<dyn BinMiniOp>,
529 eval_fn: Arc<LinalgFn>,
530}
531
532impl Debug for OptBinUnicast {
533 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
534 f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
535 }
536}
537
538impl OptBinUnicast {
539 fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
540 let num_iterations: TDim = a_shape
541 .iter()
542 .zip(b_shape.iter())
543 .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
544 .map(|(a_dim, _)| a_dim)
545 .product();
546
547 if num_iterations.is_one() {
548 return true;
549 }
550
551 let elements_per_iteration: TDim = a_shape
552 .iter()
553 .zip(b_shape.iter())
554 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
555 .map(|(_, b_dim)| b_dim)
556 .product();
557
558 if let Ok(num_element) = elements_per_iteration.to_i64() {
559 let required_alignment = vector_size();
560 (num_element as usize % required_alignment) == 0
561 } else {
562 false
563 }
564 }
565 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
566 if a_shape.len() != b_shape.len() {
567 return false;
568 };
569
570 let unicast_possible = a_shape
571 .iter()
572 .zip(b_shape.iter())
573 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
574 .all(|(a_dim, b_dim)| a_dim == b_dim);
575 let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
576
577 unicast_possible && unicast_is_aligned
578 }
579}
580
581impl Op for OptBinUnicast {
582 fn name(&self) -> StaticName {
583 format!("Opt{}Unicast", self.binop.name()).into()
584 }
585
586 fn same_as(&self, other: &dyn Op) -> bool {
587 let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
588 self.binop.same_as(&*other.binop)
589 }
590 op_as_typed_op!();
591}
592
593impl EvalOp for OptBinUnicast {
594 fn is_stateless(&self) -> bool {
595 true
596 }
597
598 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
599 let (a, b) = args_2!(inputs);
600 let a = a.into_tensor();
603 let b_shape = b.shape();
604 let b_view = b.view();
605 let first_non_unary_axis =
606 b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
607
608 if let Some(first_non_unary_axis) = first_non_unary_axis {
609 let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
611 for it_coords in tract_ndarray::indices(iterating_shape) {
612 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
613 debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
614 (self.eval_fn)(&mut view, &b_view)?;
615 }
616 } else {
617 let mut view = a.view();
618 debug_assert_eq!(view.shape(), b_view.shape());
619 (self.eval_fn)(&mut view, &b_view)?;
620 }
621
622 Ok(tvec!(a.into_tvalue()))
623 }
624}
625
626impl TypedOp for OptBinUnicast {
627 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
628 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
629 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
630 let out_shape = inputs[0].shape.clone();
631 Ok(tvec!(out_dt.fact(out_shape)))
632 }
633
634 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
635 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
636 Ok(self
637 .binop
638 .cost_per_element(inputs[0].datum_type)
639 .into_iter()
640 .map(|(c, n)| (c, count.clone() * n))
641 .collect())
642 }
643
644 as_op!();
645}
646
647#[macro_export]
648macro_rules! bin_to_super_type {
649 ($func:ident, $Op:ident,
650 $(codegen: $codegen:expr,)?
651 $(cost: $cost:expr,)?
652 $(declutter: $declutter:expr,)?
653 $(eval_in_a: $eval_in_a:expr,)?
654 $(eval_override: $eval_override: expr,)?
655 $(linalg: $linalg:ident,)?
656 $(operating_datum_type: $operating_datum_type:expr,)?
657 $(is_commutative: $is_commutative:expr,)?
658 $(neutral_element: $neutral_element:expr,)?
659 $(out_of_place: $out_of_place:expr,)?
660 $(validation: $validation:expr,)?
661 $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
662 $(q_op_on_f32: $q_op_on_f32:expr,)?
663 $( [$($typ:ident),*] => $cab:expr),*) => {
664 #[derive(Debug, Clone, Hash)]
665 pub struct $Op;
666 #[allow(clippy::redundant_closure_call)]
667 impl $crate::ops::binary::BinMiniOp for $Op {
668 fn name(&self) -> &'static str {
669 stringify!($Op)
670 }
671
672 fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
673 other.downcast_ref::<$Op>().is_some()
674 }
675
676 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
677 $(if $out_of_place(c, a, b)? { return Ok(()) } )?
678 $(
679 $(if c.datum_type() == $typ::datum_type() {
680 let a = a.to_array_view::<$typ>()?;
681 let b = b.to_array_view::<$typ>()?;
682 let mut c = c.to_array_view_mut::<$typ>()?;
683 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
684 return Ok(())
685 })*
686 )*
687 $(
688 $(
689 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
690 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
691 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
692 let a = a.to_array_view::<$typ_dt>()?;
693 let b = b.to_array_view::<$typ_dt>()?;
694 let mut c = c.to_array_view_mut::<$typ_dt>()?;
695 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
696 return Ok(())
697 }
698 )*
699 )*
700 )?
701 bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
702 }
703
704 $(fn is_commutative(&self) -> bool {
705 $is_commutative
706 })?
707 $(fn neutral_element(&self) -> Option<i64> {
708 Some($neutral_element)
709 })?
710 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
711 $(if $eval_in_a(a, b)? { return Ok(()) } )?
713 $(
714 $(if b.datum_type() == $typ::datum_type() {
715 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
716 let b = b.to_array_view::<$typ>()?;
717 let mut a = a.to_array_view_mut::<$typ>()?;
718 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
719 return Ok(())
720 })*
721 )*
722 $(
723 $(
724 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
725 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
726 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
727 let mut a = a.to_array_view_mut::<$typ_dt>()?;
728 let b = b.to_array_view::<$typ_dt>()?;
729 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
730 cab(a, &(a.clone()), b, zp, scale)
731 });
732 return Ok(())
733 })*
734 )*
735 )?
736 bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
737 }
738
739 $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
740 $eval_override(a, b, c_dt)
741 })?
742
743 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
744 if a.unquantized() == b.unquantized() {
745 if a.is_quantized() || !b.is_quantized() {
746 return Ok(a)
747 }
748 else {
749 return Ok(b)
750 }
751 }
752 self.operating_datum_type(a, b)
753 }
754
755 $(
756 fn declutter(
757 &self,
758 model: &TypedModel,
759 node: &TypedNode,
760 ) -> TractResult<Option<TypedModelPatch>> {
761 ($declutter)(self, model, node)
762 }
763 )?
764 $(
765 fn codegen(
766 &self,
767 model: &TypedModel,
768 node: &TypedNode,
769 a: &Arc<Tensor>,
770 ) -> TractResult<Option<TypedModelPatch>> {
771 ($codegen)(self, model, node, a)
772 }
773 )?
774 $(
775 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
776 ($cost)(dt)
777 }
778 )?
779 $(
780 fn validation(&self) -> Validation {
781 $validation
782 }
783 )?
784 $(
785 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
786 Some(tract_linalg::BinOp::$linalg)
787 }
788 )?
789 $(
790 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
791 ($operating_datum_type)(a, b)
792 })?
793
794
795 #[allow(unused_variables)]
799 fn maybe_eval_qbinary_as_float_op(
800 &self,
801 a: &TValue,
802 b: &TValue,
803 c_dt: &DatumType,
804 ) -> TractResult<Option<Tensor>> {
805 $(
806 fn memory_optimised_q_binary_as_float_op(
810 a: &TValue,
811 b: &TValue,
812 c_dt: &DatumType,
813 ) -> TractResult<Option<Tensor>> {
814 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
815 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
816 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
817 (a.datum_type(), b.datum_type(), c_dt)
818 {
819 let c_inv_scale = 1.0 / c_scale;
820 let a = a.to_array_view::<u8>()?;
821 let b = b.to_array_view::<u8>()?;
822 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
823 let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
824 let view = c.to_array_view_mut::<u8>()?;
825 $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
826 *c = (scale_by($q_op_on_f32(
827 ((*a as i32 - a_zp as i32) as f32 * a_scale),
828 ((*b as i32 - b_zp as i32) as f32 * b_scale),
829 ), c_inv_scale) as i32
830 + *c_zp as i32)
831 .clamp_cast()
832 });
833 return Ok(Some(c));
834 }
835 Ok(None)
836 }
837
838 fn generic_q_binary_as_float_op(
842 a: &TValue,
843 b: &TValue,
844 c_dt: &DatumType,
845 accumulator_dt: DatumType
846 ) -> TractResult<Option<Tensor>> {
847 if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
848 let a = a.cast_to_dt(accumulator_dt)?.into_owned();
849 let b = b.cast_to_dt(accumulator_dt)?.into_owned();
850 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
851 let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
852 match accumulator_dt {
853 DatumType::F32 => {
854 let view = c.to_array_view_mut::<f32>()?;
855 $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
856 *c = $q_op_on_f32(*a,*b);
857 })
858 },
859 other => bail!("unexpected accumulator data type as {:?}", other)
860 };
861
862 return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
863 }
864 Ok(None)
865 }
866
867 if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
868 return Ok(Some(c));
869 }
870 if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
871 return Ok(Some(d));
872 }
873 )?
874 Ok(None)
875 }
876 }
877
878 pub fn $func() -> $crate::ops::binary::TypedBinOp {
879 $crate::ops::binary::TypedBinOp(Box::new($Op), None)
880 }
881 };
882}
883
884#[derive(Debug)]
885pub(crate) struct OneUniformInput {
886 pub uni: Arc<Tensor>,
887 pub var: OutletId,
888 pub left_is_uniform: bool,
889}
890
891pub(crate) fn one_input_is_uniform(
892 model: &TypedModel,
893 node: &TypedNode,
894) -> TractResult<Option<OneUniformInput>> {
895 if let &[a, b] = &*model.node_input_facts(node.id)? {
896 let uni = if let Some(a) = &a.uniform {
897 OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
898 } else if let Some(b) = &b.uniform {
899 OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
900 } else {
901 return Ok(None);
902 };
903 let var_fact = [a, b][uni.left_is_uniform as usize];
904 let uni_fact = [a, b][!uni.left_is_uniform as usize];
905 if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
906 return Ok(Some(uni));
907 }
908 }
909 Ok(None)
910}