Skip to main content

tract_core/ops/
binary.rs

1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use std::fmt::{self, Debug};
5use tract_data::itertools::izip;
6use tract_itertools::Itertools;
7use tract_linalg::{BinOp, LinalgFn};
8
9use super::math::{Add, Max, Min, Mul, Sub};
10use super::{cast::cast, math::SubF};
11
12pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast {
13    fn name(&self) -> &'static str;
14    fn validation(&self) -> Validation {
15        Validation::Accurate
16    }
17    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
18        a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
19    }
20    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
21    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
22    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
23
24    fn is_commutative(&self) -> bool {
25        true
26    }
27    fn neutral_element(&self) -> Option<i64> {
28        None
29    }
30
31    #[allow(unused_variables)]
32    fn maybe_eval_qbinary_as_float_op(
33        &self,
34        a: &TValue,
35        b: &TValue,
36        c_dt: &DatumType,
37    ) -> TractResult<Option<Tensor>> {
38        Ok(None)
39    }
40
41    fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
42        if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
43            Ok(tensor)
44        } else {
45            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
46            if &*c_shape == a.shape() && c_dt == a.datum_type() {
47                let mut a = a.into_tensor();
48                self.eval_in_a(&mut a, &b)?;
49                Ok(a)
50            } else {
51                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
52                self.eval_out_of_place(&mut c, &a, &b)?;
53                Ok(c)
54            }
55        }
56    }
57    fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
58        self.generic_eval(a, b, c_dt)
59    }
60    #[allow(unused_variables)]
61    fn declutter(
62        &self,
63        model: &TypedModel,
64        node: &TypedNode,
65    ) -> TractResult<Option<TypedModelPatch>> {
66        Ok(None)
67    }
68    #[allow(unused_variables)]
69    fn codegen(
70        &self,
71        model: &TypedModel,
72        node: &TypedNode,
73    ) -> TractResult<Option<TypedModelPatch>> {
74        Ok(None)
75    }
76    #[allow(unused_variables)]
77    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
78        tvec!()
79    }
80    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
81        None
82    }
83
84    #[allow(unused_variables)]
85    fn same_as(&self, other: &dyn BinMiniOp) -> bool {
86        false
87    }
88}
89dyn_clone::clone_trait_object!(BinMiniOp);
90downcast_rs::impl_downcast!(BinMiniOp);
91
92#[derive(Debug, Clone)]
93pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
94
95impl Op for TypedBinOp {
96    fn name(&self) -> StaticName {
97        self.0.name().into()
98    }
99
100    fn validation(&self) -> Validation {
101        self.0.validation()
102    }
103
104    fn same_as(&self, other: &dyn Op) -> bool {
105        let Some(other) = other.downcast_ref::<TypedBinOp>() else { return false };
106        self.1 == other.1 && self.0.same_as(&*other.0)
107    }
108
109    op_as_typed_op!();
110}
111
112impl TypedBinOp {
113    fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
114        if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
115    }
116}
117
118impl EvalOp for TypedBinOp {
119    fn is_stateless(&self) -> bool {
120        true
121    }
122
123    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
124        let (a, b) = args_2!(inputs);
125        ensure!(a.rank() == b.rank());
126        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
127        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
128    }
129}
130
131impl TypedOp for TypedBinOp {
132    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
133        if inputs[0].rank() != inputs[1].rank() {
134            bail!(
135                "Typed ops require rank match. Invalid inputs for {}: {}",
136                self.name(),
137                inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
138            );
139        }
140        let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
141        Ok(tvec!(out_dt.fact(&*crate::broadcast::multi_broadcast(&[
142            &inputs[0].shape.to_tvec(),
143            &inputs[1].shape.to_tvec()
144        ])?)))
145    }
146
147    fn change_axes(
148        &self,
149        model: &TypedModel,
150        node: &TypedNode,
151        _io: InOut,
152        change: &AxisOp,
153    ) -> TractResult<Option<AxisChangeConsequence>> {
154        if let AxisOp::Rm(rm) = change {
155            let (inputs, outputs) = model.node_facts(node.id)?;
156            rule_if!(inputs[0].shape[*rm].is_one());
157            rule_if!(inputs[1].shape[*rm].is_one());
158            rule_if!(outputs[0].shape[*rm].is_one());
159        }
160        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
161    }
162
163    fn axes_mapping(
164        &self,
165        inputs: &[&TypedFact],
166        outputs: &[&TypedFact],
167    ) -> TractResult<AxesMapping> {
168        AxesMapping::natural(inputs, outputs)
169    }
170
171    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
172        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
173        Ok(self
174            .0
175            .cost_per_element(inputs[0].datum_type)
176            .into_iter()
177            .map(|(c, n)| (c, count.clone() * n))
178            .collect())
179    }
180
181    fn slice(
182        &self,
183        patch: &mut TypedModelPatch,
184        _model: &TypedModel,
185        _node: &TypedNode,
186        prefix: &str,
187        inputs: &[OutletId],
188        _output_axis: usize,
189        _start: &TDim,
190        _end: &TDim,
191    ) -> TractResult<Option<TVec<OutletId>>> {
192        Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
193    }
194
195    fn declutter(
196        &self,
197        model: &TypedModel,
198        node: &TypedNode,
199    ) -> TractResult<Option<TypedModelPatch>> {
200        let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
201            (a.datum_type().unwrap(), b.datum_type().unwrap())
202        } else {
203            unreachable!("TypedBinOp has two inputs.")
204        };
205        if let Some(neutral_patch) =
206            declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
207        {
208            return Ok(Some(neutral_patch));
209        }
210        if let Some(broadcast_patch) =
211            declutter_broadcasting_operand_1(model, node, self.0.clone())?
212        {
213            return Ok(Some(broadcast_patch));
214        }
215        self.0.declutter(model, node)
216    }
217
218    fn codegen(
219        &self,
220        model: &TypedModel,
221        node: &TypedNode,
222    ) -> TractResult<Option<TypedModelPatch>> {
223        if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
224            let input_facts = model.node_input_facts(node.id)?;
225            let must_swap_inputs =
226                input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
227                    (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
228                });
229            let (operand_1, operand_2) = if must_swap_inputs {
230                (input_facts[1], input_facts[0])
231            } else {
232                (input_facts[0], input_facts[1])
233            };
234
235            let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
236                find_most_efficient_config(model, node, must_swap_inputs)?;
237
238            // Check if op is quantized
239            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
240            let op_is_quant = c_dt.is_quantized()
241                || operand_1.datum_type.is_quantized()
242                || operand_2.datum_type.is_quantized();
243
244            // Check if it can be evaluated in a
245            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
246            let c_shape = crate::broadcast::multi_broadcast(&[
247                operand_1.shape.clone(),
248                operand_2.shape.clone(),
249            ])?;
250            let can_eval_in_a =
251                (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
252
253            // Swap input if required
254            let inputs = if must_swap_inputs {
255                let mut swap_input = node.inputs.clone();
256                swap_input.swap(0, 1);
257                swap_input
258            } else {
259                node.inputs.clone()
260            };
261            let actual_linalg_op =
262                if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
263            let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
264
265            let dt = model.node_input_facts(node.id)?[0].datum_type;
266            if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
267                rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
268                let eval_fn = Arc::from(func);
269                return Ok(Some(
270                    TypedModelPatch::replace_single_op(
271                        model,
272                        node,
273                        &inputs,
274                        OptBinByScalar { binop: actual_core_op, eval_fn },
275                    )?
276                    .with_context("ByScalar"),
277                ));
278            }
279
280            if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
281                rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
282                let eval_fn = Arc::from(func);
283                return Ok(Some(
284                    TypedModelPatch::replace_single_op(
285                        model,
286                        node,
287                        &inputs,
288                        OptBinUnicast { binop: actual_core_op, eval_fn },
289                    )?
290                    .with_context("Unicast"),
291                ));
292            }
293        }
294
295        Ok(None)
296    }
297    as_op!();
298}
299
300fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
301    match linalg {
302        BinOp::Min => Box::new(Min),
303        BinOp::Max => Box::new(Max),
304        BinOp::Add => Box::new(Add),
305        BinOp::Mul => Box::new(Mul),
306        BinOp::Sub => Box::new(Sub),
307        BinOp::SubF => Box::new(SubF),
308    }
309}
310fn declutter_broadcasting_operand_1(
311    model: &TypedModel,
312    node: &TypedNode,
313    mini_op: Box<dyn BinMiniOp>,
314) -> TractResult<Option<TypedModelPatch>> {
315    let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
316        (a.shape.clone(), b.shape.clone())
317    } else {
318        unreachable!("TypedBinOp has two inputs.")
319    };
320
321    let a_num_elements = a_shape.iter().product::<TDim>();
322    let b_num_elements = b_shape.iter().product::<TDim>();
323    let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
324    if a_should_be_broadcast & mini_op.is_commutative() {
325        let mut swap_input = node.inputs.clone();
326        swap_input.swap(0, 1);
327        return Ok(Some(TypedModelPatch::replace_single_op(
328            model,
329            node,
330            &swap_input,
331            TypedBinOp(mini_op, None),
332        )?));
333    }
334
335    Ok(None)
336}
337
338fn declutter_neutral(
339    model: &TypedModel,
340    node: &TypedNode,
341    mini_op: &dyn BinMiniOp,
342    out_dt: DatumType,
343) -> TractResult<Option<TypedModelPatch>> {
344    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
345        let is_neutral = mini_op
346            .neutral_element()
347            .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
348            .unwrap_or(false);
349
350        // For some operand neural element can be the left one while for other
351        // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok)
352        let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
353
354        if is_neutral && pos_checked {
355            // Neutral decluttering for quant values is special.
356            // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs
357            // - then even if fa = fc, quant params needs to be updated (a != c).
358            // So it's not a no_op.
359            if uniform.uni.datum_type().is_quantized() {
360                return Ok(Some(TypedModelPatch::replace_single_op(
361                    model,
362                    node,
363                    &[node.inputs[0]],
364                    cast(out_dt),
365                )?));
366            // In the non quantized case, it's a no_op.
367            } else {
368                return Ok(Some(TypedModelPatch::rewire(
369                    model,
370                    &[uniform.var],
371                    &[node.id.into()],
372                    &|_, inputs| Ok(inputs.into()),
373                )?));
374            }
375        }
376    }
377    Ok(None)
378}
379
380fn find_most_efficient_config(
381    model: &TypedModel,
382    node: &TypedNode,
383    swap_input: bool,
384) -> TractResult<(bool, bool)> {
385    if let &[a, b] = &*model.node_input_facts(node.id)? {
386        let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
387        let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
388
389        let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
390        let num_by_scalar_elements = if by_scalar_is_possible {
391            a_shape
392                .iter()
393                .zip(b_shape.iter())
394                .rev()
395                .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
396                .map(|(rev_a_dim, _)| rev_a_dim)
397                .product::<TDim>()
398        } else {
399            TDim::Val(0)
400        };
401
402        let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
403        let num_unicast_elements = if unicast_is_possible {
404            a_shape
405                .iter()
406                .zip(b_shape.iter())
407                .rev()
408                .take_while(|(a_dim, b_dim)| a_dim == b_dim)
409                .map(|(a_dim, _)| a_dim)
410                .product::<TDim>()
411        } else {
412            TDim::Val(0)
413        };
414
415        let min_num_elements = 32;
416        let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
417        let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
418        return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
419    }
420    Ok((false, false))
421}
422
423pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
424    TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
425}
426
427#[derive(Clone)]
428pub struct OptBinByScalar {
429    pub binop: Box<dyn BinMiniOp>,
430    eval_fn: Arc<LinalgFn>,
431}
432
433impl Debug for OptBinByScalar {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
435        f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
436    }
437}
438
439impl OptBinByScalar {
440    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
441        if a_shape.len() != b_shape.len() {
442            return false;
443        };
444
445        a_shape
446            .iter()
447            .zip(b_shape.iter())
448            .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
449            .all(|(_, b_dim)| *b_dim == 1.to_dim())
450    }
451}
452
453impl Op for OptBinByScalar {
454    fn name(&self) -> StaticName {
455        format!("Opt{}ByScalar", self.binop.name()).into()
456    }
457
458    fn same_as(&self, other: &dyn Op) -> bool {
459        let Some(other) = other.downcast_ref::<OptBinByScalar>() else { return false };
460        self.binop.same_as(&*other.binop)
461    }
462
463    op_as_typed_op!();
464}
465
466impl EvalOp for OptBinByScalar {
467    fn is_stateless(&self) -> bool {
468        true
469    }
470
471    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
472        let (a, b) = args_2!(inputs);
473        // Not a requirement as TensorView doesn't require a owned tensor but in reality
474        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
475        let a = a.into_tensor();
476        let b_shape = b.shape();
477
478        let first_unary_axis = b_shape
479            .iter()
480            .enumerate()
481            .rev()
482            .take_while(|&(_, &dim)| dim == 1)
483            .map(|(i, _)| i)
484            .last()
485            .context("Cannot use by_scalar when no trailing dimensions are unary")?;
486
487        let iterating_shape = &a.shape()[..first_unary_axis];
488        if !iterating_shape.is_empty() {
489            for it_coords in tract_ndarray::indices(iterating_shape) {
490                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
491                let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
492                debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
493                (self.eval_fn)(&mut view, &b_view)?;
494            }
495        } else {
496            let mut view = a.view();
497            let b_view = b.view();
498            debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
499            (self.eval_fn)(&mut view, &b_view)?;
500        }
501        Ok(tvec!(a.into_tvalue()))
502    }
503}
504
505impl TypedOp for OptBinByScalar {
506    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
507        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
508        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
509        let out_shape = inputs[0].shape.clone();
510        Ok(tvec!(out_dt.fact(out_shape)))
511    }
512
513    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
514        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
515        Ok(self
516            .binop
517            .cost_per_element(inputs[0].datum_type)
518            .into_iter()
519            .map(|(c, n)| (c, count.clone() * n))
520            .collect())
521    }
522
523    as_op!();
524}
525
526#[derive(Clone)]
527pub struct OptBinUnicast {
528    pub binop: Box<dyn BinMiniOp>,
529    eval_fn: Arc<LinalgFn>,
530}
531
532impl Debug for OptBinUnicast {
533    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
534        f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
535    }
536}
537
538impl OptBinUnicast {
539    fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
540        let num_iterations: TDim = a_shape
541            .iter()
542            .zip(b_shape.iter())
543            .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
544            .map(|(a_dim, _)| a_dim)
545            .product();
546
547        if num_iterations.is_one() {
548            return true;
549        }
550
551        let elements_per_iteration: TDim = a_shape
552            .iter()
553            .zip(b_shape.iter())
554            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
555            .map(|(_, b_dim)| b_dim)
556            .product();
557
558        if let Ok(num_element) = elements_per_iteration.to_i64() {
559            let required_alignment = vector_size();
560            (num_element as usize % required_alignment) == 0
561        } else {
562            false
563        }
564    }
565    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
566        if a_shape.len() != b_shape.len() {
567            return false;
568        };
569
570        let unicast_possible = a_shape
571            .iter()
572            .zip(b_shape.iter())
573            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
574            .all(|(a_dim, b_dim)| a_dim == b_dim);
575        let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
576
577        unicast_possible && unicast_is_aligned
578    }
579}
580
581impl Op for OptBinUnicast {
582    fn name(&self) -> StaticName {
583        format!("Opt{}Unicast", self.binop.name()).into()
584    }
585
586    fn same_as(&self, other: &dyn Op) -> bool {
587        let Some(other) = other.downcast_ref::<OptBinUnicast>() else { return false };
588        self.binop.same_as(&*other.binop)
589    }
590    op_as_typed_op!();
591}
592
593impl EvalOp for OptBinUnicast {
594    fn is_stateless(&self) -> bool {
595        true
596    }
597
598    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
599        let (a, b) = args_2!(inputs);
600        // Not a requirement as TensorView doesn't require a owned tensor but in reality
601        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
602        let a = a.into_tensor();
603        let b_shape = b.shape();
604        let b_view = b.view();
605        let first_non_unary_axis =
606            b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
607
608        if let Some(first_non_unary_axis) = first_non_unary_axis {
609            // Iterate on outter dimensions and evaluate with unicast subviews
610            let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
611            for it_coords in tract_ndarray::indices(iterating_shape) {
612                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
613                debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
614                (self.eval_fn)(&mut view, &b_view)?;
615            }
616        } else {
617            let mut view = a.view();
618            debug_assert_eq!(view.shape(), b_view.shape());
619            (self.eval_fn)(&mut view, &b_view)?;
620        }
621
622        Ok(tvec!(a.into_tvalue()))
623    }
624}
625
626impl TypedOp for OptBinUnicast {
627    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
628        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
629        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
630        let out_shape = inputs[0].shape.clone();
631        Ok(tvec!(out_dt.fact(out_shape)))
632    }
633
634    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
635        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
636        Ok(self
637            .binop
638            .cost_per_element(inputs[0].datum_type)
639            .into_iter()
640            .map(|(c, n)| (c, count.clone() * n))
641            .collect())
642    }
643
644    as_op!();
645}
646
647#[macro_export]
648macro_rules! bin_to_super_type {
649    ($func:ident, $Op:ident,
650     $(codegen: $codegen:expr,)?
651     $(cost: $cost:expr,)?
652     $(declutter: $declutter:expr,)?
653     $(eval_in_a: $eval_in_a:expr,)?
654     $(eval_override: $eval_override: expr,)?
655     $(linalg: $linalg:ident,)?
656     $(operating_datum_type: $operating_datum_type:expr,)?
657     $(is_commutative: $is_commutative:expr,)?
658     $(neutral_element: $neutral_element:expr,)?
659     $(out_of_place: $out_of_place:expr,)?
660     $(validation: $validation:expr,)?
661     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
662     $(q_op_on_f32: $q_op_on_f32:expr,)?
663     $( [$($typ:ident),*] => $cab:expr),*) => {
664        #[derive(Debug, Clone, Hash)]
665        pub struct $Op;
666        #[allow(clippy::redundant_closure_call)]
667        impl $crate::ops::binary::BinMiniOp for $Op {
668            fn name(&self) -> &'static str {
669                stringify!($Op)
670            }
671
672            fn same_as(&self, other: &dyn $crate::ops::binary::BinMiniOp) -> bool {
673                other.downcast_ref::<$Op>().is_some()
674            }
675
676            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
677                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
678                    $(
679                        $(if c.datum_type() == $typ::datum_type() {
680                            let a = a.to_array_view::<$typ>()?;
681                            let b = b.to_array_view::<$typ>()?;
682                            let mut c = c.to_array_view_mut::<$typ>()?;
683                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
684                            return Ok(())
685                        })*
686                     )*
687                    $(
688                        $(
689                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
690                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
691                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
692                                let a = a.to_array_view::<$typ_dt>()?;
693                                let b = b.to_array_view::<$typ_dt>()?;
694                                let mut c = c.to_array_view_mut::<$typ_dt>()?;
695                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
696                                return Ok(())
697                            }
698                            )*
699                         )*
700                     )?
701                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
702            }
703
704            $(fn is_commutative(&self) -> bool {
705                $is_commutative
706            })?
707            $(fn neutral_element(&self) -> Option<i64> {
708                Some($neutral_element)
709            })?
710            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
711                // c and a are same type
712                $(if $eval_in_a(a, b)? { return Ok(()) } )?
713                $(
714                    $(if b.datum_type() == $typ::datum_type() {
715                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
716                        let b = b.to_array_view::<$typ>()?;
717                        let mut a = a.to_array_view_mut::<$typ>()?;
718                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
719                        return Ok(())
720                    })*
721                )*
722                $(
723                    $(
724                        $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
725                            let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
726                            let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
727                            let mut a = a.to_array_view_mut::<$typ_dt>()?;
728                            let b = b.to_array_view::<$typ_dt>()?;
729                            $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
730                                cab(a, &(a.clone()), b, zp, scale)
731                            });
732                            return Ok(())
733                        })*
734                    )*
735                )?
736                bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
737            }
738
739            $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
740                $eval_override(a, b, c_dt)
741            })?
742
743            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
744                if a.unquantized() == b.unquantized() {
745                    if a.is_quantized() || !b.is_quantized() {
746                        return Ok(a)
747                    }
748                    else {
749                        return Ok(b)
750                    }
751                }
752                self.operating_datum_type(a, b)
753            }
754
755                $(
756                    fn declutter(
757                        &self,
758                        model: &TypedModel,
759                        node: &TypedNode,
760                        ) -> TractResult<Option<TypedModelPatch>> {
761                        ($declutter)(self, model, node)
762                    }
763                 )?
764                $(
765                    fn codegen(
766                        &self,
767                        model: &TypedModel,
768                        node: &TypedNode,
769                        a: &Arc<Tensor>,
770                        ) -> TractResult<Option<TypedModelPatch>> {
771                        ($codegen)(self, model, node, a)
772                    }
773                 )?
774                $(
775                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
776                        ($cost)(dt)
777                    }
778                 )?
779                $(
780                    fn validation(&self) -> Validation {
781                        $validation
782                    }
783                 )?
784                $(
785                    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
786                        Some(tract_linalg::BinOp::$linalg)
787                    }
788                 )?
789                $(
790                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
791                        ($operating_datum_type)(a, b)
792                    })?
793
794
795            /// Default simple binary operation for QFormat where
796            /// we dequantise & apply requested operation in float & requantize it
797            /// several implementation are provided with pro & con
798            #[allow(unused_variables)]
799            fn maybe_eval_qbinary_as_float_op(
800                &self,
801                a: &TValue,
802                b: &TValue,
803                c_dt: &DatumType,
804            ) -> TractResult<Option<Tensor>> {
805                $(
806                    /// Implementation strive to minimise memory allocation and access
807                    /// we apply only if type is QU8 zp_scale datum type
808                    /// maybe more suited for large models tensors
809                    fn memory_optimised_q_binary_as_float_op(
810                        a: &TValue,
811                        b: &TValue,
812                        c_dt: &DatumType,
813                    ) -> TractResult<Option<Tensor>> {
814                        if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
815                                DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
816                                DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
817                            (a.datum_type(), b.datum_type(), c_dt)
818                        {
819                            let c_inv_scale = 1.0 / c_scale;
820                            let a = a.to_array_view::<u8>()?;
821                            let b = b.to_array_view::<u8>()?;
822                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
823                            let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
824                            let view = c.to_array_view_mut::<u8>()?;
825                            $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
826                                *c = (scale_by($q_op_on_f32(
827                                            ((*a as i32 - a_zp as i32) as f32 * a_scale),
828                                            ((*b as i32 - b_zp as i32) as f32 * b_scale),
829                                ), c_inv_scale) as i32
830                                    + *c_zp as i32)
831                                    .clamp_cast()
832                            });
833                            return Ok(Some(c));
834                        }
835                        Ok(None)
836                    }
837
838                    /// Apply to all Q types
839                    /// Take more memory but hopefully faster than memory_optimised_q_binary_as_float_op
840                    /// especially once cast_to_dt will have will have vectorized implementations
841                    fn generic_q_binary_as_float_op(
842                        a: &TValue,
843                        b: &TValue,
844                        c_dt: &DatumType,
845                        accumulator_dt: DatumType
846                    ) -> TractResult<Option<Tensor>> {
847                        if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
848                            let a = a.cast_to_dt(accumulator_dt)?.into_owned();
849                            let b = b.cast_to_dt(accumulator_dt)?.into_owned();
850                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
851                            let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
852                            match accumulator_dt {
853                                DatumType::F32 => {
854                                    let view = c.to_array_view_mut::<f32>()?;
855                                    $crate::ndarray::Zip::from(view).and_broadcast(a.to_array_view()?).and_broadcast(b.to_array_view()?).for_each(|c, a, b| {
856                                        *c = $q_op_on_f32(*a,*b);
857                                    })
858                                },
859                                other => bail!("unexpected accumulator data type as {:?}", other)
860                            };
861
862                            return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
863                        }
864                        Ok(None)
865                    }
866
867                    if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
868                        return Ok(Some(c));
869                    }
870                    if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
871                        return Ok(Some(d));
872                    }
873                )?
874                Ok(None)
875            }
876        }
877
878        pub fn $func() -> $crate::ops::binary::TypedBinOp {
879            $crate::ops::binary::TypedBinOp(Box::new($Op), None)
880        }
881    };
882}
883
884#[derive(Debug)]
885pub(crate) struct OneUniformInput {
886    pub uni: Arc<Tensor>,
887    pub var: OutletId,
888    pub left_is_uniform: bool,
889}
890
891pub(crate) fn one_input_is_uniform(
892    model: &TypedModel,
893    node: &TypedNode,
894) -> TractResult<Option<OneUniformInput>> {
895    if let &[a, b] = &*model.node_input_facts(node.id)? {
896        let uni = if let Some(a) = &a.uniform {
897            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
898        } else if let Some(b) = &b.uniform {
899            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
900        } else {
901            return Ok(None);
902        };
903        let var_fact = [a, b][uni.left_is_uniform as usize];
904        let uni_fact = [a, b][!uni.left_is_uniform as usize];
905        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
906            return Ok(Some(uni));
907        }
908    }
909    Ok(None)
910}