tract_core/ops/
binary.rs

1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::{cast::cast, math::SubF};
10
11pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
12    fn name(&self) -> &'static str;
13    fn validation(&self) -> Validation {
14        Validation::Accurate
15    }
16    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
17        a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
18    }
19    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
20    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
21    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
22
23    fn is_commutative(&self) -> bool {
24        true
25    }
26    fn neutral_element(&self) -> Option<i64> {
27        None
28    }
29
30    #[allow(unused_variables)]
31    fn maybe_eval_qbinary_as_float_op(
32        &self,
33        a: &TValue,
34        b: &TValue,
35        c_dt: &DatumType,
36    ) -> TractResult<Option<Tensor>> {
37        Ok(None)
38    }
39
40    fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
41        if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
42            Ok(tensor)
43        } else {
44            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
45            if &*c_shape == a.shape() && c_dt == a.datum_type() {
46                let mut a = a.into_tensor();
47                self.eval_in_a(&mut a, &b)?;
48                Ok(a)
49            } else {
50                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
51                self.eval_out_of_place(&mut c, &a, &b)?;
52                Ok(c)
53            }
54        }
55    }
56    fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
57        self.generic_eval(a, b, c_dt)
58    }
59    #[allow(unused_variables)]
60    fn declutter(
61        &self,
62        model: &TypedModel,
63        node: &TypedNode,
64    ) -> TractResult<Option<TypedModelPatch>> {
65        Ok(None)
66    }
67    #[allow(unused_variables)]
68    fn codegen(
69        &self,
70        model: &TypedModel,
71        node: &TypedNode,
72    ) -> TractResult<Option<TypedModelPatch>> {
73        Ok(None)
74    }
75    #[allow(unused_variables)]
76    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
77        tvec!()
78    }
79    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
80        None
81    }
82
83    #[allow(unused_variables)]
84    fn same_as(&self, other: &dyn BinMiniOp) -> bool {
85        false
86    }
87}
88dyn_clone::clone_trait_object!(BinMiniOp);
89downcast_rs::impl_downcast!(BinMiniOp);
90
91#[derive(Debug, Clone)]
92pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
93
94impl Op for TypedBinOp {
95    fn name(&self) -> Cow<str> {
96        self.0.name().into()
97    }
98
99    fn validation(&self) -> Validation {
100        self.0.validation()
101    }
102
103    fn same_as(&self, other: &dyn Op) -> bool {
104        let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
105        self.1 == other.1 && self.0.same_as(&*other.0)
106    }
107
108    op_as_typed_op!();
109}
110
111impl TypedBinOp {
112    fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
113        if let Some(dt) = self.1 {
114            Ok(dt)
115        } else {
116            self.0.result_datum_type(a_dt, b_dt)
117        }
118    }
119}
120
121impl EvalOp for TypedBinOp {
122    fn is_stateless(&self) -> bool {
123        true
124    }
125
126    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
127        let (a, b) = args_2!(inputs);
128        ensure!(a.rank() == b.rank());
129        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
130        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
131    }
132}
133
134impl TypedOp for TypedBinOp {
135    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
136        if inputs[0].rank() != inputs[1].rank() {
137            bail!(
138                "Typed ops require rank match. Invalid inputs for {}: {}",
139                self.name(),
140                inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
141            );
142        }
143        let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
144        Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
145            &inputs[0].shape.to_tvec(),
146            &inputs[1].shape.to_tvec()
147        ])?)))
148    }
149
150    fn change_axes(
151        &self,
152        model: &TypedModel,
153        node: &TypedNode,
154        _io: InOut,
155        change: &AxisOp,
156    ) -> TractResult<Option<AxisChangeConsequence>> {
157        if let AxisOp::Rm(rm) = change {
158            let (inputs, outputs) = model.node_facts(node.id)?;
159            if !inputs[0].shape[*rm].is_one()
160                || !inputs[1].shape[*rm].is_one()
161                || !outputs[0].shape[*rm].is_one()
162            {
163                return Ok(None);
164            }
165        }
166        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
167    }
168
169    fn axes_mapping(
170        &self,
171        inputs: &[&TypedFact],
172        outputs: &[&TypedFact],
173    ) -> TractResult<AxesMapping> {
174        AxesMapping::natural(inputs, outputs)
175    }
176
177    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
178        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
179        Ok(self
180            .0
181            .cost_per_element(inputs[0].datum_type)
182            .into_iter()
183            .map(|(c, n)| (c, count.clone() * n))
184            .collect())
185    }
186
187    fn slice(
188        &self,
189        patch: &mut TypedModelPatch,
190        _model: &TypedModel,
191        _node: &TypedNode,
192        prefix: &str,
193        inputs: &[OutletId],
194        _output_axis: usize,
195        _start: usize,
196        _end: usize,
197    ) -> TractResult<Option<TVec<OutletId>>> {
198        Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
199    }
200
201    fn declutter(
202        &self,
203        model: &TypedModel,
204        node: &TypedNode,
205    ) -> TractResult<Option<TypedModelPatch>> {
206        let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
207            (a.datum_type().unwrap(), b.datum_type().unwrap())
208        } else {
209            unreachable!("TypedBinOp has two inputs.")
210        };
211        if let Some(neutral_patch) =
212            declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
213        {
214            return Ok(Some(neutral_patch));
215        }
216        if let Some(broadcast_patch) =
217            declutter_broadcasting_operand_1(model, node, self.0.clone())?
218        {
219            return Ok(Some(broadcast_patch));
220        }
221        self.0.declutter(model, node)
222    }
223
224    fn codegen(
225        &self,
226        model: &TypedModel,
227        node: &TypedNode,
228    ) -> TractResult<Option<TypedModelPatch>> {
229        if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
230            let (operand_1, operand_2, should_swap, op) =
231                if let &[a, b] = &*model.node_input_facts(node.id)? {
232                    let num_elements_1 = a.shape.iter().product::<TDim>();
233                    let num_elements_2 = b.shape.iter().product::<TDim>();
234                    let operand_1_should_be_broadcast =
235                        (num_elements_1 - num_elements_2).prove_strict_negative();
236
237                    let is_sub = linalg_bin_op == BinOp::Sub;
238                    if operand_1_should_be_broadcast & is_sub {
239                        let sub_flipped: Box<dyn BinMiniOp> = Box::new(SubF {});
240                        (b, a, true, sub_flipped)
241                    } else {
242                        (a, b, false, self.0.clone())
243                    }
244                } else {
245                    unreachable!("TypedBinOp has two inputs.")
246                };
247
248            let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
249                find_most_efficient_config(model, node, should_swap)?;
250
251            // Check if op is quantized
252            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
253            let op_is_quant = c_dt.is_quantized()
254                || operand_1.datum_type.is_quantized()
255                || operand_2.datum_type.is_quantized();
256
257            // Check if it can be evaluated in a
258            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
259            let c_shape = crate::broadcast::multi_broadcast(&[
260                operand_1.shape.clone(),
261                operand_2.shape.clone(),
262            ])?;
263            let can_eval_in_a =
264                (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
265
266            // Swap input if required
267            let inputs = if should_swap {
268                let mut swap_input = node.inputs.clone();
269                swap_input.swap(0, 1);
270                swap_input
271            } else {
272                node.inputs.clone()
273            };
274
275            let dt = model.node_input_facts(node.id)?[0].datum_type().unwrap();
276            if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
277                let Some(func) = tract_linalg::bin_by_scalar(dt, linalg_bin_op) else {
278                    return Ok(None);
279                };
280                let eval_fn = Arc::from(func);
281                return Ok(Some(
282                    TypedModelPatch::replace_single_op(
283                        model,
284                        node,
285                        &inputs,
286                        OptBinByScalar { binop: op, eval_fn },
287                    )?
288                    .with_context("ByScalar"),
289                ));
290            }
291
292            if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
293                let Some(func) = tract_linalg::bin_unicast(dt, linalg_bin_op) else {
294                    return Ok(None);
295                };
296                let eval_fn = Arc::from(func);
297                return Ok(Some(
298                    TypedModelPatch::replace_single_op(
299                        model,
300                        node,
301                        &inputs,
302                        OptBinUnicast { binop: op, eval_fn },
303                    )?
304                    .with_context("Unicast"),
305                ));
306            }
307        }
308
309        Ok(None)
310    }
311    as_op!();
312}
313
314fn declutter_broadcasting_operand_1(
315    model: &TypedModel,
316    node: &TypedNode,
317    mini_op: Box<dyn BinMiniOp>,
318) -> TractResult<Option<TypedModelPatch>> {
319    let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
320        (a.shape.clone(), b.shape.clone())
321    } else {
322        unreachable!("TypedBinOp has two inputs.")
323    };
324
325    let a_num_elements = a_shape.iter().product::<TDim>();
326    let b_num_elements = b_shape.iter().product::<TDim>();
327    let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
328    if a_should_be_broadcast & mini_op.is_commutative() {
329        let mut swap_input = node.inputs.clone();
330        swap_input.swap(0, 1);
331        return Ok(Some(TypedModelPatch::replace_single_op(
332            model,
333            node,
334            &swap_input,
335            TypedBinOp(mini_op, None),
336        )?));
337    }
338
339    Ok(None)
340}
341
342fn declutter_neutral(
343    model: &TypedModel,
344    node: &TypedNode,
345    mini_op: &dyn BinMiniOp,
346    out_dt: DatumType,
347) -> TractResult<Option<TypedModelPatch>> {
348    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
349        let is_neutral = mini_op
350            .neutral_element()
351            .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
352            .unwrap_or(false);
353
354        // For some operand neural element can be the left one while for other
355        // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok)
356        let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
357
358        if is_neutral && pos_checked {
359            // Neutral decluttering for quant values is special.
360            // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs
361            // - then even if fa = fc, quant params needs to be updated (a != c).
362            // So it's not a no_op.
363            if uniform.uni.datum_type().is_quantized() {
364                return Ok(Some(TypedModelPatch::replace_single_op(
365                    model,
366                    node,
367                    &[node.inputs[0]],
368                    cast(out_dt),
369                )?));
370            // In the non quantized case, it's a no_op.
371            } else {
372                return Ok(Some(TypedModelPatch::rewire(
373                    model,
374                    &[uniform.var],
375                    &[node.id.into()],
376                    &|_, inputs| Ok(inputs.into()),
377                )?));
378            }
379        }
380    }
381    Ok(None)
382}
383
384fn find_most_efficient_config(
385    model: &TypedModel,
386    node: &TypedNode,
387    swap_input: bool,
388) -> TractResult<(bool, bool)> {
389    if let &[a, b] = &*model.node_input_facts(node.id)? {
390        let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
391        let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
392
393        let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
394        let num_by_scalar_elements = if by_scalar_is_possible {
395            a_shape
396                .iter()
397                .zip(b_shape.iter())
398                .rev()
399                .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
400                .map(|(rev_a_dim, _)| rev_a_dim)
401                .product::<TDim>()
402        } else {
403            TDim::Val(0)
404        };
405
406        let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
407        let num_unicast_elements = if unicast_is_possible {
408            a_shape
409                .iter()
410                .zip(b_shape.iter())
411                .rev()
412                .take_while(|(a_dim, b_dim)| a_dim == b_dim)
413                .map(|(a_dim, _)| a_dim)
414                .product::<TDim>()
415        } else {
416            TDim::Val(0)
417        };
418
419        let min_num_elements = 32;
420        let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
421        let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
422        return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
423    }
424    Ok((false, false))
425}
426
427pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
428    TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
429}
430
431#[derive(Clone)]
432pub struct OptBinByScalar {
433    pub binop: Box<dyn BinMiniOp>,
434    eval_fn: Arc<LinalgFn>,
435}
436
437impl Debug for OptBinByScalar {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
439        f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
440    }
441}
442
443impl OptBinByScalar {
444    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
445        if a_shape.len() != b_shape.len() {
446            return false;
447        };
448
449        a_shape
450            .iter()
451            .zip(b_shape.iter())
452            .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
453            .all(|(_, b_dim)| *b_dim == 1.to_dim())
454    }
455}
456
457impl Op for OptBinByScalar {
458    fn name(&self) -> Cow<str> {
459        format!("Opt{}ByScalar", self.binop.name()).into()
460    }
461
462    fn same_as(&self, other: &dyn Op) -> bool {
463        let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
464        self.binop.same_as(&*other.binop)
465    }
466
467    op_as_typed_op!();
468}
469
470impl EvalOp for OptBinByScalar {
471    fn is_stateless(&self) -> bool {
472        true
473    }
474
475    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
476        let (a, b) = args_2!(inputs);
477        // Not a requirement as TensorView doesn't require a owned tensor but in reality
478        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
479        let a = a.into_tensor();
480        let b_shape = b.shape();
481
482        let first_unary_axis = b_shape
483            .iter()
484            .enumerate()
485            .rev()
486            .take_while(|&(_, &dim)| dim == 1)
487            .map(|(i, _)| i)
488            .last()
489            .context("Cannot use by_scalar when no trailing dimensions are unary")?;
490
491        let iterating_shape = &a.shape()[..first_unary_axis];
492        if !iterating_shape.is_empty() {
493            for it_coords in tract_ndarray::indices(iterating_shape) {
494                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
495                let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
496                debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
497                (self.eval_fn)(&mut view, &b_view)?;
498            }
499        } else {
500            let mut view = a.view();
501            let b_view = b.view();
502            debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
503            (self.eval_fn)(&mut view, &b_view)?;
504        }
505        Ok(tvec!(a.into_tvalue()))
506    }
507}
508
509impl TypedOp for OptBinByScalar {
510    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
511        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
512        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
513        let out_shape = inputs[0].shape.clone();
514        Ok(tvec!(out_dt.fact(out_shape)))
515    }
516
517    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
518        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
519        Ok(self
520            .binop
521            .cost_per_element(inputs[0].datum_type)
522            .into_iter()
523            .map(|(c, n)| (c, count.clone() * n))
524            .collect())
525    }
526
527    as_op!();
528}
529
530#[derive(Clone)]
531pub struct OptBinUnicast {
532    pub binop: Box<dyn BinMiniOp>,
533    eval_fn: Arc<LinalgFn>,
534}
535
536impl Debug for OptBinUnicast {
537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
538        f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
539    }
540}
541
542impl OptBinUnicast {
543    fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
544        let num_iterations: TDim = a_shape
545            .iter()
546            .zip(b_shape.iter())
547            .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
548            .map(|(a_dim, _)| a_dim)
549            .product();
550
551        if num_iterations.is_one() {
552            return true;
553        }
554
555        let elements_per_iteration: TDim = a_shape
556            .iter()
557            .zip(b_shape.iter())
558            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
559            .map(|(_, b_dim)| b_dim)
560            .product();
561
562        if let Ok(num_element) = elements_per_iteration.to_i64() {
563            let required_alignment = vector_size();
564            (num_element as usize % required_alignment) == 0
565        } else {
566            false
567        }
568    }
569    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
570        if a_shape.len() != b_shape.len() {
571            return false;
572        };
573
574        let unicast_possible = a_shape
575            .iter()
576            .zip(b_shape.iter())
577            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
578            .all(|(a_dim, b_dim)| a_dim == b_dim);
579        let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
580
581        unicast_possible && unicast_is_aligned
582    }
583}
584
585impl Op for OptBinUnicast {
586    fn name(&self) -> Cow<str> {
587        format!("Opt{}Unicast", self.binop.name()).into()
588    }
589
590    fn same_as(&self, other: &dyn Op) -> bool {
591        let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
592        self.binop.same_as(&*other.binop)
593    }
594    op_as_typed_op!();
595}
596
597impl EvalOp for OptBinUnicast {
598    fn is_stateless(&self) -> bool {
599        true
600    }
601
602    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
603        let (a, b) = args_2!(inputs);
604        // Not a requirement as TensorView doesn't require a owned tensor but in reality
605        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
606        let a = a.into_tensor();
607        let b_shape = b.shape();
608        let b_view = b.view();
609        let first_non_unary_axis =
610            b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
611
612        if let Some(first_non_unary_axis) = first_non_unary_axis {
613            // Iterate on outter dimensions and evaluate with unicast subviews
614            let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
615            for it_coords in tract_ndarray::indices(iterating_shape) {
616                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
617                debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
618                (self.eval_fn)(&mut view, &b_view)?;
619            }
620        } else {
621            let mut view = a.view();
622            debug_assert_eq!(view.shape(), b_view.shape());
623            (self.eval_fn)(&mut view, &b_view)?;
624        }
625
626        Ok(tvec!(a.into_tvalue()))
627    }
628}
629
630impl TypedOp for OptBinUnicast {
631    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
632        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
633        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
634        let out_shape = inputs[0].shape.clone();
635        Ok(tvec!(out_dt.fact(out_shape)))
636    }
637
638    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
639        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
640        Ok(self
641            .binop
642            .cost_per_element(inputs[0].datum_type)
643            .into_iter()
644            .map(|(c, n)| (c, count.clone() * n))
645            .collect())
646    }
647
648    as_op!();
649}
650
651#[macro_export]
652macro_rules! bin_to_super_type {
653    ($func:ident, $Op:ident,
654     $(codegen: $codegen:expr,)?
655     $(cost: $cost:expr,)?
656     $(declutter: $declutter:expr,)?
657     $(eval_in_a: $eval_in_a:expr,)?
658     $(eval_override: $eval_override: expr,)?
659     $(linalg: $linalg:ident,)?
660     $(operating_datum_type: $operating_datum_type:expr,)?
661     $(is_commutative: $is_commutative:expr,)?
662     $(neutral_element: $neutral_element:expr,)?
663     $(out_of_place: $out_of_place:expr,)?
664     $(validation: $validation:expr,)?
665     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
666     $(q_op_on_f32: $q_op_on_f32:expr,)?
667     $( [$($typ:ident),*] => $cab:expr),*) => {
668        #[derive(Debug, Clone, Hash)]
669        pub struct $Op;
670        #[allow(clippy::redundant_closure_call)]
671        impl $crate::ops::binary::BinMiniOp for $Op {
672            fn name(&self) -> &'static str {
673                stringify!($Op)
674            }
675
676            fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
677                other.downcast_ref::<$Op>().is_some()
678            }
679
680            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
681                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
682                    $(
683                        $(if c.datum_type() == $typ::datum_type() {
684                            let a = a.to_array_view::<$typ>()?;
685                            let b = b.to_array_view::<$typ>()?;
686                            let mut c = c.to_array_view_mut::<$typ>()?;
687                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
688                            return Ok(())
689                        })*
690                     )*
691                    $(
692                        $(
693                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
694                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
695                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
696                                let a = a.to_array_view::<$typ_dt>()?;
697                                let b = b.to_array_view::<$typ_dt>()?;
698                                let mut c = c.to_array_view_mut::<$typ_dt>()?;
699                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
700                                return Ok(())
701                            }
702                            )*
703                         )*
704                     )?
705                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
706            }
707
708            $(fn is_commutative(&self) -> bool {
709                $is_commutative
710            })?
711            $(fn neutral_element(&self) -> Option<i64> {
712                Some($neutral_element)
713            })?
714            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
715                // c and a are same type
716                $(if $eval_in_a(a, b)? { return Ok(()) } )?
717                $(
718                    $(if b.datum_type() == $typ::datum_type() {
719                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
720                        let b = b.to_array_view::<$typ>()?;
721                        let mut a = a.to_array_view_mut::<$typ>()?;
722                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
723                        return Ok(())
724                    })*
725                )*
726                $(
727                    $(
728                        $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
729                            let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
730                            let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
731                            let mut a = a.to_array_view_mut::<$typ_dt>()?;
732                            let b = b.to_array_view::<$typ_dt>()?;
733                            $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
734                                cab(a, &(a.clone()), b, zp, scale)
735                            });
736                            return Ok(())
737                        })*
738                    )*
739                )?
740                bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
741            }
742
743            $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
744                $eval_override(a, b, c_dt)
745            })?
746
747            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
748                if a.unquantized() == b.unquantized() {
749                    if a.is_quantized() || !b.is_quantized() {
750                        return Ok(a)
751                    }
752                    else {
753                        return Ok(b)
754                    }
755                }
756                self.operating_datum_type(a, b)
757            }
758
759                $(
760                    fn declutter(
761                        &self,
762                        model: &TypedModel,
763                        node: &TypedNode,
764                        ) -> TractResult<Option<TypedModelPatch>> {
765                        ($declutter)(self, model, node)
766                    }
767                 )?
768                $(
769                    fn codegen(
770                        &self,
771                        model: &TypedModel,
772                        node: &TypedNode,
773                        a: &Arc<Tensor>,
774                        ) -> TractResult<Option<TypedModelPatch>> {
775                        ($codegen)(self, model, node, a)
776                    }
777                 )?
778                $(
779                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
780                        ($cost)(dt)
781                    }
782                 )?
783                $(
784                    fn validation(&self) -> Validation {
785                        $validation
786                    }
787                 )?
788                $(
789                    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
790                        Some(tract_linalg::BinOp::$linalg)
791                    }
792                 )?
793                $(
794                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
795                        ($operating_datum_type)(a, b)
796                    })?
797
798
799            /// Default simple binary operation for QFormat where
800            /// we dequantise & apply requested operation in float & requantize it
801            /// several implementation are provided with pro & con
802            #[allow(unused_variables)]
803            fn maybe_eval_qbinary_as_float_op(
804                &self,
805                a: &TValue,
806                b: &TValue,
807                c_dt: &DatumType,
808            ) -> TractResult<Option<Tensor>> {
809                $(
810                    /// Implementation strive to minimise memory allocation and access
811                    /// we apply only if type is QU8 zp_scale datum type
812                    /// maybe more suited for large models tensors
813                    fn memory_optimised_q_binary_as_float_op(
814                        a: &TValue,
815                        b: &TValue,
816                        c_dt: &DatumType,
817                    ) -> TractResult<Option<Tensor>> {
818                        if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
819                                DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
820                                DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
821                            (a.datum_type(), b.datum_type(), c_dt)
822                        {
823                            let c_inv_scale = 1.0 / c_scale;
824                            let a = a.to_array_view::<u8>()?;
825                            let b = b.to_array_view::<u8>()?;
826                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
827                            let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
828                            let view = c.to_array_view_mut::<u8>()?;
829                            $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
830                                *c = (scale_by($q_op_on_f32(
831                                            ((*a as i32 - a_zp as i32) as f32 * a_scale),
832                                            ((*b as i32 - b_zp as i32) as f32 * b_scale),
833                                ), c_inv_scale) as i32
834                                    + *c_zp as i32)
835                                    .clamp_cast()
836                            });
837                            return Ok(Some(c));
838                        }
839                        Ok(None)
840                    }
841
842                    /// Apply to all Q types
843                    /// Take more memory but hopefully faster than memory_optimised_q_binary_as_float_op
844                    /// especially once cast_to_dt will have will have vectorized implementations
845                    fn generic_q_binary_as_float_op(
846                        a: &TValue,
847                        b: &TValue,
848                        c_dt: &DatumType,
849                        accumulator_dt: DatumType
850                    ) -> TractResult<Option<Tensor>> {
851                        if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
852                            let a = a.cast_to_dt(accumulator_dt)?.into_owned();
853                            let b = b.cast_to_dt(accumulator_dt)?.into_owned();
854                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
855                            let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
856                            match accumulator_dt {
857                                DatumType::F32 => {
858                                    let view = c.to_array_view_mut::<f32>()?;
859                                    $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
860                                        *c = $q_op_on_f32(*a,*b);
861                                    })
862                                },
863                                other => bail!("unexpected accumulator data type as {:?}", other)
864                            };
865
866                            return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
867                        }
868                        Ok(None)
869                    }
870
871                    if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
872                        return Ok(Some(c));
873                    }
874                    if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
875                        return Ok(Some(d));
876                    }
877                )?
878                Ok(None)
879            }
880        }
881
882        pub fn $func() -> $crate::ops::binary::TypedBinOp {
883            $crate::ops::binary::TypedBinOp(Box::new($Op), None)
884        }
885    };
886}
887
888#[derive(Debug)]
889pub(crate) struct OneUniformInput {
890    pub uni: Arc<Tensor>,
891    pub var: OutletId,
892    pub left_is_uniform: bool,
893}
894
895pub(crate) fn one_input_is_uniform(
896    model: &TypedModel,
897    node: &TypedNode,
898) -> TractResult<Option<OneUniformInput>> {
899    if let &[a, b] = &*model.node_input_facts(node.id)? {
900        let uni = if let Some(a) = &a.uniform {
901            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
902        } else if let Some(b) = &b.uniform {
903            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
904        } else {
905            return Ok(None);
906        };
907        let var_fact = [a, b][uni.left_is_uniform as usize];
908        let uni_fact = [a, b][!uni.left_is_uniform as usize];
909        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
910            return Ok(Some(uni));
911        }
912    }
913    Ok(None)
914}