tract_core/ops/
binary.rs

1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::math::{Add, Max, Min, Mul, Sub};
10use super::{cast::cast, math::SubF};
11
12pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
13    fn name(&self) -> &'static str;
14    fn validation(&self) -> Validation {
15        Validation::Accurate
16    }
17    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
18        a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
19    }
20    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
21    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
22    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
23
24    fn is_commutative(&self) -> bool {
25        true
26    }
27    fn neutral_element(&self) -> Option<i64> {
28        None
29    }
30
31    #[allow(unused_variables)]
32    fn maybe_eval_qbinary_as_float_op(
33        &self,
34        a: &TValue,
35        b: &TValue,
36        c_dt: &DatumType,
37    ) -> TractResult<Option<Tensor>> {
38        Ok(None)
39    }
40
41    fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
42        if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
43            Ok(tensor)
44        } else {
45            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
46            if &*c_shape == a.shape() && c_dt == a.datum_type() {
47                let mut a = a.into_tensor();
48                self.eval_in_a(&mut a, &b)?;
49                Ok(a)
50            } else {
51                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
52                self.eval_out_of_place(&mut c, &a, &b)?;
53                Ok(c)
54            }
55        }
56    }
57    fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
58        self.generic_eval(a, b, c_dt)
59    }
60    #[allow(unused_variables)]
61    fn declutter(
62        &self,
63        model: &TypedModel,
64        node: &TypedNode,
65    ) -> TractResult<Option<TypedModelPatch>> {
66        Ok(None)
67    }
68    #[allow(unused_variables)]
69    fn codegen(
70        &self,
71        model: &TypedModel,
72        node: &TypedNode,
73    ) -> TractResult<Option<TypedModelPatch>> {
74        Ok(None)
75    }
76    #[allow(unused_variables)]
77    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
78        tvec!()
79    }
80    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
81        None
82    }
83
84    #[allow(unused_variables)]
85    fn same_as(&self, other: &dyn BinMiniOp) -> bool {
86        false
87    }
88}
89dyn_clone::clone_trait_object!(BinMiniOp);
90downcast_rs::impl_downcast!(BinMiniOp);
91
92#[derive(Debug, Clone)]
93pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
94
95impl Op for TypedBinOp {
96    fn name(&self) -> Cow<str> {
97        self.0.name().into()
98    }
99
100    fn validation(&self) -> Validation {
101        self.0.validation()
102    }
103
104    fn same_as(&self, other: &dyn Op) -> bool {
105        let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
106        self.1 == other.1 && self.0.same_as(&*other.0)
107    }
108
109    op_as_typed_op!();
110}
111
112impl TypedBinOp {
113    fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
114        if let Some(dt) = self.1 {
115            Ok(dt)
116        } else {
117            self.0.result_datum_type(a_dt, b_dt)
118        }
119    }
120}
121
122impl EvalOp for TypedBinOp {
123    fn is_stateless(&self) -> bool {
124        true
125    }
126
127    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
128        let (a, b) = args_2!(inputs);
129        ensure!(a.rank() == b.rank());
130        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
131        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
132    }
133}
134
135impl TypedOp for TypedBinOp {
136    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
137        if inputs[0].rank() != inputs[1].rank() {
138            bail!(
139                "Typed ops require rank match. Invalid inputs for {}: {}",
140                self.name(),
141                inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
142            );
143        }
144        let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
145        Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
146            &inputs[0].shape.to_tvec(),
147            &inputs[1].shape.to_tvec()
148        ])?)))
149    }
150
151    fn change_axes(
152        &self,
153        model: &TypedModel,
154        node: &TypedNode,
155        _io: InOut,
156        change: &AxisOp,
157    ) -> TractResult<Option<AxisChangeConsequence>> {
158        if let AxisOp::Rm(rm) = change {
159            let (inputs, outputs) = model.node_facts(node.id)?;
160            if !inputs[0].shape[*rm].is_one()
161                || !inputs[1].shape[*rm].is_one()
162                || !outputs[0].shape[*rm].is_one()
163            {
164                return Ok(None);
165            }
166        }
167        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
168    }
169
170    fn axes_mapping(
171        &self,
172        inputs: &[&TypedFact],
173        outputs: &[&TypedFact],
174    ) -> TractResult<AxesMapping> {
175        AxesMapping::natural(inputs, outputs)
176    }
177
178    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
179        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
180        Ok(self
181            .0
182            .cost_per_element(inputs[0].datum_type)
183            .into_iter()
184            .map(|(c, n)| (c, count.clone() * n))
185            .collect())
186    }
187
188    fn slice(
189        &self,
190        patch: &mut TypedModelPatch,
191        _model: &TypedModel,
192        _node: &TypedNode,
193        prefix: &str,
194        inputs: &[OutletId],
195        _output_axis: usize,
196        _start: &TDim,
197        _end: &TDim,
198    ) -> TractResult<Option<TVec<OutletId>>> {
199        Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
200    }
201
202    fn declutter(
203        &self,
204        model: &TypedModel,
205        node: &TypedNode,
206    ) -> TractResult<Option<TypedModelPatch>> {
207        let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
208            (a.datum_type().unwrap(), b.datum_type().unwrap())
209        } else {
210            unreachable!("TypedBinOp has two inputs.")
211        };
212        if let Some(neutral_patch) =
213            declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
214        {
215            return Ok(Some(neutral_patch));
216        }
217        if let Some(broadcast_patch) =
218            declutter_broadcasting_operand_1(model, node, self.0.clone())?
219        {
220            return Ok(Some(broadcast_patch));
221        }
222        self.0.declutter(model, node)
223    }
224
225    fn codegen(
226        &self,
227        model: &TypedModel,
228        node: &TypedNode,
229    ) -> TractResult<Option<TypedModelPatch>> {
230        if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
231            let input_facts = model.node_input_facts(node.id)?;
232            let must_swap_inputs =
233                input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
234                    (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
235                });
236            let (operand_1, operand_2) = if must_swap_inputs {
237                (input_facts[1], input_facts[0])
238            } else {
239                (input_facts[0], input_facts[1])
240            };
241
242            let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
243                find_most_efficient_config(model, node, must_swap_inputs)?;
244
245            // Check if op is quantized
246            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
247            let op_is_quant = c_dt.is_quantized()
248                || operand_1.datum_type.is_quantized()
249                || operand_2.datum_type.is_quantized();
250
251            // Check if it can be evaluated in a
252            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
253            let c_shape = crate::broadcast::multi_broadcast(&[
254                operand_1.shape.clone(),
255                operand_2.shape.clone(),
256            ])?;
257            let can_eval_in_a =
258                (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
259
260            // Swap input if required
261            let inputs = if must_swap_inputs {
262                let mut swap_input = node.inputs.clone();
263                swap_input.swap(0, 1);
264                swap_input
265            } else {
266                node.inputs.clone()
267            };
268            let actual_linalg_op =
269                if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
270            let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
271
272            let dt = model.node_input_facts(node.id)?[0].datum_type;
273            if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
274                let Some(func) = tract_linalg::bin_by_scalar(dt, actual_linalg_op) else {
275                    return Ok(None);
276                };
277                let eval_fn = Arc::from(func);
278                return Ok(Some(
279                    TypedModelPatch::replace_single_op(
280                        model,
281                        node,
282                        &inputs,
283                        OptBinByScalar { binop: actual_core_op, eval_fn },
284                    )?
285                    .with_context("ByScalar"),
286                ));
287            }
288
289            if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
290                let Some(func) = tract_linalg::bin_unicast(dt, actual_linalg_op) else {
291                    return Ok(None);
292                };
293                let eval_fn = Arc::from(func);
294                return Ok(Some(
295                    TypedModelPatch::replace_single_op(
296                        model,
297                        node,
298                        &inputs,
299                        OptBinUnicast { binop: actual_core_op, eval_fn },
300                    )?
301                    .with_context("Unicast"),
302                ));
303            }
304        }
305
306        Ok(None)
307    }
308    as_op!();
309}
310
311fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
312    match linalg {
313        BinOp::Min => Box::new(Min),
314        BinOp::Max => Box::new(Max),
315        BinOp::Add => Box::new(Add),
316        BinOp::Mul => Box::new(Mul),
317        BinOp::Sub => Box::new(Sub),
318        BinOp::SubF => Box::new(SubF),
319    }
320}
321fn declutter_broadcasting_operand_1(
322    model: &TypedModel,
323    node: &TypedNode,
324    mini_op: Box<dyn BinMiniOp>,
325) -> TractResult<Option<TypedModelPatch>> {
326    let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
327        (a.shape.clone(), b.shape.clone())
328    } else {
329        unreachable!("TypedBinOp has two inputs.")
330    };
331
332    let a_num_elements = a_shape.iter().product::<TDim>();
333    let b_num_elements = b_shape.iter().product::<TDim>();
334    let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
335    if a_should_be_broadcast & mini_op.is_commutative() {
336        let mut swap_input = node.inputs.clone();
337        swap_input.swap(0, 1);
338        return Ok(Some(TypedModelPatch::replace_single_op(
339            model,
340            node,
341            &swap_input,
342            TypedBinOp(mini_op, None),
343        )?));
344    }
345
346    Ok(None)
347}
348
349fn declutter_neutral(
350    model: &TypedModel,
351    node: &TypedNode,
352    mini_op: &dyn BinMiniOp,
353    out_dt: DatumType,
354) -> TractResult<Option<TypedModelPatch>> {
355    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
356        let is_neutral = mini_op
357            .neutral_element()
358            .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
359            .unwrap_or(false);
360
361        // For some operand neural element can be the left one while for other
362        // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok)
363        let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
364
365        if is_neutral && pos_checked {
366            // Neutral decluttering for quant values is special.
367            // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs
368            // - then even if fa = fc, quant params needs to be updated (a != c).
369            // So it's not a no_op.
370            if uniform.uni.datum_type().is_quantized() {
371                return Ok(Some(TypedModelPatch::replace_single_op(
372                    model,
373                    node,
374                    &[node.inputs[0]],
375                    cast(out_dt),
376                )?));
377            // In the non quantized case, it's a no_op.
378            } else {
379                return Ok(Some(TypedModelPatch::rewire(
380                    model,
381                    &[uniform.var],
382                    &[node.id.into()],
383                    &|_, inputs| Ok(inputs.into()),
384                )?));
385            }
386        }
387    }
388    Ok(None)
389}
390
391fn find_most_efficient_config(
392    model: &TypedModel,
393    node: &TypedNode,
394    swap_input: bool,
395) -> TractResult<(bool, bool)> {
396    if let &[a, b] = &*model.node_input_facts(node.id)? {
397        let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
398        let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
399
400        let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
401        let num_by_scalar_elements = if by_scalar_is_possible {
402            a_shape
403                .iter()
404                .zip(b_shape.iter())
405                .rev()
406                .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
407                .map(|(rev_a_dim, _)| rev_a_dim)
408                .product::<TDim>()
409        } else {
410            TDim::Val(0)
411        };
412
413        let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
414        let num_unicast_elements = if unicast_is_possible {
415            a_shape
416                .iter()
417                .zip(b_shape.iter())
418                .rev()
419                .take_while(|(a_dim, b_dim)| a_dim == b_dim)
420                .map(|(a_dim, _)| a_dim)
421                .product::<TDim>()
422        } else {
423            TDim::Val(0)
424        };
425
426        let min_num_elements = 32;
427        let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
428        let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
429        return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
430    }
431    Ok((false, false))
432}
433
434pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
435    TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
436}
437
438#[derive(Clone)]
439pub struct OptBinByScalar {
440    pub binop: Box<dyn BinMiniOp>,
441    eval_fn: Arc<LinalgFn>,
442}
443
444impl Debug for OptBinByScalar {
445    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
446        f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
447    }
448}
449
450impl OptBinByScalar {
451    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
452        if a_shape.len() != b_shape.len() {
453            return false;
454        };
455
456        a_shape
457            .iter()
458            .zip(b_shape.iter())
459            .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
460            .all(|(_, b_dim)| *b_dim == 1.to_dim())
461    }
462}
463
464impl Op for OptBinByScalar {
465    fn name(&self) -> Cow<str> {
466        format!("Opt{}ByScalar", self.binop.name()).into()
467    }
468
469    fn same_as(&self, other: &dyn Op) -> bool {
470        let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
471        self.binop.same_as(&*other.binop)
472    }
473
474    op_as_typed_op!();
475}
476
477impl EvalOp for OptBinByScalar {
478    fn is_stateless(&self) -> bool {
479        true
480    }
481
482    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
483        let (a, b) = args_2!(inputs);
484        // Not a requirement as TensorView doesn't require a owned tensor but in reality
485        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
486        let a = a.into_tensor();
487        let b_shape = b.shape();
488
489        let first_unary_axis = b_shape
490            .iter()
491            .enumerate()
492            .rev()
493            .take_while(|&(_, &dim)| dim == 1)
494            .map(|(i, _)| i)
495            .last()
496            .context("Cannot use by_scalar when no trailing dimensions are unary")?;
497
498        let iterating_shape = &a.shape()[..first_unary_axis];
499        if !iterating_shape.is_empty() {
500            for it_coords in tract_ndarray::indices(iterating_shape) {
501                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
502                let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
503                debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
504                (self.eval_fn)(&mut view, &b_view)?;
505            }
506        } else {
507            let mut view = a.view();
508            let b_view = b.view();
509            debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
510            (self.eval_fn)(&mut view, &b_view)?;
511        }
512        Ok(tvec!(a.into_tvalue()))
513    }
514}
515
516impl TypedOp for OptBinByScalar {
517    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
518        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
519        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
520        let out_shape = inputs[0].shape.clone();
521        Ok(tvec!(out_dt.fact(out_shape)))
522    }
523
524    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
525        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
526        Ok(self
527            .binop
528            .cost_per_element(inputs[0].datum_type)
529            .into_iter()
530            .map(|(c, n)| (c, count.clone() * n))
531            .collect())
532    }
533
534    as_op!();
535}
536
537#[derive(Clone)]
538pub struct OptBinUnicast {
539    pub binop: Box<dyn BinMiniOp>,
540    eval_fn: Arc<LinalgFn>,
541}
542
543impl Debug for OptBinUnicast {
544    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
545        f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
546    }
547}
548
549impl OptBinUnicast {
550    fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
551        let num_iterations: TDim = a_shape
552            .iter()
553            .zip(b_shape.iter())
554            .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
555            .map(|(a_dim, _)| a_dim)
556            .product();
557
558        if num_iterations.is_one() {
559            return true;
560        }
561
562        let elements_per_iteration: TDim = a_shape
563            .iter()
564            .zip(b_shape.iter())
565            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
566            .map(|(_, b_dim)| b_dim)
567            .product();
568
569        if let Ok(num_element) = elements_per_iteration.to_i64() {
570            let required_alignment = vector_size();
571            (num_element as usize % required_alignment) == 0
572        } else {
573            false
574        }
575    }
576    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
577        if a_shape.len() != b_shape.len() {
578            return false;
579        };
580
581        let unicast_possible = a_shape
582            .iter()
583            .zip(b_shape.iter())
584            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
585            .all(|(a_dim, b_dim)| a_dim == b_dim);
586        let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
587
588        unicast_possible && unicast_is_aligned
589    }
590}
591
592impl Op for OptBinUnicast {
593    fn name(&self) -> Cow<str> {
594        format!("Opt{}Unicast", self.binop.name()).into()
595    }
596
597    fn same_as(&self, other: &dyn Op) -> bool {
598        let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
599        self.binop.same_as(&*other.binop)
600    }
601    op_as_typed_op!();
602}
603
604impl EvalOp for OptBinUnicast {
605    fn is_stateless(&self) -> bool {
606        true
607    }
608
609    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
610        let (a, b) = args_2!(inputs);
611        // Not a requirement as TensorView doesn't require a owned tensor but in reality
612        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
613        let a = a.into_tensor();
614        let b_shape = b.shape();
615        let b_view = b.view();
616        let first_non_unary_axis =
617            b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
618
619        if let Some(first_non_unary_axis) = first_non_unary_axis {
620            // Iterate on outter dimensions and evaluate with unicast subviews
621            let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
622            for it_coords in tract_ndarray::indices(iterating_shape) {
623                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
624                debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
625                (self.eval_fn)(&mut view, &b_view)?;
626            }
627        } else {
628            let mut view = a.view();
629            debug_assert_eq!(view.shape(), b_view.shape());
630            (self.eval_fn)(&mut view, &b_view)?;
631        }
632
633        Ok(tvec!(a.into_tvalue()))
634    }
635}
636
637impl TypedOp for OptBinUnicast {
638    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
639        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
640        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
641        let out_shape = inputs[0].shape.clone();
642        Ok(tvec!(out_dt.fact(out_shape)))
643    }
644
645    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
646        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
647        Ok(self
648            .binop
649            .cost_per_element(inputs[0].datum_type)
650            .into_iter()
651            .map(|(c, n)| (c, count.clone() * n))
652            .collect())
653    }
654
655    as_op!();
656}
657
658#[macro_export]
659macro_rules! bin_to_super_type {
660    ($func:ident, $Op:ident,
661     $(codegen: $codegen:expr,)?
662     $(cost: $cost:expr,)?
663     $(declutter: $declutter:expr,)?
664     $(eval_in_a: $eval_in_a:expr,)?
665     $(eval_override: $eval_override: expr,)?
666     $(linalg: $linalg:ident,)?
667     $(operating_datum_type: $operating_datum_type:expr,)?
668     $(is_commutative: $is_commutative:expr,)?
669     $(neutral_element: $neutral_element:expr,)?
670     $(out_of_place: $out_of_place:expr,)?
671     $(validation: $validation:expr,)?
672     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
673     $(q_op_on_f32: $q_op_on_f32:expr,)?
674     $( [$($typ:ident),*] => $cab:expr),*) => {
675        #[derive(Debug, Clone, Hash)]
676        pub struct $Op;
677        #[allow(clippy::redundant_closure_call)]
678        impl $crate::ops::binary::BinMiniOp for $Op {
679            fn name(&self) -> &'static str {
680                stringify!($Op)
681            }
682
683            fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
684                other.downcast_ref::<$Op>().is_some()
685            }
686
687            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
688                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
689                    $(
690                        $(if c.datum_type() == $typ::datum_type() {
691                            let a = a.to_array_view::<$typ>()?;
692                            let b = b.to_array_view::<$typ>()?;
693                            let mut c = c.to_array_view_mut::<$typ>()?;
694                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
695                            return Ok(())
696                        })*
697                     )*
698                    $(
699                        $(
700                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
701                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
702                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
703                                let a = a.to_array_view::<$typ_dt>()?;
704                                let b = b.to_array_view::<$typ_dt>()?;
705                                let mut c = c.to_array_view_mut::<$typ_dt>()?;
706                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
707                                return Ok(())
708                            }
709                            )*
710                         )*
711                     )?
712                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
713            }
714
715            $(fn is_commutative(&self) -> bool {
716                $is_commutative
717            })?
718            $(fn neutral_element(&self) -> Option<i64> {
719                Some($neutral_element)
720            })?
721            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
722                // c and a are same type
723                $(if $eval_in_a(a, b)? { return Ok(()) } )?
724                $(
725                    $(if b.datum_type() == $typ::datum_type() {
726                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
727                        let b = b.to_array_view::<$typ>()?;
728                        let mut a = a.to_array_view_mut::<$typ>()?;
729                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
730                        return Ok(())
731                    })*
732                )*
733                $(
734                    $(
735                        $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
736                            let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
737                            let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
738                            let mut a = a.to_array_view_mut::<$typ_dt>()?;
739                            let b = b.to_array_view::<$typ_dt>()?;
740                            $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
741                                cab(a, &(a.clone()), b, zp, scale)
742                            });
743                            return Ok(())
744                        })*
745                    )*
746                )?
747                bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
748            }
749
750            $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
751                $eval_override(a, b, c_dt)
752            })?
753
754            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
755                if a.unquantized() == b.unquantized() {
756                    if a.is_quantized() || !b.is_quantized() {
757                        return Ok(a)
758                    }
759                    else {
760                        return Ok(b)
761                    }
762                }
763                self.operating_datum_type(a, b)
764            }
765
766                $(
767                    fn declutter(
768                        &self,
769                        model: &TypedModel,
770                        node: &TypedNode,
771                        ) -> TractResult<Option<TypedModelPatch>> {
772                        ($declutter)(self, model, node)
773                    }
774                 )?
775                $(
776                    fn codegen(
777                        &self,
778                        model: &TypedModel,
779                        node: &TypedNode,
780                        a: &Arc<Tensor>,
781                        ) -> TractResult<Option<TypedModelPatch>> {
782                        ($codegen)(self, model, node, a)
783                    }
784                 )?
785                $(
786                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
787                        ($cost)(dt)
788                    }
789                 )?
790                $(
791                    fn validation(&self) -> Validation {
792                        $validation
793                    }
794                 )?
795                $(
796                    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
797                        Some(tract_linalg::BinOp::$linalg)
798                    }
799                 )?
800                $(
801                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
802                        ($operating_datum_type)(a, b)
803                    })?
804
805
806            /// Default simple binary operation for QFormat where
807            /// we dequantise & apply requested operation in float & requantize it
808            /// several implementation are provided with pro & con
809            #[allow(unused_variables)]
810            fn maybe_eval_qbinary_as_float_op(
811                &self,
812                a: &TValue,
813                b: &TValue,
814                c_dt: &DatumType,
815            ) -> TractResult<Option<Tensor>> {
816                $(
817                    /// Implementation strive to minimise memory allocation and access
818                    /// we apply only if type is QU8 zp_scale datum type
819                    /// maybe more suited for large models tensors
820                    fn memory_optimised_q_binary_as_float_op(
821                        a: &TValue,
822                        b: &TValue,
823                        c_dt: &DatumType,
824                    ) -> TractResult<Option<Tensor>> {
825                        if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
826                                DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
827                                DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
828                            (a.datum_type(), b.datum_type(), c_dt)
829                        {
830                            let c_inv_scale = 1.0 / c_scale;
831                            let a = a.to_array_view::<u8>()?;
832                            let b = b.to_array_view::<u8>()?;
833                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
834                            let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
835                            let view = c.to_array_view_mut::<u8>()?;
836                            $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
837                                *c = (scale_by($q_op_on_f32(
838                                            ((*a as i32 - a_zp as i32) as f32 * a_scale),
839                                            ((*b as i32 - b_zp as i32) as f32 * b_scale),
840                                ), c_inv_scale) as i32
841                                    + *c_zp as i32)
842                                    .clamp_cast()
843                            });
844                            return Ok(Some(c));
845                        }
846                        Ok(None)
847                    }
848
849                    /// Apply to all Q types
850                    /// Take more memory but hopefully faster than memory_optimised_q_binary_as_float_op
851                    /// especially once cast_to_dt will have will have vectorized implementations
852                    fn generic_q_binary_as_float_op(
853                        a: &TValue,
854                        b: &TValue,
855                        c_dt: &DatumType,
856                        accumulator_dt: DatumType
857                    ) -> TractResult<Option<Tensor>> {
858                        if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
859                            let a = a.cast_to_dt(accumulator_dt)?.into_owned();
860                            let b = b.cast_to_dt(accumulator_dt)?.into_owned();
861                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
862                            let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
863                            match accumulator_dt {
864                                DatumType::F32 => {
865                                    let view = c.to_array_view_mut::<f32>()?;
866                                    $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
867                                        *c = $q_op_on_f32(*a,*b);
868                                    })
869                                },
870                                other => bail!("unexpected accumulator data type as {:?}", other)
871                            };
872
873                            return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
874                        }
875                        Ok(None)
876                    }
877
878                    if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
879                        return Ok(Some(c));
880                    }
881                    if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
882                        return Ok(Some(d));
883                    }
884                )?
885                Ok(None)
886            }
887        }
888
889        pub fn $func() -> $crate::ops::binary::TypedBinOp {
890            $crate::ops::binary::TypedBinOp(Box::new($Op), None)
891        }
892    };
893}
894
895#[derive(Debug)]
896pub(crate) struct OneUniformInput {
897    pub uni: Arc<Tensor>,
898    pub var: OutletId,
899    pub left_is_uniform: bool,
900}
901
902pub(crate) fn one_input_is_uniform(
903    model: &TypedModel,
904    node: &TypedNode,
905) -> TractResult<Option<OneUniformInput>> {
906    if let &[a, b] = &*model.node_input_facts(node.id)? {
907        let uni = if let Some(a) = &a.uniform {
908            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
909        } else if let Some(b) = &b.uniform {
910            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
911        } else {
912            return Ok(None);
913        };
914        let var_fact = [a, b][uni.left_is_uniform as usize];
915        let uni_fact = [a, b][!uni.left_is_uniform as usize];
916        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
917            return Ok(Some(uni));
918        }
919    }
920    Ok(None)
921}