1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::math::{Add, Max, Min, Mul, Sub};
10use super::{cast::cast, math::SubF};
11
12pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
13 fn name(&self) -> &'static str;
14 fn validation(&self) -> Validation {
15 Validation::Accurate
16 }
17 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
18 a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
19 }
20 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
21 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
22 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
23
24 fn is_commutative(&self) -> bool {
25 true
26 }
27 fn neutral_element(&self) -> Option<i64> {
28 None
29 }
30
31 #[allow(unused_variables)]
32 fn maybe_eval_qbinary_as_float_op(
33 &self,
34 a: &TValue,
35 b: &TValue,
36 c_dt: &DatumType,
37 ) -> TractResult<Option<Tensor>> {
38 Ok(None)
39 }
40
41 fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
42 if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
43 Ok(tensor)
44 } else {
45 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
46 if &*c_shape == a.shape() && c_dt == a.datum_type() {
47 let mut a = a.into_tensor();
48 self.eval_in_a(&mut a, &b)?;
49 Ok(a)
50 } else {
51 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
52 self.eval_out_of_place(&mut c, &a, &b)?;
53 Ok(c)
54 }
55 }
56 }
57 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
58 self.generic_eval(a, b, c_dt)
59 }
60 #[allow(unused_variables)]
61 fn declutter(
62 &self,
63 model: &TypedModel,
64 node: &TypedNode,
65 ) -> TractResult<Option<TypedModelPatch>> {
66 Ok(None)
67 }
68 #[allow(unused_variables)]
69 fn codegen(
70 &self,
71 model: &TypedModel,
72 node: &TypedNode,
73 ) -> TractResult<Option<TypedModelPatch>> {
74 Ok(None)
75 }
76 #[allow(unused_variables)]
77 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
78 tvec!()
79 }
80 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
81 None
82 }
83
84 #[allow(unused_variables)]
85 fn same_as(&self, other: &dyn BinMiniOp) -> bool {
86 false
87 }
88}
89dyn_clone::clone_trait_object!(BinMiniOp);
90downcast_rs::impl_downcast!(BinMiniOp);
91
92#[derive(Debug, Clone)]
93pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
94
95impl Op for TypedBinOp {
96 fn name(&self) -> Cow<str> {
97 self.0.name().into()
98 }
99
100 fn validation(&self) -> Validation {
101 self.0.validation()
102 }
103
104 fn same_as(&self, other: &dyn Op) -> bool {
105 let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
106 self.1 == other.1 && self.0.same_as(&*other.0)
107 }
108
109 op_as_typed_op!();
110}
111
112impl TypedBinOp {
113 fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
114 if let Some(dt) = self.1 {
115 Ok(dt)
116 } else {
117 self.0.result_datum_type(a_dt, b_dt)
118 }
119 }
120}
121
122impl EvalOp for TypedBinOp {
123 fn is_stateless(&self) -> bool {
124 true
125 }
126
127 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
128 let (a, b) = args_2!(inputs);
129 ensure!(a.rank() == b.rank());
130 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
131 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
132 }
133}
134
135impl TypedOp for TypedBinOp {
136 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
137 if inputs[0].rank() != inputs[1].rank() {
138 bail!(
139 "Typed ops require rank match. Invalid inputs for {}: {}",
140 self.name(),
141 inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
142 );
143 }
144 let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
145 Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
146 &inputs[0].shape.to_tvec(),
147 &inputs[1].shape.to_tvec()
148 ])?)))
149 }
150
151 fn change_axes(
152 &self,
153 model: &TypedModel,
154 node: &TypedNode,
155 _io: InOut,
156 change: &AxisOp,
157 ) -> TractResult<Option<AxisChangeConsequence>> {
158 if let AxisOp::Rm(rm) = change {
159 let (inputs, outputs) = model.node_facts(node.id)?;
160 if !inputs[0].shape[*rm].is_one()
161 || !inputs[1].shape[*rm].is_one()
162 || !outputs[0].shape[*rm].is_one()
163 {
164 return Ok(None);
165 }
166 }
167 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
168 }
169
170 fn axes_mapping(
171 &self,
172 inputs: &[&TypedFact],
173 outputs: &[&TypedFact],
174 ) -> TractResult<AxesMapping> {
175 AxesMapping::natural(inputs, outputs)
176 }
177
178 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
179 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
180 Ok(self
181 .0
182 .cost_per_element(inputs[0].datum_type)
183 .into_iter()
184 .map(|(c, n)| (c, count.clone() * n))
185 .collect())
186 }
187
188 fn slice(
189 &self,
190 patch: &mut TypedModelPatch,
191 _model: &TypedModel,
192 _node: &TypedNode,
193 prefix: &str,
194 inputs: &[OutletId],
195 _output_axis: usize,
196 _start: &TDim,
197 _end: &TDim,
198 ) -> TractResult<Option<TVec<OutletId>>> {
199 Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
200 }
201
202 fn declutter(
203 &self,
204 model: &TypedModel,
205 node: &TypedNode,
206 ) -> TractResult<Option<TypedModelPatch>> {
207 let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
208 (a.datum_type().unwrap(), b.datum_type().unwrap())
209 } else {
210 unreachable!("TypedBinOp has two inputs.")
211 };
212 if let Some(neutral_patch) =
213 declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
214 {
215 return Ok(Some(neutral_patch));
216 }
217 if let Some(broadcast_patch) =
218 declutter_broadcasting_operand_1(model, node, self.0.clone())?
219 {
220 return Ok(Some(broadcast_patch));
221 }
222 self.0.declutter(model, node)
223 }
224
225 fn codegen(
226 &self,
227 model: &TypedModel,
228 node: &TypedNode,
229 ) -> TractResult<Option<TypedModelPatch>> {
230 if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
231 let input_facts = model.node_input_facts(node.id)?;
232 let must_swap_inputs =
233 input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
234 (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
235 });
236 let (operand_1, operand_2) = if must_swap_inputs {
237 (input_facts[1], input_facts[0])
238 } else {
239 (input_facts[0], input_facts[1])
240 };
241
242 let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
243 find_most_efficient_config(model, node, must_swap_inputs)?;
244
245 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
247 let op_is_quant = c_dt.is_quantized()
248 || operand_1.datum_type.is_quantized()
249 || operand_2.datum_type.is_quantized();
250
251 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
253 let c_shape = crate::broadcast::multi_broadcast(&[
254 operand_1.shape.clone(),
255 operand_2.shape.clone(),
256 ])?;
257 let can_eval_in_a =
258 (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
259
260 let inputs = if must_swap_inputs {
262 let mut swap_input = node.inputs.clone();
263 swap_input.swap(0, 1);
264 swap_input
265 } else {
266 node.inputs.clone()
267 };
268 let actual_linalg_op =
269 if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
270 let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
271
272 let dt = model.node_input_facts(node.id)?[0].datum_type;
273 if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
274 let Some(func) = tract_linalg::bin_by_scalar(dt, actual_linalg_op) else {
275 return Ok(None);
276 };
277 let eval_fn = Arc::from(func);
278 return Ok(Some(
279 TypedModelPatch::replace_single_op(
280 model,
281 node,
282 &inputs,
283 OptBinByScalar { binop: actual_core_op, eval_fn },
284 )?
285 .with_context("ByScalar"),
286 ));
287 }
288
289 if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
290 let Some(func) = tract_linalg::bin_unicast(dt, actual_linalg_op) else {
291 return Ok(None);
292 };
293 let eval_fn = Arc::from(func);
294 return Ok(Some(
295 TypedModelPatch::replace_single_op(
296 model,
297 node,
298 &inputs,
299 OptBinUnicast { binop: actual_core_op, eval_fn },
300 )?
301 .with_context("Unicast"),
302 ));
303 }
304 }
305
306 Ok(None)
307 }
308 as_op!();
309}
310
311fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
312 match linalg {
313 BinOp::Min => Box::new(Min),
314 BinOp::Max => Box::new(Max),
315 BinOp::Add => Box::new(Add),
316 BinOp::Mul => Box::new(Mul),
317 BinOp::Sub => Box::new(Sub),
318 BinOp::SubF => Box::new(SubF),
319 }
320}
321fn declutter_broadcasting_operand_1(
322 model: &TypedModel,
323 node: &TypedNode,
324 mini_op: Box<dyn BinMiniOp>,
325) -> TractResult<Option<TypedModelPatch>> {
326 let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
327 (a.shape.clone(), b.shape.clone())
328 } else {
329 unreachable!("TypedBinOp has two inputs.")
330 };
331
332 let a_num_elements = a_shape.iter().product::<TDim>();
333 let b_num_elements = b_shape.iter().product::<TDim>();
334 let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
335 if a_should_be_broadcast & mini_op.is_commutative() {
336 let mut swap_input = node.inputs.clone();
337 swap_input.swap(0, 1);
338 return Ok(Some(TypedModelPatch::replace_single_op(
339 model,
340 node,
341 &swap_input,
342 TypedBinOp(mini_op, None),
343 )?));
344 }
345
346 Ok(None)
347}
348
349fn declutter_neutral(
350 model: &TypedModel,
351 node: &TypedNode,
352 mini_op: &dyn BinMiniOp,
353 out_dt: DatumType,
354) -> TractResult<Option<TypedModelPatch>> {
355 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
356 let is_neutral = mini_op
357 .neutral_element()
358 .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
359 .unwrap_or(false);
360
361 let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
364
365 if is_neutral && pos_checked {
366 if uniform.uni.datum_type().is_quantized() {
371 return Ok(Some(TypedModelPatch::replace_single_op(
372 model,
373 node,
374 &[node.inputs[0]],
375 cast(out_dt),
376 )?));
377 } else {
379 return Ok(Some(TypedModelPatch::rewire(
380 model,
381 &[uniform.var],
382 &[node.id.into()],
383 &|_, inputs| Ok(inputs.into()),
384 )?));
385 }
386 }
387 }
388 Ok(None)
389}
390
391fn find_most_efficient_config(
392 model: &TypedModel,
393 node: &TypedNode,
394 swap_input: bool,
395) -> TractResult<(bool, bool)> {
396 if let &[a, b] = &*model.node_input_facts(node.id)? {
397 let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
398 let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
399
400 let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
401 let num_by_scalar_elements = if by_scalar_is_possible {
402 a_shape
403 .iter()
404 .zip(b_shape.iter())
405 .rev()
406 .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
407 .map(|(rev_a_dim, _)| rev_a_dim)
408 .product::<TDim>()
409 } else {
410 TDim::Val(0)
411 };
412
413 let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
414 let num_unicast_elements = if unicast_is_possible {
415 a_shape
416 .iter()
417 .zip(b_shape.iter())
418 .rev()
419 .take_while(|(a_dim, b_dim)| a_dim == b_dim)
420 .map(|(a_dim, _)| a_dim)
421 .product::<TDim>()
422 } else {
423 TDim::Val(0)
424 };
425
426 let min_num_elements = 32;
427 let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
428 let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
429 return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
430 }
431 Ok((false, false))
432}
433
434pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
435 TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
436}
437
438#[derive(Clone)]
439pub struct OptBinByScalar {
440 pub binop: Box<dyn BinMiniOp>,
441 eval_fn: Arc<LinalgFn>,
442}
443
444impl Debug for OptBinByScalar {
445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
446 f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
447 }
448}
449
450impl OptBinByScalar {
451 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
452 if a_shape.len() != b_shape.len() {
453 return false;
454 };
455
456 a_shape
457 .iter()
458 .zip(b_shape.iter())
459 .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
460 .all(|(_, b_dim)| *b_dim == 1.to_dim())
461 }
462}
463
464impl Op for OptBinByScalar {
465 fn name(&self) -> Cow<str> {
466 format!("Opt{}ByScalar", self.binop.name()).into()
467 }
468
469 fn same_as(&self, other: &dyn Op) -> bool {
470 let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
471 self.binop.same_as(&*other.binop)
472 }
473
474 op_as_typed_op!();
475}
476
477impl EvalOp for OptBinByScalar {
478 fn is_stateless(&self) -> bool {
479 true
480 }
481
482 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
483 let (a, b) = args_2!(inputs);
484 let a = a.into_tensor();
487 let b_shape = b.shape();
488
489 let first_unary_axis = b_shape
490 .iter()
491 .enumerate()
492 .rev()
493 .take_while(|&(_, &dim)| dim == 1)
494 .map(|(i, _)| i)
495 .last()
496 .context("Cannot use by_scalar when no trailing dimensions are unary")?;
497
498 let iterating_shape = &a.shape()[..first_unary_axis];
499 if !iterating_shape.is_empty() {
500 for it_coords in tract_ndarray::indices(iterating_shape) {
501 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
502 let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
503 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
504 (self.eval_fn)(&mut view, &b_view)?;
505 }
506 } else {
507 let mut view = a.view();
508 let b_view = b.view();
509 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
510 (self.eval_fn)(&mut view, &b_view)?;
511 }
512 Ok(tvec!(a.into_tvalue()))
513 }
514}
515
516impl TypedOp for OptBinByScalar {
517 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
518 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
519 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
520 let out_shape = inputs[0].shape.clone();
521 Ok(tvec!(out_dt.fact(out_shape)))
522 }
523
524 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
525 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
526 Ok(self
527 .binop
528 .cost_per_element(inputs[0].datum_type)
529 .into_iter()
530 .map(|(c, n)| (c, count.clone() * n))
531 .collect())
532 }
533
534 as_op!();
535}
536
537#[derive(Clone)]
538pub struct OptBinUnicast {
539 pub binop: Box<dyn BinMiniOp>,
540 eval_fn: Arc<LinalgFn>,
541}
542
543impl Debug for OptBinUnicast {
544 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
545 f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
546 }
547}
548
549impl OptBinUnicast {
550 fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
551 let num_iterations: TDim = a_shape
552 .iter()
553 .zip(b_shape.iter())
554 .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
555 .map(|(a_dim, _)| a_dim)
556 .product();
557
558 if num_iterations.is_one() {
559 return true;
560 }
561
562 let elements_per_iteration: TDim = a_shape
563 .iter()
564 .zip(b_shape.iter())
565 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
566 .map(|(_, b_dim)| b_dim)
567 .product();
568
569 if let Ok(num_element) = elements_per_iteration.to_i64() {
570 let required_alignment = vector_size();
571 (num_element as usize % required_alignment) == 0
572 } else {
573 false
574 }
575 }
576 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
577 if a_shape.len() != b_shape.len() {
578 return false;
579 };
580
581 let unicast_possible = a_shape
582 .iter()
583 .zip(b_shape.iter())
584 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
585 .all(|(a_dim, b_dim)| a_dim == b_dim);
586 let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
587
588 unicast_possible && unicast_is_aligned
589 }
590}
591
592impl Op for OptBinUnicast {
593 fn name(&self) -> Cow<str> {
594 format!("Opt{}Unicast", self.binop.name()).into()
595 }
596
597 fn same_as(&self, other: &dyn Op) -> bool {
598 let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
599 self.binop.same_as(&*other.binop)
600 }
601 op_as_typed_op!();
602}
603
604impl EvalOp for OptBinUnicast {
605 fn is_stateless(&self) -> bool {
606 true
607 }
608
609 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
610 let (a, b) = args_2!(inputs);
611 let a = a.into_tensor();
614 let b_shape = b.shape();
615 let b_view = b.view();
616 let first_non_unary_axis =
617 b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
618
619 if let Some(first_non_unary_axis) = first_non_unary_axis {
620 let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
622 for it_coords in tract_ndarray::indices(iterating_shape) {
623 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
624 debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
625 (self.eval_fn)(&mut view, &b_view)?;
626 }
627 } else {
628 let mut view = a.view();
629 debug_assert_eq!(view.shape(), b_view.shape());
630 (self.eval_fn)(&mut view, &b_view)?;
631 }
632
633 Ok(tvec!(a.into_tvalue()))
634 }
635}
636
637impl TypedOp for OptBinUnicast {
638 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
639 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
640 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
641 let out_shape = inputs[0].shape.clone();
642 Ok(tvec!(out_dt.fact(out_shape)))
643 }
644
645 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
646 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
647 Ok(self
648 .binop
649 .cost_per_element(inputs[0].datum_type)
650 .into_iter()
651 .map(|(c, n)| (c, count.clone() * n))
652 .collect())
653 }
654
655 as_op!();
656}
657
658#[macro_export]
659macro_rules! bin_to_super_type {
660 ($func:ident, $Op:ident,
661 $(codegen: $codegen:expr,)?
662 $(cost: $cost:expr,)?
663 $(declutter: $declutter:expr,)?
664 $(eval_in_a: $eval_in_a:expr,)?
665 $(eval_override: $eval_override: expr,)?
666 $(linalg: $linalg:ident,)?
667 $(operating_datum_type: $operating_datum_type:expr,)?
668 $(is_commutative: $is_commutative:expr,)?
669 $(neutral_element: $neutral_element:expr,)?
670 $(out_of_place: $out_of_place:expr,)?
671 $(validation: $validation:expr,)?
672 $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
673 $(q_op_on_f32: $q_op_on_f32:expr,)?
674 $( [$($typ:ident),*] => $cab:expr),*) => {
675 #[derive(Debug, Clone, Hash)]
676 pub struct $Op;
677 #[allow(clippy::redundant_closure_call)]
678 impl $crate::ops::binary::BinMiniOp for $Op {
679 fn name(&self) -> &'static str {
680 stringify!($Op)
681 }
682
683 fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
684 other.downcast_ref::<$Op>().is_some()
685 }
686
687 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
688 $(if $out_of_place(c, a, b)? { return Ok(()) } )?
689 $(
690 $(if c.datum_type() == $typ::datum_type() {
691 let a = a.to_array_view::<$typ>()?;
692 let b = b.to_array_view::<$typ>()?;
693 let mut c = c.to_array_view_mut::<$typ>()?;
694 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
695 return Ok(())
696 })*
697 )*
698 $(
699 $(
700 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
701 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
702 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
703 let a = a.to_array_view::<$typ_dt>()?;
704 let b = b.to_array_view::<$typ_dt>()?;
705 let mut c = c.to_array_view_mut::<$typ_dt>()?;
706 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
707 return Ok(())
708 }
709 )*
710 )*
711 )?
712 bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
713 }
714
715 $(fn is_commutative(&self) -> bool {
716 $is_commutative
717 })?
718 $(fn neutral_element(&self) -> Option<i64> {
719 Some($neutral_element)
720 })?
721 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
722 $(if $eval_in_a(a, b)? { return Ok(()) } )?
724 $(
725 $(if b.datum_type() == $typ::datum_type() {
726 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
727 let b = b.to_array_view::<$typ>()?;
728 let mut a = a.to_array_view_mut::<$typ>()?;
729 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
730 return Ok(())
731 })*
732 )*
733 $(
734 $(
735 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
736 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
737 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
738 let mut a = a.to_array_view_mut::<$typ_dt>()?;
739 let b = b.to_array_view::<$typ_dt>()?;
740 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
741 cab(a, &(a.clone()), b, zp, scale)
742 });
743 return Ok(())
744 })*
745 )*
746 )?
747 bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
748 }
749
750 $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
751 $eval_override(a, b, c_dt)
752 })?
753
754 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
755 if a.unquantized() == b.unquantized() {
756 if a.is_quantized() || !b.is_quantized() {
757 return Ok(a)
758 }
759 else {
760 return Ok(b)
761 }
762 }
763 self.operating_datum_type(a, b)
764 }
765
766 $(
767 fn declutter(
768 &self,
769 model: &TypedModel,
770 node: &TypedNode,
771 ) -> TractResult<Option<TypedModelPatch>> {
772 ($declutter)(self, model, node)
773 }
774 )?
775 $(
776 fn codegen(
777 &self,
778 model: &TypedModel,
779 node: &TypedNode,
780 a: &Arc<Tensor>,
781 ) -> TractResult<Option<TypedModelPatch>> {
782 ($codegen)(self, model, node, a)
783 }
784 )?
785 $(
786 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
787 ($cost)(dt)
788 }
789 )?
790 $(
791 fn validation(&self) -> Validation {
792 $validation
793 }
794 )?
795 $(
796 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
797 Some(tract_linalg::BinOp::$linalg)
798 }
799 )?
800 $(
801 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
802 ($operating_datum_type)(a, b)
803 })?
804
805
806 #[allow(unused_variables)]
810 fn maybe_eval_qbinary_as_float_op(
811 &self,
812 a: &TValue,
813 b: &TValue,
814 c_dt: &DatumType,
815 ) -> TractResult<Option<Tensor>> {
816 $(
817 fn memory_optimised_q_binary_as_float_op(
821 a: &TValue,
822 b: &TValue,
823 c_dt: &DatumType,
824 ) -> TractResult<Option<Tensor>> {
825 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
826 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
827 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
828 (a.datum_type(), b.datum_type(), c_dt)
829 {
830 let c_inv_scale = 1.0 / c_scale;
831 let a = a.to_array_view::<u8>()?;
832 let b = b.to_array_view::<u8>()?;
833 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
834 let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
835 let view = c.to_array_view_mut::<u8>()?;
836 $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
837 *c = (scale_by($q_op_on_f32(
838 ((*a as i32 - a_zp as i32) as f32 * a_scale),
839 ((*b as i32 - b_zp as i32) as f32 * b_scale),
840 ), c_inv_scale) as i32
841 + *c_zp as i32)
842 .clamp_cast()
843 });
844 return Ok(Some(c));
845 }
846 Ok(None)
847 }
848
849 fn generic_q_binary_as_float_op(
853 a: &TValue,
854 b: &TValue,
855 c_dt: &DatumType,
856 accumulator_dt: DatumType
857 ) -> TractResult<Option<Tensor>> {
858 if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
859 let a = a.cast_to_dt(accumulator_dt)?.into_owned();
860 let b = b.cast_to_dt(accumulator_dt)?.into_owned();
861 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
862 let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
863 match accumulator_dt {
864 DatumType::F32 => {
865 let view = c.to_array_view_mut::<f32>()?;
866 $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
867 *c = $q_op_on_f32(*a,*b);
868 })
869 },
870 other => bail!("unexpected accumulator data type as {:?}", other)
871 };
872
873 return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
874 }
875 Ok(None)
876 }
877
878 if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
879 return Ok(Some(c));
880 }
881 if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
882 return Ok(Some(d));
883 }
884 )?
885 Ok(None)
886 }
887 }
888
889 pub fn $func() -> $crate::ops::binary::TypedBinOp {
890 $crate::ops::binary::TypedBinOp(Box::new($Op), None)
891 }
892 };
893}
894
895#[derive(Debug)]
896pub(crate) struct OneUniformInput {
897 pub uni: Arc<Tensor>,
898 pub var: OutletId,
899 pub left_is_uniform: bool,
900}
901
902pub(crate) fn one_input_is_uniform(
903 model: &TypedModel,
904 node: &TypedNode,
905) -> TractResult<Option<OneUniformInput>> {
906 if let &[a, b] = &*model.node_input_facts(node.id)? {
907 let uni = if let Some(a) = &a.uniform {
908 OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
909 } else if let Some(b) = &b.uniform {
910 OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
911 } else {
912 return Ok(None);
913 };
914 let var_fact = [a, b][uni.left_is_uniform as usize];
915 let uni_fact = [a, b][!uni.left_is_uniform as usize];
916 if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
917 return Ok(Some(uni));
918 }
919 }
920 Ok(None)
921}