Skip to main content

tract_core/ops/
binary.rs

1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt::{self, Debug};
6use tract_data::itertools::izip;
7use tract_itertools::Itertools;
8use tract_linalg::{BinOp, LinalgFn};
9
10use super::math::{Add, Max, Min, Mul, Sub};
11use super::{cast::cast, math::SubF};
12
13pub trait BinMiniOp:
14    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
15{
16    fn name(&self) -> &'static str;
17    fn validation(&self) -> Validation {
18        Validation::Accurate
19    }
20    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
21        a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
22    }
23    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
24    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
25    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
26
27    fn is_commutative(&self) -> bool {
28        true
29    }
30    fn neutral_element(&self) -> Option<i64> {
31        None
32    }
33    fn absorbing_element(&self) -> Option<i64> {
34        None
35    }
36
37    #[allow(unused_variables)]
38    fn maybe_eval_qbinary_as_float_op(
39        &self,
40        a: &TValue,
41        b: &TValue,
42        c_dt: &DatumType,
43    ) -> TractResult<Option<Tensor>> {
44        Ok(None)
45    }
46
47    fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
48        if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
49            Ok(tensor)
50        } else {
51            let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
52            if &*c_shape == a.shape() && c_dt == a.datum_type() {
53                let mut a = a.into_tensor();
54                self.eval_in_a(&mut a, &b)?;
55                Ok(a)
56            } else {
57                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
58                self.eval_out_of_place(&mut c, &a, &b)?;
59                Ok(c)
60            }
61        }
62    }
63    fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
64        self.generic_eval(a, b, c_dt)
65    }
66    #[allow(unused_variables)]
67    fn declutter(
68        &self,
69        model: &TypedModel,
70        node: &TypedNode,
71    ) -> TractResult<Option<TypedModelPatch>> {
72        Ok(None)
73    }
74    #[allow(unused_variables)]
75    fn codegen(
76        &self,
77        model: &TypedModel,
78        node: &TypedNode,
79    ) -> TractResult<Option<TypedModelPatch>> {
80        Ok(None)
81    }
82    #[allow(unused_variables)]
83    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
84        tvec!()
85    }
86    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
87        None
88    }
89
90    /// Override for ops that can evaluate symbolic TDim inputs (comparisons).
91    #[allow(unused_variables)]
92    fn eval_symbolic(
93        &self,
94        session: &TurnState,
95        inputs: TVec<TValue>,
96    ) -> TractResult<Option<TVec<TValue>>> {
97        Ok(None)
98    }
99
100    /// Override for ops that produce TDim-level comparison expressions (comparisons).
101    #[allow(unused_variables)]
102    fn uniform_tdim_comparison(&self, a: &TDim, b: &TDim) -> Option<TDim> {
103        None
104    }
105}
106dyn_clone::clone_trait_object!(BinMiniOp);
107dyn_eq::eq_trait_object!(BinMiniOp);
108downcast_rs::impl_downcast!(BinMiniOp);
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
112
113impl Op for TypedBinOp {
114    fn name(&self) -> StaticName {
115        self.0.name().into()
116    }
117
118    fn validation(&self) -> Validation {
119        self.0.validation()
120    }
121
122    op_as_typed_op!();
123}
124
125impl TypedBinOp {
126    fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
127        if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
128    }
129}
130
131impl EvalOp for TypedBinOp {
132    fn is_stateless(&self) -> bool {
133        true
134    }
135
136    fn eval_with_session(
137        &self,
138        _node_id: usize,
139        session: &TurnState,
140        inputs: TVec<TValue>,
141    ) -> TractResult<TVec<TValue>> {
142        if let Some(result) = self.0.eval_symbolic(session, inputs.clone())? {
143            return Ok(result);
144        }
145        let (a, b) = args_2!(inputs);
146        ensure!(a.rank() == b.rank());
147        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
148        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
149    }
150
151    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
152        let (a, b) = args_2!(inputs);
153        ensure!(a.rank() == b.rank());
154        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
155        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
156    }
157}
158
159impl TypedBinOp {
160    fn combine_uniform_tdim(&self, a: &TDim, b: &TDim) -> Option<TDim> {
161        // Comparison ops provide their own TDim combination
162        if let Some(result) = self.0.uniform_tdim_comparison(a, b) {
163            return Some(result);
164        }
165        let a = tensor0(a.clone()).into_tvalue();
166        let b = tensor0(b.clone()).into_tvalue();
167        let result = self.0.eval(a, b, TDim::datum_type()).ok()?;
168        result
169            .try_as_plain()
170            .ok()
171            .and_then(|d| d.as_slice::<TDim>().ok())
172            .and_then(|s| s.first())
173            .cloned()
174            .map(|d| d.reduce())
175    }
176
177    fn combine_uniform_tdim_with_konst(&self, a: &TDim, konst: &Tensor) -> Option<TDim> {
178        if konst.len() != 1 {
179            return None;
180        }
181        // Integer-valued scalar (including float constants like 2.0, 1.0, 3.0)
182        let b_int: Option<i64> =
183            if konst.datum_type().is_integer() || konst.datum_type().is::<bool>() {
184                konst.cast_to_scalar::<i64>().ok()
185            } else if konst.datum_type().is_float() {
186                konst.cast_to_scalar::<f64>().ok().and_then(|f| {
187                    if (f - f.round()).abs() < 1e-6 { Some(f.round() as i64) } else { None }
188                })
189            } else {
190                None
191            };
192        if let Some(b) = b_int {
193            return self.combine_uniform_tdim(a, &TDim::Val(b));
194        }
195        // Mul by reciprocal of integer (e.g. ×0.5 → Div(a, 2))
196        if self.0.neutral_element() == Some(1)
197            && let Some(f) = konst.cast_to_scalar::<f64>().ok().filter(|&f| f > 0.0)
198        {
199            let n = (1.0 / f).round() as u64;
200            if n >= 2 && (f * n as f64 - 1.0).abs() < 1e-6 {
201                return Some(TDim::Div(Box::new(a.clone()), n).reduce());
202            }
203        }
204        None
205    }
206}
207
208impl TypedOp for TypedBinOp {
209    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210        if inputs[0].rank() != inputs[1].rank() {
211            bail!(
212                "Typed ops require rank match. Invalid inputs for {}: {}",
213                self.name(),
214                inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
215            );
216        }
217        let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
218        let mut fact = out_dt.fact(&*crate::broadcast::multi_broadcast(&[
219            &inputs[0].shape.to_tvec(),
220            &inputs[1].shape.to_tvec(),
221        ])?);
222        if let (Some(a), Some(b)) = (&inputs[0].uniform_tdim, &inputs[1].uniform_tdim) {
223            fact.uniform_tdim = self.combine_uniform_tdim(a, b);
224            // And(a,b) has no TDim kernel; for 0/1 booleans And == Mul
225            if fact.uniform_tdim.is_none() && self.0.is::<crate::ops::logic::And>() {
226                fact.uniform_tdim = Some(TDim::Mul(vec![a.clone(), b.clone()]).reduce());
227            }
228        }
229        // Fallback: one side has uniform_tdim, the other is a scalar constant
230        if fact.uniform_tdim.is_none() {
231            for (expr, konst_fact) in [
232                (inputs[0].uniform_tdim.as_ref(), inputs[1]),
233                (inputs[1].uniform_tdim.as_ref(), inputs[0]),
234            ] {
235                let Some(a) = expr else { continue };
236                let Some(konst) = konst_fact.konst.as_ref() else { continue };
237                fact.uniform_tdim = self.combine_uniform_tdim_with_konst(a, konst);
238                if fact.uniform_tdim.is_some() {
239                    break;
240                }
241            }
242        }
243        Ok(tvec!(fact))
244    }
245
246    fn input_roi(
247        &self,
248        model: &TypedModel,
249        node: &TypedNode,
250    ) -> TractResult<Option<TVec<Option<TDim>>>> {
251        // Introduction: Mul (or any op with neutral_element=1) with a mask
252        // that has uniform_tdim → the other input gets that expression as ROI.
253        if self.0.neutral_element() == Some(1) {
254            for (mask_ix, other_ix) in [(0usize, 1usize), (1, 0)] {
255                let fact = model.outlet_fact(node.inputs[mask_ix])?;
256                if let Some(mask_expr) = &fact.uniform_tdim {
257                    let mut rois = tvec![None; node.inputs.len()];
258                    rois[other_ix] = Some(mask_expr.clone());
259                    return Ok(Some(rois));
260                }
261            }
262        }
263        // Bubbling: delegate to the natural blanket implementation.
264        crate::optim::propagate_roi::bubble_roi(model, node)
265    }
266
267    fn change_axes(
268        &self,
269        model: &TypedModel,
270        node: &TypedNode,
271        _io: InOut,
272        change: &AxisOp,
273    ) -> TractResult<Option<AxisChangeConsequence>> {
274        if let AxisOp::Rm(rm) = change {
275            let (inputs, outputs) = model.node_facts(node.id)?;
276            if inputs.len() >= 2
277                && outputs.len() >= 1
278                && inputs[0].rank() > *rm
279                && inputs[1].rank() > *rm
280                && outputs[0].rank() > *rm
281            {
282                rule_if!(inputs[0].shape[*rm].is_one());
283                rule_if!(inputs[1].shape[*rm].is_one());
284                rule_if!(outputs[0].shape[*rm].is_one());
285            }
286        }
287        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
288    }
289
290    fn axes_mapping(
291        &self,
292        inputs: &[&TypedFact],
293        outputs: &[&TypedFact],
294    ) -> TractResult<AxesMapping> {
295        AxesMapping::natural(inputs, outputs)
296    }
297
298    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
299        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
300        Ok(self
301            .0
302            .cost_per_element(inputs[0].datum_type)
303            .into_iter()
304            .map(|(c, n)| (c, count.clone() * n))
305            .collect())
306    }
307
308    fn slice(
309        &self,
310        patch: &mut TypedModelPatch,
311        _model: &TypedModel,
312        _node: &TypedNode,
313        prefix: &str,
314        inputs: &[OutletId],
315        _output_axis: usize,
316        _start: &TDim,
317        _end: &TDim,
318    ) -> TractResult<Option<TVec<OutletId>>> {
319        Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
320    }
321
322    fn declutter(
323        &self,
324        model: &TypedModel,
325        node: &TypedNode,
326    ) -> TractResult<Option<TypedModelPatch>> {
327        let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
328            (a.datum_type().unwrap(), b.datum_type().unwrap())
329        } else {
330            unreachable!("TypedBinOp has two inputs.")
331        };
332        if let Some(neutral_patch) =
333            declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
334        {
335            return Ok(Some(neutral_patch));
336        }
337        if let Some(absorbing_patch) = declutter_absorbing(model, node, self.0.as_ref())? {
338            return Ok(Some(absorbing_patch));
339        }
340        if let Some(broadcast_patch) =
341            declutter_broadcasting_operand_1(model, node, self.0.clone())?
342        {
343            return Ok(Some(broadcast_patch));
344        }
345        self.0.declutter(model, node)
346    }
347
348    fn codegen(
349        &self,
350        model: &TypedModel,
351        node: &TypedNode,
352    ) -> TractResult<Option<TypedModelPatch>> {
353        if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
354            let input_facts = model.node_input_facts(node.id)?;
355            let must_swap_inputs =
356                input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
357                    (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
358                });
359            let (operand_1, operand_2) = if must_swap_inputs {
360                (input_facts[1], input_facts[0])
361            } else {
362                (input_facts[0], input_facts[1])
363            };
364
365            let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
366                find_most_efficient_config(model, node, must_swap_inputs)?;
367
368            // Check if op is quantized
369            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
370            let op_is_quant = c_dt.is_quantized()
371                || operand_1.datum_type.is_quantized()
372                || operand_2.datum_type.is_quantized();
373
374            // Check if it can be evaluated in a
375            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
376            let c_shape = crate::broadcast::multi_broadcast(&[
377                operand_1.shape.clone(),
378                operand_2.shape.clone(),
379            ])?;
380            let can_eval_in_a =
381                (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
382
383            // Swap input if required
384            let inputs = if must_swap_inputs {
385                let mut swap_input = node.inputs.clone();
386                swap_input.swap(0, 1);
387                swap_input
388            } else {
389                node.inputs.clone()
390            };
391            let actual_linalg_op =
392                if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
393            let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
394
395            let dt = model.node_input_facts(node.id)?[0].datum_type;
396            if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
397                rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
398                let eval_fn = Arc::from(func);
399                return Ok(Some(
400                    TypedModelPatch::replace_single_op(
401                        model,
402                        node,
403                        &inputs,
404                        OptBinByScalar { binop: actual_core_op, eval_fn },
405                    )?
406                    .with_context("ByScalar"),
407                ));
408            }
409
410            if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
411                rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
412                let eval_fn = Arc::from(func);
413                return Ok(Some(
414                    TypedModelPatch::replace_single_op(
415                        model,
416                        node,
417                        &inputs,
418                        OptBinUnicast { binop: actual_core_op, eval_fn },
419                    )?
420                    .with_context("Unicast"),
421                ));
422            }
423        }
424
425        Ok(None)
426    }
427    as_op!();
428}
429
430fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
431    match linalg {
432        BinOp::Min => Box::new(Min),
433        BinOp::Max => Box::new(Max),
434        BinOp::Add => Box::new(Add),
435        BinOp::Mul => Box::new(Mul),
436        BinOp::Sub => Box::new(Sub),
437        BinOp::SubF => Box::new(SubF),
438    }
439}
440fn declutter_broadcasting_operand_1(
441    model: &TypedModel,
442    node: &TypedNode,
443    mini_op: Box<dyn BinMiniOp>,
444) -> TractResult<Option<TypedModelPatch>> {
445    let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
446        (a.shape.clone(), b.shape.clone())
447    } else {
448        unreachable!("TypedBinOp has two inputs.")
449    };
450
451    let a_num_elements = a_shape.iter().product::<TDim>();
452    let b_num_elements = b_shape.iter().product::<TDim>();
453    let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
454    if a_should_be_broadcast & mini_op.is_commutative() {
455        let mut swap_input = node.inputs.clone();
456        swap_input.swap(0, 1);
457        return Ok(Some(TypedModelPatch::replace_single_op(
458            model,
459            node,
460            &swap_input,
461            TypedBinOp(mini_op, None),
462        )?));
463    }
464
465    Ok(None)
466}
467
468fn declutter_neutral(
469    model: &TypedModel,
470    node: &TypedNode,
471    mini_op: &dyn BinMiniOp,
472    out_dt: DatumType,
473) -> TractResult<Option<TypedModelPatch>> {
474    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
475        let is_neutral = mini_op
476            .neutral_element()
477            .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
478            .unwrap_or(false);
479
480        // For some operand neural element can be the left one while for other
481        // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok)
482        let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
483
484        if is_neutral && pos_checked {
485            // Neutral decluttering for quant values is special.
486            // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs
487            // - then even if fa = fc, quant params needs to be updated (a != c).
488            // So it's not a no_op.
489            if uniform.uni.datum_type().is_quantized() {
490                return Ok(Some(TypedModelPatch::replace_single_op(
491                    model,
492                    node,
493                    &[node.inputs[0]],
494                    cast(out_dt),
495                )?));
496            // In the non quantized case, it's a no_op.
497            } else {
498                return Ok(Some(TypedModelPatch::rewire(
499                    model,
500                    &[uniform.var],
501                    &[node.id.into()],
502                    &|_, inputs| Ok(inputs.into()),
503                )?));
504            }
505        }
506    }
507    Ok(None)
508}
509
510/// When one input is the absorbing element (e.g. 0 for Mul, false for And),
511/// replace the entire op with the uniform (absorbing) input.
512fn declutter_absorbing(
513    model: &TypedModel,
514    node: &TypedNode,
515    mini_op: &dyn BinMiniOp,
516) -> TractResult<Option<TypedModelPatch>> {
517    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
518        let is_absorbing = mini_op
519            .absorbing_element()
520            .map(|absorb| tensor0(absorb).close_enough(&uniform.uni, false).is_ok())
521            .unwrap_or(false);
522        if is_absorbing {
523            let uni_inlet = if uniform.left_is_uniform { 0 } else { 1 };
524            return Ok(Some(TypedModelPatch::rewire(
525                model,
526                &[node.inputs[uni_inlet]],
527                &[node.id.into()],
528                &|_, inputs| Ok(inputs.into()),
529            )?));
530        }
531    }
532    Ok(None)
533}
534
535fn find_most_efficient_config(
536    model: &TypedModel,
537    node: &TypedNode,
538    swap_input: bool,
539) -> TractResult<(bool, bool)> {
540    if let &[a, b] = &*model.node_input_facts(node.id)? {
541        let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
542        let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
543
544        let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
545        let num_by_scalar_elements = if by_scalar_is_possible {
546            a_shape
547                .iter()
548                .zip(b_shape.iter())
549                .rev()
550                .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
551                .map(|(rev_a_dim, _)| rev_a_dim)
552                .product::<TDim>()
553        } else {
554            TDim::Val(0)
555        };
556
557        let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
558        let num_unicast_elements = if unicast_is_possible {
559            a_shape
560                .iter()
561                .zip(b_shape.iter())
562                .rev()
563                .take_while(|(a_dim, b_dim)| a_dim == b_dim)
564                .map(|(a_dim, _)| a_dim)
565                .product::<TDim>()
566        } else {
567            TDim::Val(0)
568        };
569
570        let min_num_elements = 32;
571        let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
572        let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
573        return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
574    }
575    Ok((false, false))
576}
577
578pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
579    TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
580}
581
582#[derive(Clone)]
583pub struct OptBinByScalar {
584    pub binop: Box<dyn BinMiniOp>,
585    eval_fn: Arc<LinalgFn>,
586}
587
588impl Debug for OptBinByScalar {
589    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
590        f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
591    }
592}
593
594impl OptBinByScalar {
595    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
596        if a_shape.len() != b_shape.len() {
597            return false;
598        };
599
600        a_shape
601            .iter()
602            .zip(b_shape.iter())
603            .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
604            .all(|(_, b_dim)| *b_dim == 1.to_dim())
605    }
606}
607
608impl PartialEq for OptBinByScalar {
609    fn eq(&self, other: &Self) -> bool {
610        *self.binop == *other.binop
611    }
612}
613impl Eq for OptBinByScalar {}
614
615impl Op for OptBinByScalar {
616    fn name(&self) -> StaticName {
617        format!("Opt{}ByScalar", self.binop.name()).into()
618    }
619
620    op_as_typed_op!();
621}
622
623impl EvalOp for OptBinByScalar {
624    fn is_stateless(&self) -> bool {
625        true
626    }
627
628    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
629        let (a, b) = args_2!(inputs);
630        // Not a requirement as TensorView doesn't require a owned tensor but in reality
631        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
632        let a = a.into_tensor();
633        let b_shape = b.shape();
634
635        let first_unary_axis = b_shape
636            .iter()
637            .enumerate()
638            .rev()
639            .take_while(|&(_, &dim)| dim == 1)
640            .map(|(i, _)| i)
641            .last()
642            .context("Cannot use by_scalar when no trailing dimensions are unary")?;
643
644        let iterating_shape = &a.shape()[..first_unary_axis];
645        if !iterating_shape.is_empty() {
646            for it_coords in tract_ndarray::indices(iterating_shape) {
647                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
648                let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
649                debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
650                (self.eval_fn)(&mut view, &b_view)?;
651            }
652        } else {
653            let mut view = a.view();
654            let b_view = b.view();
655            debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
656            (self.eval_fn)(&mut view, &b_view)?;
657        }
658        Ok(tvec!(a.into_tvalue()))
659    }
660}
661
662impl TypedOp for OptBinByScalar {
663    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
664        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
665        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
666        let out_shape = inputs[0].shape.clone();
667        Ok(tvec!(out_dt.fact(out_shape)))
668    }
669
670    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
671        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
672        Ok(self
673            .binop
674            .cost_per_element(inputs[0].datum_type)
675            .into_iter()
676            .map(|(c, n)| (c, count.clone() * n))
677            .collect())
678    }
679
680    as_op!();
681}
682
683#[derive(Clone)]
684pub struct OptBinUnicast {
685    pub binop: Box<dyn BinMiniOp>,
686    eval_fn: Arc<LinalgFn>,
687}
688
689impl Debug for OptBinUnicast {
690    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
691        f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
692    }
693}
694
695impl OptBinUnicast {
696    fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
697        let num_iterations: TDim = a_shape
698            .iter()
699            .zip(b_shape.iter())
700            .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
701            .map(|(a_dim, _)| a_dim)
702            .product();
703
704        if num_iterations.is_one() {
705            return true;
706        }
707
708        let elements_per_iteration: TDim = a_shape
709            .iter()
710            .zip(b_shape.iter())
711            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
712            .map(|(_, b_dim)| b_dim)
713            .product();
714
715        if let Ok(num_element) = elements_per_iteration.to_i64() {
716            let required_alignment = vector_size();
717            (num_element as usize).is_multiple_of(required_alignment)
718        } else {
719            false
720        }
721    }
722    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
723        if a_shape.len() != b_shape.len() {
724            return false;
725        };
726
727        let unicast_possible = a_shape
728            .iter()
729            .zip(b_shape.iter())
730            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
731            .all(|(a_dim, b_dim)| a_dim == b_dim);
732        let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
733
734        unicast_possible && unicast_is_aligned
735    }
736}
737
738impl PartialEq for OptBinUnicast {
739    fn eq(&self, other: &Self) -> bool {
740        *self.binop == *other.binop
741    }
742}
743impl Eq for OptBinUnicast {}
744
745impl Op for OptBinUnicast {
746    fn name(&self) -> StaticName {
747        format!("Opt{}Unicast", self.binop.name()).into()
748    }
749
750    op_as_typed_op!();
751}
752
753impl EvalOp for OptBinUnicast {
754    fn is_stateless(&self) -> bool {
755        true
756    }
757
758    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
759        let (a, b) = args_2!(inputs);
760        // Not a requirement as TensorView doesn't require a owned tensor but in reality
761        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
762        let a = a.into_tensor();
763        let b_shape = b.shape();
764        let b_view = b.view();
765        let first_non_unary_axis =
766            b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
767
768        if let Some(first_non_unary_axis) = first_non_unary_axis {
769            // Iterate on outter dimensions and evaluate with unicast subviews
770            let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
771            for it_coords in tract_ndarray::indices(iterating_shape) {
772                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
773                debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
774                (self.eval_fn)(&mut view, &b_view)?;
775            }
776        } else {
777            let mut view = a.view();
778            debug_assert_eq!(view.shape(), b_view.shape());
779            (self.eval_fn)(&mut view, &b_view)?;
780        }
781
782        Ok(tvec!(a.into_tvalue()))
783    }
784}
785
786impl TypedOp for OptBinUnicast {
787    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
788        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
789        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
790        let out_shape = inputs[0].shape.clone();
791        Ok(tvec!(out_dt.fact(out_shape)))
792    }
793
794    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
795        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
796        Ok(self
797            .binop
798            .cost_per_element(inputs[0].datum_type)
799            .into_iter()
800            .map(|(c, n)| (c, count.clone() * n))
801            .collect())
802    }
803
804    as_op!();
805}
806
807#[macro_export]
808macro_rules! bin_to_super_type {
809    ($func:ident, $Op:ident,
810     $(codegen: $codegen:expr,)?
811     $(cost: $cost:expr,)?
812     $(declutter: $declutter:expr,)?
813     $(eval_in_a: $eval_in_a:expr,)?
814     $(eval_override: $eval_override: expr,)?
815     $(linalg: $linalg:ident,)?
816     $(operating_datum_type: $operating_datum_type:expr,)?
817     $(is_commutative: $is_commutative:expr,)?
818     $(neutral_element: $neutral_element:expr,)?
819     $(absorbing_element: $absorbing_element:expr,)?
820     $(out_of_place: $out_of_place:expr,)?
821     $(validation: $validation:expr,)?
822     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
823     $(q_op_on_f32: $q_op_on_f32:expr,)?
824     $( [$($typ:ident),*] => $cab:expr),*) => {
825        #[derive(Debug, Clone, Hash, PartialEq, Eq)]
826        pub struct $Op;
827        #[allow(clippy::redundant_closure_call)]
828        impl $crate::ops::binary::BinMiniOp for $Op {
829            fn name(&self) -> &'static str {
830                stringify!($Op)
831            }
832
833            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
834                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
835                    $(
836                        $(if c.datum_type() == $typ::datum_type() {
837                            let a = a.to_plain_array_view::<$typ>()?;
838                            let b = b.to_plain_array_view::<$typ>()?;
839                            let mut c_plain = c.try_as_plain_mut()?;
840                            let mut c = c_plain.to_array_view_mut::<$typ>()?;
841                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
842                            return Ok(())
843                        })*
844                     )*
845                    $(
846                        $(
847                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
848                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
849                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
850                                let a = a.to_plain_array_view::<$typ_dt>()?;
851                                let b = b.to_plain_array_view::<$typ_dt>()?;
852                                let mut c_plain = c.try_as_plain_mut()?;
853                                let mut c = c_plain.to_array_view_mut::<$typ_dt>()?;
854                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
855                                return Ok(())
856                            }
857                            )*
858                         )*
859                     )?
860                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
861            }
862
863            $(fn is_commutative(&self) -> bool {
864                $is_commutative
865            })?
866            $(fn neutral_element(&self) -> Option<i64> {
867                Some($neutral_element)
868            })?
869            $(fn absorbing_element(&self) -> Option<i64> {
870                Some($absorbing_element)
871            })?
872            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
873                // c and a are same type
874                $(if $eval_in_a(a, b)? { return Ok(()) } )?
875                $(
876                    $(if b.datum_type() == $typ::datum_type() {
877                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
878                        let b = b.to_plain_array_view::<$typ>()?;
879                        let mut a_plain = a.try_as_plain_mut()?;
880                        let mut a = a_plain.to_array_view_mut::<$typ>()?;
881                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
882                        return Ok(())
883                    })*
884                )*
885                $(
886                    $(
887                        $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
888                            let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
889                            let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
890                            let mut a_plain = a.try_as_plain_mut()?;
891                            let mut a = a_plain.to_array_view_mut::<$typ_dt>()?;
892                            let b = b.to_plain_array_view::<$typ_dt>()?;
893                            $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
894                                cab(a, &(a.clone()), b, zp, scale)
895                            });
896                            return Ok(())
897                        })*
898                    )*
899                )?
900                bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
901            }
902
903            $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
904                $eval_override(a, b, c_dt)
905            })?
906
907            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
908                if a.unquantized() == b.unquantized() {
909                    if a.is_quantized() || !b.is_quantized() {
910                        return Ok(a)
911                    }
912                    else {
913                        return Ok(b)
914                    }
915                }
916                self.operating_datum_type(a, b)
917            }
918
919                $(
920                    fn declutter(
921                        &self,
922                        model: &TypedModel,
923                        node: &TypedNode,
924                        ) -> TractResult<Option<TypedModelPatch>> {
925                        ($declutter)(self, model, node)
926                    }
927                 )?
928                $(
929                    fn codegen(
930                        &self,
931                        model: &TypedModel,
932                        node: &TypedNode,
933                        a: &Arc<Tensor>,
934                        ) -> TractResult<Option<TypedModelPatch>> {
935                        ($codegen)(self, model, node, a)
936                    }
937                 )?
938                $(
939                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
940                        ($cost)(dt)
941                    }
942                 )?
943                $(
944                    fn validation(&self) -> Validation {
945                        $validation
946                    }
947                 )?
948                $(
949                    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
950                        Some(tract_linalg::BinOp::$linalg)
951                    }
952                 )?
953                $(
954                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
955                        ($operating_datum_type)(a, b)
956                    })?
957
958
959            /// Default simple binary operation for QFormat where
960            /// we dequantise & apply requested operation in float & requantize it
961            /// several implementation are provided with pro & con
962            #[allow(unused_variables)]
963            fn maybe_eval_qbinary_as_float_op(
964                &self,
965                a: &TValue,
966                b: &TValue,
967                c_dt: &DatumType,
968            ) -> TractResult<Option<Tensor>> {
969                $(
970                    /// Implementation strive to minimise memory allocation and access
971                    /// we apply only if type is QU8 zp_scale datum type
972                    /// maybe more suited for large models tensors
973                    fn memory_optimised_q_binary_as_float_op(
974                        a: &TValue,
975                        b: &TValue,
976                        c_dt: &DatumType,
977                    ) -> TractResult<Option<Tensor>> {
978                        if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
979                                DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
980                                DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
981                            (a.datum_type(), b.datum_type(), c_dt)
982                        {
983                            let c_inv_scale = 1.0 / c_scale;
984                            let a = a.to_plain_array_view::<u8>()?;
985                            let b = b.to_plain_array_view::<u8>()?;
986                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
987                            let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
988                            let mut c_plain = c.try_as_plain_mut()?;
989                            let view = c_plain.to_array_view_mut::<u8>()?;
990                            $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
991                                *c = (scale_by($q_op_on_f32(
992                                            ((*a as i32 - a_zp as i32) as f32 * a_scale),
993                                            ((*b as i32 - b_zp as i32) as f32 * b_scale),
994                                ), c_inv_scale) as i32
995                                    + *c_zp as i32)
996                                    .clamp_cast()
997                            });
998                            return Ok(Some(c));
999                        }
1000                        Ok(None)
1001                    }
1002
1003                    /// Apply to all Q types
1004                    /// Take more memory but hopefully faster than memory_optimised_q_binary_as_float_op
1005                    /// especially once cast_to_dt will have will have vectorized implementations
1006                    fn generic_q_binary_as_float_op(
1007                        a: &TValue,
1008                        b: &TValue,
1009                        c_dt: &DatumType,
1010                        accumulator_dt: DatumType
1011                    ) -> TractResult<Option<Tensor>> {
1012                        if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
1013                            let a = a.cast_to_dt(accumulator_dt)?.into_owned();
1014                            let b = b.cast_to_dt(accumulator_dt)?.into_owned();
1015                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1016                            let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
1017                            match accumulator_dt {
1018                                DatumType::F32 => {
1019                                    let mut c_plain = c.try_as_plain_mut()?;
1020                                    let view = c_plain.to_array_view_mut::<f32>()?;
1021                                    $crate::ndarray::Zip::from(view).and_broadcast(a.try_as_plain()?.to_array_view()?).and_broadcast(b.try_as_plain()?.to_array_view()?).for_each(|c, a, b| {
1022                                        *c = $q_op_on_f32(*a,*b);
1023                                    })
1024                                },
1025                                other => bail!("unexpected accumulator data type as {:?}", other)
1026                            };
1027
1028                            return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
1029                        }
1030                        Ok(None)
1031                    }
1032
1033                    if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
1034                        return Ok(Some(c));
1035                    }
1036                    if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
1037                        return Ok(Some(d));
1038                    }
1039                )?
1040                Ok(None)
1041            }
1042        }
1043
1044        pub fn $func() -> $crate::ops::binary::TypedBinOp {
1045            $crate::ops::binary::TypedBinOp(Box::new($Op), None)
1046        }
1047    };
1048}
1049
1050#[derive(Debug)]
1051pub(crate) struct OneUniformInput {
1052    pub uni: Arc<Tensor>,
1053    pub var: OutletId,
1054    pub left_is_uniform: bool,
1055}
1056
1057pub(crate) fn one_input_is_uniform(
1058    model: &TypedModel,
1059    node: &TypedNode,
1060) -> TractResult<Option<OneUniformInput>> {
1061    if let &[a, b] = &*model.node_input_facts(node.id)? {
1062        let uni = if let Some(a) = &a.uniform {
1063            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
1064        } else if let Some(b) = &b.uniform {
1065            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
1066        } else {
1067            return Ok(None);
1068        };
1069        let var_fact = [a, b][uni.left_is_uniform as usize];
1070        let uni_fact = [a, b][!uni.left_is_uniform as usize];
1071        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
1072            return Ok(Some(uni));
1073        }
1074    }
1075    Ok(None)
1076}