Skip to main content

tract_core/ops/
binary.rs

1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt::{self, Debug};
6use tract_data::itertools::izip;
7use tract_itertools::Itertools;
8use tract_linalg::{BinOp, LinalgFn};
9
10use super::math::{Add, Max, Min, Mul, Sub};
11use super::{cast::cast, math::SubF};
12
13pub trait BinMiniOp:
14    fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
15{
16    fn name(&self) -> &'static str;
17    fn validation(&self) -> Validation {
18        Validation::Accurate
19    }
20    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
21        a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
22    }
23    fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
24    fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
25    fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
26
27    fn is_commutative(&self) -> bool {
28        true
29    }
30    fn neutral_element(&self) -> Option<i64> {
31        None
32    }
33    fn absorbing_element(&self) -> Option<i64> {
34        None
35    }
36
37    #[allow(unused_variables)]
38    fn maybe_eval_qbinary_as_float_op(
39        &self,
40        a: &TValue,
41        b: &TValue,
42        c_dt: &DatumType,
43    ) -> TractResult<Option<Tensor>> {
44        Ok(None)
45    }
46
47    fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
48        if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
49            return Ok(tensor);
50        }
51        // Same-shape fast path: skip `multi_broadcast` allocation when shapes
52        // are already equal (very common: residuals, mask application, etc.).
53        // Correctness: equal shapes imply broadcast shape == a.shape() and the
54        // existing slow path would have taken this same branch.
55        if c_dt == a.datum_type() && a.shape() == b.shape() {
56            let mut a = a.into_tensor();
57            self.eval_in_a(&mut a, &b)?;
58            return Ok(a);
59        }
60        let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
61        if &*c_shape == a.shape() && c_dt == a.datum_type() {
62            let mut a = a.into_tensor();
63            self.eval_in_a(&mut a, &b)?;
64            Ok(a)
65        } else {
66            let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
67            self.eval_out_of_place(&mut c, &a, &b)?;
68            Ok(c)
69        }
70    }
71    fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
72        self.generic_eval(a, b, c_dt)
73    }
74    #[allow(unused_variables)]
75    fn declutter(
76        &self,
77        model: &TypedModel,
78        node: &TypedNode,
79    ) -> TractResult<Option<TypedModelPatch>> {
80        Ok(None)
81    }
82    #[allow(unused_variables)]
83    fn codegen(
84        &self,
85        model: &TypedModel,
86        node: &TypedNode,
87    ) -> TractResult<Option<TypedModelPatch>> {
88        Ok(None)
89    }
90    #[allow(unused_variables)]
91    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
92        tvec!()
93    }
94    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
95        None
96    }
97
98    /// Override for ops that can evaluate symbolic TDim inputs (comparisons).
99    #[allow(unused_variables)]
100    fn eval_symbolic(
101        &self,
102        session: &TurnState,
103        inputs: TVec<TValue>,
104    ) -> TractResult<Option<TVec<TValue>>> {
105        Ok(None)
106    }
107
108    /// Override for ops that produce TDim-level comparison expressions (comparisons).
109    #[allow(unused_variables)]
110    fn uniform_tdim_comparison(&self, a: &TDim, b: &TDim) -> Option<TDim> {
111        None
112    }
113}
114dyn_clone::clone_trait_object!(BinMiniOp);
115dyn_eq::eq_trait_object!(BinMiniOp);
116downcast_rs::impl_downcast!(BinMiniOp);
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
120
121impl Op for TypedBinOp {
122    fn name(&self) -> StaticName {
123        self.0.name().into()
124    }
125
126    fn validation(&self) -> Validation {
127        self.0.validation()
128    }
129
130    op_as_typed_op!();
131}
132
133impl TypedBinOp {
134    fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
135        if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
136    }
137}
138
139impl EvalOp for TypedBinOp {
140    fn is_stateless(&self) -> bool {
141        true
142    }
143
144    fn eval_with_session(
145        &self,
146        _node_id: usize,
147        session: &TurnState,
148        inputs: TVec<TValue>,
149    ) -> TractResult<TVec<TValue>> {
150        if let Some(result) = self.0.eval_symbolic(session, inputs.clone())? {
151            return Ok(result);
152        }
153        let (a, b) = args_2!(inputs);
154        ensure!(a.rank() == b.rank());
155        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
156        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
157    }
158
159    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
160        let (a, b) = args_2!(inputs);
161        ensure!(a.rank() == b.rank());
162        let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
163        Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
164    }
165}
166
167impl TypedBinOp {
168    fn combine_uniform_tdim(&self, a: &TDim, b: &TDim) -> Option<TDim> {
169        // Comparison ops provide their own TDim combination
170        if let Some(result) = self.0.uniform_tdim_comparison(a, b) {
171            return Some(result);
172        }
173        let a = tensor0(a.clone()).into_tvalue();
174        let b = tensor0(b.clone()).into_tvalue();
175        let result = self.0.eval(a, b, TDim::datum_type()).ok()?;
176        result
177            .try_as_plain()
178            .ok()
179            .and_then(|d| d.as_slice::<TDim>().ok())
180            .and_then(|s| s.first())
181            .cloned()
182            .map(|d| d.reduce())
183    }
184
185    fn combine_uniform_tdim_with_konst(&self, a: &TDim, konst: &Tensor) -> Option<TDim> {
186        if konst.len() != 1 {
187            return None;
188        }
189        // Integer-valued scalar (including float constants like 2.0, 1.0, 3.0)
190        let b_int: Option<i64> =
191            if konst.datum_type().is_integer() || konst.datum_type().is::<bool>() {
192                konst.cast_to_scalar::<i64>().ok()
193            } else if konst.datum_type().is_float() {
194                konst.cast_to_scalar::<f64>().ok().and_then(|f| {
195                    if (f - f.round()).abs() < 1e-6 { Some(f.round() as i64) } else { None }
196                })
197            } else {
198                None
199            };
200        if let Some(b) = b_int {
201            return self.combine_uniform_tdim(a, &TDim::Val(b));
202        }
203        // Mul by reciprocal of integer (e.g. ×0.5 → Div(a, 2))
204        if self.0.neutral_element() == Some(1)
205            && let Some(f) = konst.cast_to_scalar::<f64>().ok().filter(|&f| f > 0.0)
206        {
207            let n = (1.0 / f).round() as u64;
208            if n >= 2 && (f * n as f64 - 1.0).abs() < 1e-6 {
209                return Some(TDim::Div(Box::new(a.clone()), n).reduce());
210            }
211        }
212        None
213    }
214}
215
216impl TypedOp for TypedBinOp {
217    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
218        if inputs[0].rank() != inputs[1].rank() {
219            bail!(
220                "Typed ops require rank match. Invalid inputs for {}: {}",
221                self.name(),
222                inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
223            );
224        }
225        let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
226        let mut fact = out_dt.fact(&*crate::broadcast::multi_broadcast(&[
227            &inputs[0].shape.to_tvec(),
228            &inputs[1].shape.to_tvec(),
229        ])?);
230        if let (Some(a), Some(b)) = (&inputs[0].uniform_tdim, &inputs[1].uniform_tdim) {
231            fact.uniform_tdim = self.combine_uniform_tdim(a, b);
232            // And(a,b) has no TDim kernel; for 0/1 booleans And == Mul
233            if fact.uniform_tdim.is_none() && self.0.is::<crate::ops::logic::And>() {
234                fact.uniform_tdim = Some(TDim::Mul(vec![a.clone(), b.clone()]).reduce());
235            }
236        }
237        // Fallback: one side has uniform_tdim, the other is a scalar constant
238        if fact.uniform_tdim.is_none() {
239            for (expr, konst_fact) in [
240                (inputs[0].uniform_tdim.as_ref(), inputs[1]),
241                (inputs[1].uniform_tdim.as_ref(), inputs[0]),
242            ] {
243                let Some(a) = expr else { continue };
244                let Some(konst) = konst_fact.konst.as_ref() else { continue };
245                fact.uniform_tdim = self.combine_uniform_tdim_with_konst(a, konst);
246                if fact.uniform_tdim.is_some() {
247                    break;
248                }
249            }
250        }
251        Ok(tvec!(fact))
252    }
253
254    fn input_roi(
255        &self,
256        model: &TypedModel,
257        node: &TypedNode,
258    ) -> TractResult<Option<TVec<Option<TDim>>>> {
259        // Introduction: Mul (or any op with neutral_element=1) with a mask
260        // that has uniform_tdim → the other input gets that expression as ROI.
261        if self.0.neutral_element() == Some(1) {
262            for (mask_ix, other_ix) in [(0usize, 1usize), (1, 0)] {
263                let fact = model.outlet_fact(node.inputs[mask_ix])?;
264                if let Some(mask_expr) = &fact.uniform_tdim {
265                    let mut rois = tvec![None; node.inputs.len()];
266                    rois[other_ix] = Some(mask_expr.clone());
267                    return Ok(Some(rois));
268                }
269            }
270        }
271        // Bubbling: delegate to the natural blanket implementation.
272        crate::optim::propagate_roi::bubble_roi(model, node)
273    }
274
275    fn change_axes(
276        &self,
277        model: &TypedModel,
278        node: &TypedNode,
279        _io: InOut,
280        change: &AxisOp,
281    ) -> TractResult<Option<AxisChangeConsequence>> {
282        if let AxisOp::Rm(rm) = change {
283            let (inputs, outputs) = model.node_facts(node.id)?;
284            if inputs.len() >= 2
285                && outputs.len() >= 1
286                && inputs[0].rank() > *rm
287                && inputs[1].rank() > *rm
288                && outputs[0].rank() > *rm
289            {
290                rule_if!(inputs[0].shape[*rm].is_one());
291                rule_if!(inputs[1].shape[*rm].is_one());
292                rule_if!(outputs[0].shape[*rm].is_one());
293            }
294        }
295        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
296    }
297
298    fn axes_mapping(
299        &self,
300        inputs: &[&TypedFact],
301        outputs: &[&TypedFact],
302    ) -> TractResult<AxesMapping> {
303        AxesMapping::natural(inputs, outputs)
304    }
305
306    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
307        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
308        Ok(self
309            .0
310            .cost_per_element(inputs[0].datum_type)
311            .into_iter()
312            .map(|(c, n)| (c, count.clone() * n))
313            .collect())
314    }
315
316    fn slice(
317        &self,
318        patch: &mut TypedModelPatch,
319        _model: &TypedModel,
320        _node: &TypedNode,
321        prefix: &str,
322        inputs: &[OutletId],
323        _output_axis: usize,
324        _start: &TDim,
325        _end: &TDim,
326    ) -> TractResult<Option<TVec<OutletId>>> {
327        Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
328    }
329
330    fn declutter(
331        &self,
332        model: &TypedModel,
333        node: &TypedNode,
334    ) -> TractResult<Option<TypedModelPatch>> {
335        let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
336            (a.datum_type().unwrap(), b.datum_type().unwrap())
337        } else {
338            unreachable!("TypedBinOp has two inputs.")
339        };
340        if let Some(neutral_patch) =
341            declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
342        {
343            return Ok(Some(neutral_patch));
344        }
345        if let Some(absorbing_patch) = declutter_absorbing(model, node, self.0.as_ref())? {
346            return Ok(Some(absorbing_patch));
347        }
348        if let Some(broadcast_patch) =
349            declutter_broadcasting_operand_1(model, node, self.0.clone())?
350        {
351            return Ok(Some(broadcast_patch));
352        }
353        self.0.declutter(model, node)
354    }
355
356    fn codegen(
357        &self,
358        model: &TypedModel,
359        node: &TypedNode,
360    ) -> TractResult<Option<TypedModelPatch>> {
361        if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
362            let input_facts = model.node_input_facts(node.id)?;
363            let must_swap_inputs =
364                input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
365                    (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
366                });
367            let (operand_1, operand_2) = if must_swap_inputs {
368                (input_facts[1], input_facts[0])
369            } else {
370                (input_facts[0], input_facts[1])
371            };
372
373            let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
374                find_most_efficient_config(model, node, must_swap_inputs)?;
375
376            // Check if op is quantized
377            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
378            let op_is_quant = c_dt.is_quantized()
379                || operand_1.datum_type.is_quantized()
380                || operand_2.datum_type.is_quantized();
381
382            // Check if it can be evaluated in a
383            let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
384            let c_shape = crate::broadcast::multi_broadcast(&[
385                operand_1.shape.clone(),
386                operand_2.shape.clone(),
387            ])?;
388            let can_eval_in_a =
389                (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
390
391            // Swap input if required
392            let inputs = if must_swap_inputs {
393                let mut swap_input = node.inputs.clone();
394                swap_input.swap(0, 1);
395                swap_input
396            } else {
397                node.inputs.clone()
398            };
399            let actual_linalg_op =
400                if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
401            let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
402
403            let dt = model.node_input_facts(node.id)?[0].datum_type;
404            if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
405                rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
406                let eval_fn = Arc::from(func);
407                return Ok(Some(
408                    TypedModelPatch::replace_single_op(
409                        model,
410                        node,
411                        &inputs,
412                        OptBinByScalar { binop: actual_core_op, eval_fn },
413                    )?
414                    .with_context("ByScalar"),
415                ));
416            }
417
418            if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
419                rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
420                let eval_fn = Arc::from(func);
421                return Ok(Some(
422                    TypedModelPatch::replace_single_op(
423                        model,
424                        node,
425                        &inputs,
426                        OptBinUnicast { binop: actual_core_op, eval_fn },
427                    )?
428                    .with_context("Unicast"),
429                ));
430            }
431        }
432
433        Ok(None)
434    }
435    as_op!();
436}
437
438fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
439    match linalg {
440        BinOp::Min => Box::new(Min),
441        BinOp::Max => Box::new(Max),
442        BinOp::Add => Box::new(Add),
443        BinOp::Mul => Box::new(Mul),
444        BinOp::Sub => Box::new(Sub),
445        BinOp::SubF => Box::new(SubF),
446    }
447}
448fn declutter_broadcasting_operand_1(
449    model: &TypedModel,
450    node: &TypedNode,
451    mini_op: Box<dyn BinMiniOp>,
452) -> TractResult<Option<TypedModelPatch>> {
453    let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
454        (a.shape.clone(), b.shape.clone())
455    } else {
456        unreachable!("TypedBinOp has two inputs.")
457    };
458
459    let a_num_elements = a_shape.iter().product::<TDim>();
460    let b_num_elements = b_shape.iter().product::<TDim>();
461    let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
462    if a_should_be_broadcast & mini_op.is_commutative() {
463        let mut swap_input = node.inputs.clone();
464        swap_input.swap(0, 1);
465        return Ok(Some(TypedModelPatch::replace_single_op(
466            model,
467            node,
468            &swap_input,
469            TypedBinOp(mini_op, None),
470        )?));
471    }
472
473    Ok(None)
474}
475
476fn declutter_neutral(
477    model: &TypedModel,
478    node: &TypedNode,
479    mini_op: &dyn BinMiniOp,
480    out_dt: DatumType,
481) -> TractResult<Option<TypedModelPatch>> {
482    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
483        let is_neutral = mini_op
484            .neutral_element()
485            .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
486            .unwrap_or(false);
487
488        // For some operand neural element can be the left one while for other
489        // it is not the case (neutral - 1 -> not ok, 1 - neutal -> ok)
490        let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
491
492        if is_neutral && pos_checked {
493            // Neutral decluttering for quant values is special.
494            // - if (fa) (a-az)*as + (fb = 0) (b-bz)*bs = (fc) (c-cz)*cs
495            // - then even if fa = fc, quant params needs to be updated (a != c).
496            // So it's not a no_op.
497            if uniform.uni.datum_type().is_quantized() {
498                return Ok(Some(TypedModelPatch::replace_single_op(
499                    model,
500                    node,
501                    &[node.inputs[0]],
502                    cast(out_dt),
503                )?));
504            // In the non quantized case, it's a no_op.
505            } else {
506                return Ok(Some(TypedModelPatch::rewire(
507                    model,
508                    &[uniform.var],
509                    &[node.id.into()],
510                    &|_, inputs| Ok(inputs.into()),
511                )?));
512            }
513        }
514    }
515    Ok(None)
516}
517
518/// When one input is the absorbing element (e.g. 0 for Mul, false for And),
519/// replace the entire op with a uniform-value tensor of the output shape.
520///
521/// We can't shunt the uniform input directly: it may be lower-rank or have
522/// broadcast-from-1 dims that don't match the op's output shape (e.g.
523/// `Mul([4, 1], scalar-0)` outputs `[4, 1]`, not `[1]`).  Wire a
524/// `MultiBroadcastTo` from the uniform constant to the output shape;
525/// subsequent declutter folds it into a pure constant when the shape is
526/// fully concrete.
527fn declutter_absorbing(
528    model: &TypedModel,
529    node: &TypedNode,
530    mini_op: &dyn BinMiniOp,
531) -> TractResult<Option<TypedModelPatch>> {
532    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
533        let is_absorbing = mini_op
534            .absorbing_element()
535            .map(|absorb| tensor0(absorb).close_enough(&uniform.uni, false).is_ok())
536            .unwrap_or(false);
537        if is_absorbing {
538            let output_fact = model.outlet_fact(node.id.into())?;
539            let output_dt = output_fact.datum_type;
540            let output_shape = output_fact.shape.clone();
541            let uni_inlet = if uniform.left_is_uniform { 0 } else { 1 };
542            let uni_input_shape = &model.outlet_fact(node.inputs[uni_inlet])?.shape;
543            // Fast path: shapes and types match — shunt the absorbing input directly.
544            if uni_input_shape == &output_shape && uniform.uni.datum_type() == output_dt {
545                return Ok(Some(TypedModelPatch::rewire(
546                    model,
547                    &[node.inputs[uni_inlet]],
548                    &[node.id.into()],
549                    &|_, inputs| Ok(inputs.into()),
550                )?));
551            }
552            // General path: create a constant encoded in the output type.
553            // This handles both shape mismatches and quantization mismatches
554            // (e.g. absorbing input is QU8(Z:61 S:1) but output is QU8(Z:0 S:0.5)).
555            let absorb_val = mini_op.absorbing_element().unwrap();
556            let absorbing_const =
557                tensor0(absorb_val as f32).cast_to_dt(output_dt)?.into_owned().into_arc_tensor();
558            let mut patch = TypedModelPatch::default();
559            let uni_const =
560                patch.add_const(format!("{}.absorbing_const", node.name), absorbing_const)?;
561            let bcast = patch.wire_node(
562                format!("{}.absorbing_bcast", node.name),
563                crate::ops::array::MultiBroadcastTo { shape: output_shape },
564                &[uni_const],
565            )?[0];
566            patch.shunt_outside(model, node.id.into(), bcast)?;
567            return Ok(Some(patch));
568        }
569    }
570    Ok(None)
571}
572
573fn find_most_efficient_config(
574    model: &TypedModel,
575    node: &TypedNode,
576    swap_input: bool,
577) -> TractResult<(bool, bool)> {
578    if let &[a, b] = &*model.node_input_facts(node.id)? {
579        let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
580        let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
581
582        let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
583        let num_by_scalar_elements = if by_scalar_is_possible {
584            a_shape
585                .iter()
586                .zip(b_shape.iter())
587                .rev()
588                .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
589                .map(|(rev_a_dim, _)| rev_a_dim)
590                .product::<TDim>()
591        } else {
592            TDim::Val(0)
593        };
594
595        let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
596        let num_unicast_elements = if unicast_is_possible {
597            a_shape
598                .iter()
599                .zip(b_shape.iter())
600                .rev()
601                .take_while(|(a_dim, b_dim)| a_dim == b_dim)
602                .map(|(a_dim, _)| a_dim)
603                .product::<TDim>()
604        } else {
605            TDim::Val(0)
606        };
607
608        let min_num_elements = 32;
609        let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
610        let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
611        return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
612    }
613    Ok((false, false))
614}
615
616pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
617    TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
618}
619
620#[derive(Clone)]
621pub struct OptBinByScalar {
622    pub binop: Box<dyn BinMiniOp>,
623    eval_fn: Arc<LinalgFn>,
624}
625
626impl Debug for OptBinByScalar {
627    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
628        f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
629    }
630}
631
632impl OptBinByScalar {
633    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
634        if a_shape.len() != b_shape.len() {
635            return false;
636        };
637
638        a_shape
639            .iter()
640            .zip(b_shape.iter())
641            .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
642            .all(|(_, b_dim)| *b_dim == 1.to_dim())
643    }
644}
645
646impl PartialEq for OptBinByScalar {
647    fn eq(&self, other: &Self) -> bool {
648        *self.binop == *other.binop
649    }
650}
651impl Eq for OptBinByScalar {}
652
653impl Op for OptBinByScalar {
654    fn name(&self) -> StaticName {
655        format!("Opt{}ByScalar", self.binop.name()).into()
656    }
657
658    op_as_typed_op!();
659}
660
661impl EvalOp for OptBinByScalar {
662    fn is_stateless(&self) -> bool {
663        true
664    }
665
666    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
667        let (a, b) = args_2!(inputs);
668        // Same as OptBinUnicast: the fast path uses at_prefix + as_slice_mut
669        // and relies on natural C-order strides for the slice math. Fall back
670        // to the generic eval if either operand has non-natural strides or a
671        // storage size that doesn't match its declared shape (e.g. after
672        // Tensor::insert_axis which leaves non-natural strides behind).
673        let a_natural = a.len() == a.shape().iter().product::<usize>()
674            && a.strides() == &*Tensor::natural_strides(a.shape());
675        let b_natural = b.len() == b.shape().iter().product::<usize>()
676            && b.strides() == &*Tensor::natural_strides(b.shape());
677        if !a_natural || !b_natural {
678            let c_dt = self.binop.result_datum_type(a.datum_type(), b.datum_type())?;
679            return Ok(tvec!(self.binop.eval(a, b, c_dt)?.into_tvalue()));
680        }
681
682        // Not a requirement as TensorView doesn't require a owned tensor but in reality
683        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
684        let a = a.into_tensor();
685        let b_shape = b.shape();
686
687        let first_unary_axis = b_shape
688            .iter()
689            .enumerate()
690            .rev()
691            .take_while(|&(_, &dim)| dim == 1)
692            .map(|(i, _)| i)
693            .last()
694            .context("Cannot use by_scalar when no trailing dimensions are unary")?;
695
696        let iterating_shape = &a.shape()[..first_unary_axis];
697        if !iterating_shape.is_empty() {
698            for it_coords in tract_ndarray::indices(iterating_shape) {
699                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
700                let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
701                debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
702                (self.eval_fn)(&mut view, &b_view)?;
703            }
704        } else {
705            let mut view = a.view();
706            let b_view = b.view();
707            debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
708            (self.eval_fn)(&mut view, &b_view)?;
709        }
710        Ok(tvec!(a.into_tvalue()))
711    }
712}
713
714impl TypedOp for OptBinByScalar {
715    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
716        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
717        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
718        let out_shape = inputs[0].shape.clone();
719        Ok(tvec!(out_dt.fact(out_shape)))
720    }
721
722    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
723        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
724        Ok(self
725            .binop
726            .cost_per_element(inputs[0].datum_type)
727            .into_iter()
728            .map(|(c, n)| (c, count.clone() * n))
729            .collect())
730    }
731
732    as_op!();
733}
734
735#[derive(Clone)]
736pub struct OptBinUnicast {
737    pub binop: Box<dyn BinMiniOp>,
738    eval_fn: Arc<LinalgFn>,
739}
740
741impl Debug for OptBinUnicast {
742    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
743        f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
744    }
745}
746
747impl OptBinUnicast {
748    fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
749        let num_iterations: TDim = a_shape
750            .iter()
751            .zip(b_shape.iter())
752            .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
753            .map(|(a_dim, _)| a_dim)
754            .product();
755
756        if num_iterations.is_one() {
757            return true;
758        }
759
760        let elements_per_iteration: TDim = a_shape
761            .iter()
762            .zip(b_shape.iter())
763            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
764            .map(|(_, b_dim)| b_dim)
765            .product();
766
767        if let Ok(num_element) = elements_per_iteration.to_i64() {
768            let required_alignment = vector_size();
769            (num_element as usize).is_multiple_of(required_alignment)
770        } else {
771            false
772        }
773    }
774    fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
775        if a_shape.len() != b_shape.len() {
776            return false;
777        };
778
779        let unicast_possible = a_shape
780            .iter()
781            .zip(b_shape.iter())
782            .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
783            .all(|(a_dim, b_dim)| a_dim == b_dim);
784        let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
785
786        unicast_possible && unicast_is_aligned
787    }
788}
789
790impl PartialEq for OptBinUnicast {
791    fn eq(&self, other: &Self) -> bool {
792        *self.binop == *other.binop
793    }
794}
795impl Eq for OptBinUnicast {}
796
797impl Op for OptBinUnicast {
798    fn name(&self) -> StaticName {
799        format!("Opt{}Unicast", self.binop.name()).into()
800    }
801
802    op_as_typed_op!();
803}
804
805impl EvalOp for OptBinUnicast {
806    fn is_stateless(&self) -> bool {
807        true
808    }
809
810    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
811        let (a, b) = args_2!(inputs);
812        // The unicast fast path indexes each input's storage via at_prefix +
813        // as_slice_mut, which uses `strides[i-1]` to size the resulting slice
814        // (data/src/tensor/view.rs:99). That formula only matches ∏(shape[i..])
815        // when the tensor has natural C-order strides. Producers like
816        // Tensor::insert_axis leave non-natural strides on a tensor (e.g.
817        // shape `[1, 1, 640]` with strides `[1, 1, 1]` after two insert_axis
818        // on a `[640]` tensor), which silently breaks the slice math. Fall
819        // back to the generic broadcasting eval when either operand is not in
820        // natural strides (or has a storage size that doesn't match the
821        // declared shape).
822        let a_natural = a.len() == a.shape().iter().product::<usize>()
823            && a.strides() == &*Tensor::natural_strides(a.shape());
824        let b_natural = b.len() == b.shape().iter().product::<usize>()
825            && b.strides() == &*Tensor::natural_strides(b.shape());
826        if !a_natural || !b_natural {
827            let c_dt = self.binop.result_datum_type(a.datum_type(), b.datum_type())?;
828            return Ok(tvec!(self.binop.eval(a, b, c_dt)?.into_tvalue()));
829        }
830
831        // Not a requirement as TensorView doesn't require a owned tensor but in reality
832        // "a "should be mutable (it's omitted here as Rust compiler advise to remove it)
833        let a = a.into_tensor();
834        let b_shape = b.shape();
835        let b_view = b.view();
836        let first_non_unary_axis =
837            b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
838
839        if let Some(first_non_unary_axis) = first_non_unary_axis {
840            // Iterate on outter dimensions and evaluate with unicast subviews
841            let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
842            for it_coords in tract_ndarray::indices(iterating_shape) {
843                let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
844                debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
845                (self.eval_fn)(&mut view, &b_view)?;
846            }
847        } else {
848            let mut view = a.view();
849            debug_assert_eq!(view.shape(), b_view.shape());
850            (self.eval_fn)(&mut view, &b_view)?;
851        }
852
853        Ok(tvec!(a.into_tvalue()))
854    }
855}
856
857impl TypedOp for OptBinUnicast {
858    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
859        ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
860        let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
861        let out_shape = inputs[0].shape.clone();
862        Ok(tvec!(out_dt.fact(out_shape)))
863    }
864
865    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
866        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
867        Ok(self
868            .binop
869            .cost_per_element(inputs[0].datum_type)
870            .into_iter()
871            .map(|(c, n)| (c, count.clone() * n))
872            .collect())
873    }
874
875    as_op!();
876}
877
878#[macro_export]
879macro_rules! bin_to_super_type {
880    ($func:ident, $Op:ident,
881     $(codegen: $codegen:expr,)?
882     $(cost: $cost:expr,)?
883     $(declutter: $declutter:expr,)?
884     $(eval_in_a: $eval_in_a:expr,)?
885     $(eval_override: $eval_override: expr,)?
886     $(linalg: $linalg:ident,)?
887     $(operating_datum_type: $operating_datum_type:expr,)?
888     $(is_commutative: $is_commutative:expr,)?
889     $(neutral_element: $neutral_element:expr,)?
890     $(absorbing_element: $absorbing_element:expr,)?
891     $(out_of_place: $out_of_place:expr,)?
892     $(validation: $validation:expr,)?
893     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
894     $(q_op_on_f32: $q_op_on_f32:expr,)?
895     $( [$($typ:ident),*] => $cab:expr),*) => {
896        #[derive(Debug, Clone, Hash, PartialEq, Eq)]
897        pub struct $Op;
898        #[allow(clippy::redundant_closure_call)]
899        impl $crate::ops::binary::BinMiniOp for $Op {
900            fn name(&self) -> &'static str {
901                stringify!($Op)
902            }
903
904            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
905                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
906                    // Same-shape fast path: bypass ndarray Zip when c, a, b
907                    // share the same shape (and hence same len for plain
908                    // storage). Iterate over slices directly.
909                    if c.shape() == a.shape() && a.shape() == b.shape() {
910                        $(
911                            $(if c.datum_type() == $typ::datum_type() {
912                                let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
913                                let a_plain = a.try_as_plain()?;
914                                let a_slice = a_plain.as_slice::<$typ>()?;
915                                let b_plain = b.try_as_plain()?;
916                                let b_slice = b_plain.as_slice::<$typ>()?;
917                                let mut c_plain = c.try_as_plain_mut()?;
918                                let c_slice = c_plain.as_slice_mut::<$typ>()?;
919                                debug_assert_eq!(c_slice.len(), a_slice.len());
920                                debug_assert_eq!(c_slice.len(), b_slice.len());
921                                for ((cv, av), bv) in c_slice.iter_mut().zip(a_slice.iter()).zip(b_slice.iter()) {
922                                    cab(cv, av, bv);
923                                }
924                                return Ok(())
925                            })*
926                        )*
927                        $(
928                            $(
929                                $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
930                                    let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
931                                    let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
932                                    let a_plain = a.try_as_plain()?;
933                                    let a_slice = a_plain.as_slice::<$typ_dt>()?;
934                                    let b_plain = b.try_as_plain()?;
935                                    let b_slice = b_plain.as_slice::<$typ_dt>()?;
936                                    let mut c_plain = c.try_as_plain_mut()?;
937                                    let c_slice = c_plain.as_slice_mut::<$typ_dt>()?;
938                                    for ((cv, av), bv) in c_slice.iter_mut().zip(a_slice.iter()).zip(b_slice.iter()) {
939                                        cab(cv, av, bv, zp, scale);
940                                    }
941                                    return Ok(())
942                                })*
943                            )*
944                        )?
945                    }
946                    $(
947                        $(if c.datum_type() == $typ::datum_type() {
948                            let a = a.to_plain_array_view::<$typ>()?;
949                            let b = b.to_plain_array_view::<$typ>()?;
950                            let mut c_plain = c.try_as_plain_mut()?;
951                            let mut c = c_plain.to_array_view_mut::<$typ>()?;
952                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
953                            return Ok(())
954                        })*
955                     )*
956                    $(
957                        $(
958                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
959                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
960                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
961                                let a = a.to_plain_array_view::<$typ_dt>()?;
962                                let b = b.to_plain_array_view::<$typ_dt>()?;
963                                let mut c_plain = c.try_as_plain_mut()?;
964                                let mut c = c_plain.to_array_view_mut::<$typ_dt>()?;
965                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
966                                return Ok(())
967                            }
968                            )*
969                         )*
970                     )?
971                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
972            }
973
974            $(fn is_commutative(&self) -> bool {
975                $is_commutative
976            })?
977            $(fn neutral_element(&self) -> Option<i64> {
978                Some($neutral_element)
979            })?
980            $(fn absorbing_element(&self) -> Option<i64> {
981                Some($absorbing_element)
982            })?
983            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
984                // c and a are same type
985                $(if $eval_in_a(a, b)? { return Ok(()) } )?
986                // Same-shape fast path: bypass ndarray Zip when a and b share
987                // the same shape (and hence same len for plain storage).
988                if a.shape() == b.shape() {
989                    $(
990                        $(if b.datum_type() == $typ::datum_type() {
991                            let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
992                            let b_plain = b.try_as_plain()?;
993                            let b_slice = b_plain.as_slice::<$typ>()?;
994                            let mut a_plain = a.try_as_plain_mut()?;
995                            let a_slice = a_plain.as_slice_mut::<$typ>()?;
996                            debug_assert_eq!(a_slice.len(), b_slice.len());
997                            for (av, bv) in a_slice.iter_mut().zip(b_slice.iter()) {
998                                cab(av, &av.clone(), bv);
999                            }
1000                            return Ok(())
1001                        })*
1002                    )*
1003                    $(
1004                        $(
1005                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
1006                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
1007                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
1008                                let b_plain = b.try_as_plain()?;
1009                                let b_slice = b_plain.as_slice::<$typ_dt>()?;
1010                                let mut a_plain = a.try_as_plain_mut()?;
1011                                let a_slice = a_plain.as_slice_mut::<$typ_dt>()?;
1012                                for (av, bv) in a_slice.iter_mut().zip(b_slice.iter()) {
1013                                    cab(av, &(av.clone()), bv, zp, scale);
1014                                }
1015                                return Ok(())
1016                            })*
1017                        )*
1018                    )?
1019                }
1020                $(
1021                    $(if b.datum_type() == $typ::datum_type() {
1022                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
1023                        let b = b.to_plain_array_view::<$typ>()?;
1024                        let mut a_plain = a.try_as_plain_mut()?;
1025                        let mut a = a_plain.to_array_view_mut::<$typ>()?;
1026                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
1027                        return Ok(())
1028                    })*
1029                )*
1030                $(
1031                    $(
1032                        $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
1033                            let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
1034                            let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
1035                            let mut a_plain = a.try_as_plain_mut()?;
1036                            let mut a = a_plain.to_array_view_mut::<$typ_dt>()?;
1037                            let b = b.to_plain_array_view::<$typ_dt>()?;
1038                            $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
1039                                cab(a, &(a.clone()), b, zp, scale)
1040                            });
1041                            return Ok(())
1042                        })*
1043                    )*
1044                )?
1045                bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
1046            }
1047
1048            $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
1049                $eval_override(a, b, c_dt)
1050            })?
1051
1052            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
1053                if a.unquantized() == b.unquantized() {
1054                    if a.is_quantized() || !b.is_quantized() {
1055                        return Ok(a)
1056                    }
1057                    else {
1058                        return Ok(b)
1059                    }
1060                }
1061                self.operating_datum_type(a, b)
1062            }
1063
1064                $(
1065                    fn declutter(
1066                        &self,
1067                        model: &TypedModel,
1068                        node: &TypedNode,
1069                        ) -> TractResult<Option<TypedModelPatch>> {
1070                        ($declutter)(self, model, node)
1071                    }
1072                 )?
1073                $(
1074                    fn codegen(
1075                        &self,
1076                        model: &TypedModel,
1077                        node: &TypedNode,
1078                        a: &Arc<Tensor>,
1079                        ) -> TractResult<Option<TypedModelPatch>> {
1080                        ($codegen)(self, model, node, a)
1081                    }
1082                 )?
1083                $(
1084                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
1085                        ($cost)(dt)
1086                    }
1087                 )?
1088                $(
1089                    fn validation(&self) -> Validation {
1090                        $validation
1091                    }
1092                 )?
1093                $(
1094                    fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
1095                        Some(tract_linalg::BinOp::$linalg)
1096                    }
1097                 )?
1098                $(
1099                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
1100                        ($operating_datum_type)(a, b)
1101                    })?
1102
1103
1104            /// Default simple binary operation for QFormat where
1105            /// we dequantise & apply requested operation in float & requantize it
1106            /// several implementation are provided with pro & con
1107            #[allow(unused_variables)]
1108            fn maybe_eval_qbinary_as_float_op(
1109                &self,
1110                a: &TValue,
1111                b: &TValue,
1112                c_dt: &DatumType,
1113            ) -> TractResult<Option<Tensor>> {
1114                $(
1115                    /// Implementation strive to minimise memory allocation and access
1116                    /// we apply only if type is QU8 zp_scale datum type
1117                    /// maybe more suited for large models tensors
1118                    fn memory_optimised_q_binary_as_float_op(
1119                        a: &TValue,
1120                        b: &TValue,
1121                        c_dt: &DatumType,
1122                    ) -> TractResult<Option<Tensor>> {
1123                        if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
1124                                DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
1125                                DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
1126                            (a.datum_type(), b.datum_type(), c_dt)
1127                        {
1128                            let c_inv_scale = 1.0 / c_scale;
1129                            let a = a.to_plain_array_view::<u8>()?;
1130                            let b = b.to_plain_array_view::<u8>()?;
1131                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1132                            let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
1133                            let mut c_plain = c.try_as_plain_mut()?;
1134                            let view = c_plain.to_array_view_mut::<u8>()?;
1135                            $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
1136                                *c = (scale_by($q_op_on_f32(
1137                                            ((*a as i32 - a_zp as i32) as f32 * a_scale),
1138                                            ((*b as i32 - b_zp as i32) as f32 * b_scale),
1139                                ), c_inv_scale) as i32
1140                                    + *c_zp as i32)
1141                                    .clamp_cast()
1142                            });
1143                            return Ok(Some(c));
1144                        }
1145                        Ok(None)
1146                    }
1147
1148                    /// Apply to all Q types
1149                    /// Take more memory but hopefully faster than memory_optimised_q_binary_as_float_op
1150                    /// especially once cast_to_dt will have will have vectorized implementations
1151                    fn generic_q_binary_as_float_op(
1152                        a: &TValue,
1153                        b: &TValue,
1154                        c_dt: &DatumType,
1155                        accumulator_dt: DatumType
1156                    ) -> TractResult<Option<Tensor>> {
1157                        if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
1158                            let a = a.cast_to_dt(accumulator_dt)?.into_owned();
1159                            let b = b.cast_to_dt(accumulator_dt)?.into_owned();
1160                            let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1161                            let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
1162                            match accumulator_dt {
1163                                DatumType::F32 => {
1164                                    let mut c_plain = c.try_as_plain_mut()?;
1165                                    let view = c_plain.to_array_view_mut::<f32>()?;
1166                                    $crate::ndarray::Zip::from(view).and_broadcast(a.try_as_plain()?.to_array_view()?).and_broadcast(b.try_as_plain()?.to_array_view()?).for_each(|c, a, b| {
1167                                        *c = $q_op_on_f32(*a,*b);
1168                                    })
1169                                },
1170                                other => bail!("unexpected accumulator data type as {:?}", other)
1171                            };
1172
1173                            return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
1174                        }
1175                        Ok(None)
1176                    }
1177
1178                    if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
1179                        return Ok(Some(c));
1180                    }
1181                    if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
1182                        return Ok(Some(d));
1183                    }
1184                )?
1185                Ok(None)
1186            }
1187        }
1188
1189        pub fn $func() -> $crate::ops::binary::TypedBinOp {
1190            $crate::ops::binary::TypedBinOp(Box::new($Op), None)
1191        }
1192    };
1193}
1194
1195#[derive(Debug)]
1196pub(crate) struct OneUniformInput {
1197    pub uni: Arc<Tensor>,
1198    pub var: OutletId,
1199    pub left_is_uniform: bool,
1200}
1201
1202pub(crate) fn one_input_is_uniform(
1203    model: &TypedModel,
1204    node: &TypedNode,
1205) -> TractResult<Option<OneUniformInput>> {
1206    if let &[a, b] = &*model.node_input_facts(node.id)? {
1207        let uni = if let Some(a) = &a.uniform {
1208            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
1209        } else if let Some(b) = &b.uniform {
1210            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
1211        } else {
1212            return Ok(None);
1213        };
1214        let var_fact = [a, b][uni.left_is_uniform as usize];
1215        let uni_fact = [a, b][!uni.left_is_uniform as usize];
1216        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
1217            return Ok(Some(uni));
1218        }
1219    }
1220    Ok(None)
1221}
1222
1223#[cfg(test)]
1224mod tests {
1225    use super::*;
1226
1227    /// Reproducer for the OptBinUnicast panic seen on Nemotron decoder CI
1228    /// (cuda-lovelace + Darwin). A 1-D tensor that goes through `insert_axis`
1229    /// twice ends up with declared shape `[1, 1, 640]` but strides `[1, 1, 1]`
1230    /// instead of the natural `[640, 640, 1]`. TensorView::at_prefix then
1231    /// returns a view whose `len()` reads `strides[1] = 1`, so the unicast
1232    /// kernel sees `a.len = 1, b.len = 640` and OOBs into the tile buffer.
1233    ///
1234    /// Pre-fix this test panics inside `linalg/src/frame/unicast.rs` with
1235    /// "range end index 640 out of range for slice of length …". With the
1236    /// natural-strides guard in `OptBinUnicast::eval`, the call falls back to
1237    /// `BinMiniOp::eval` and produces correct output.
1238    #[test]
1239    fn opt_bin_unicast_falls_back_on_non_natural_strides() {
1240        // Construct `a` the way the LSTM bias path does: build a 640-element
1241        // 1-D tensor, then insert two leading unit dims.
1242        let a_data: Vec<f32> = (0..640).map(|i| i as f32).collect();
1243        let mut a = tensor1(&a_data);
1244        a.insert_axis(0).unwrap();
1245        a.insert_axis(0).unwrap();
1246        assert_eq!(a.shape(), &[1, 1, 640]);
1247        assert_eq!(a.strides(), &[1, 1, 1]);
1248        assert_ne!(a.strides(), &*Tensor::natural_strides(a.shape()));
1249
1250        // `b` is a normal contiguous tensor of the same declared shape.
1251        let b_data: Vec<f32> = vec![1.0; 640];
1252        let mut b = tensor1(&b_data);
1253        b.insert_axis(0).unwrap();
1254        b.insert_axis(0).unwrap();
1255        // Reset b to natural strides so we exercise only the a-broken path
1256        // and let the b-side go through cleanly.
1257        b = b.into_shape(&[1, 1, 640]).unwrap();
1258
1259        let linalg_fn = tract_linalg::bin_unicast(f32::datum_type(), BinOp::Add)
1260            .expect("f32 unicast Add kernel available");
1261        let op = OptBinUnicast { binop: Box::new(Add), eval_fn: Arc::from(linalg_fn) };
1262
1263        let out = op.eval(tvec!(a.into_tvalue(), b.into_tvalue())).unwrap();
1264        let out = &out[0];
1265        assert_eq!(out.shape(), &[1, 1, 640]);
1266        let plain = out.try_as_plain().unwrap();
1267        let out_slice = plain.as_slice::<f32>().unwrap();
1268        for (i, v) in out_slice.iter().enumerate() {
1269            assert_eq!(*v, i as f32 + 1.0, "mismatch at {i}");
1270        }
1271    }
1272}