1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt::{self, Debug};
6use tract_data::itertools::izip;
7use tract_itertools::Itertools;
8use tract_linalg::{BinOp, LinalgFn};
9
10use super::math::{Add, Max, Min, Mul, Sub};
11use super::{cast::cast, math::SubF};
12
13pub trait BinMiniOp:
14 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
15{
16 fn name(&self) -> &'static str;
17 fn validation(&self) -> Validation {
18 Validation::Accurate
19 }
20 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
21 a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
22 }
23 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
24 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
25 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
26
27 fn is_commutative(&self) -> bool {
28 true
29 }
30 fn neutral_element(&self) -> Option<i64> {
31 None
32 }
33 fn absorbing_element(&self) -> Option<i64> {
34 None
35 }
36
37 #[allow(unused_variables)]
38 fn maybe_eval_qbinary_as_float_op(
39 &self,
40 a: &TValue,
41 b: &TValue,
42 c_dt: &DatumType,
43 ) -> TractResult<Option<Tensor>> {
44 Ok(None)
45 }
46
47 fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
48 if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
49 return Ok(tensor);
50 }
51 if c_dt == a.datum_type() && a.shape() == b.shape() {
56 let mut a = a.into_tensor();
57 self.eval_in_a(&mut a, &b)?;
58 return Ok(a);
59 }
60 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
61 if &*c_shape == a.shape() && c_dt == a.datum_type() {
62 let mut a = a.into_tensor();
63 self.eval_in_a(&mut a, &b)?;
64 Ok(a)
65 } else {
66 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
67 self.eval_out_of_place(&mut c, &a, &b)?;
68 Ok(c)
69 }
70 }
71 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
72 self.generic_eval(a, b, c_dt)
73 }
74 #[allow(unused_variables)]
75 fn declutter(
76 &self,
77 model: &TypedModel,
78 node: &TypedNode,
79 ) -> TractResult<Option<TypedModelPatch>> {
80 Ok(None)
81 }
82 #[allow(unused_variables)]
83 fn codegen(
84 &self,
85 model: &TypedModel,
86 node: &TypedNode,
87 ) -> TractResult<Option<TypedModelPatch>> {
88 Ok(None)
89 }
90 #[allow(unused_variables)]
91 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
92 tvec!()
93 }
94 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
95 None
96 }
97
98 #[allow(unused_variables)]
100 fn eval_symbolic(
101 &self,
102 session: &TurnState,
103 inputs: TVec<TValue>,
104 ) -> TractResult<Option<TVec<TValue>>> {
105 Ok(None)
106 }
107
108 #[allow(unused_variables)]
110 fn uniform_tdim_comparison(&self, a: &TDim, b: &TDim) -> Option<TDim> {
111 None
112 }
113}
114dyn_clone::clone_trait_object!(BinMiniOp);
115dyn_eq::eq_trait_object!(BinMiniOp);
116downcast_rs::impl_downcast!(BinMiniOp);
117
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
120
121impl Op for TypedBinOp {
122 fn name(&self) -> StaticName {
123 self.0.name().into()
124 }
125
126 fn validation(&self) -> Validation {
127 self.0.validation()
128 }
129
130 op_as_typed_op!();
131}
132
133impl TypedBinOp {
134 fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
135 if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
136 }
137}
138
139impl EvalOp for TypedBinOp {
140 fn is_stateless(&self) -> bool {
141 true
142 }
143
144 fn eval_with_session(
145 &self,
146 _node_id: usize,
147 session: &TurnState,
148 inputs: TVec<TValue>,
149 ) -> TractResult<TVec<TValue>> {
150 if let Some(result) = self.0.eval_symbolic(session, inputs.clone())? {
151 return Ok(result);
152 }
153 let (a, b) = args_2!(inputs);
154 ensure!(a.rank() == b.rank());
155 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
156 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
157 }
158
159 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
160 let (a, b) = args_2!(inputs);
161 ensure!(a.rank() == b.rank());
162 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
163 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
164 }
165}
166
167impl TypedBinOp {
168 fn combine_uniform_tdim(&self, a: &TDim, b: &TDim) -> Option<TDim> {
169 if let Some(result) = self.0.uniform_tdim_comparison(a, b) {
171 return Some(result);
172 }
173 let a = tensor0(a.clone()).into_tvalue();
174 let b = tensor0(b.clone()).into_tvalue();
175 let result = self.0.eval(a, b, TDim::datum_type()).ok()?;
176 result
177 .try_as_plain()
178 .ok()
179 .and_then(|d| d.as_slice::<TDim>().ok())
180 .and_then(|s| s.first())
181 .cloned()
182 .map(|d| d.reduce())
183 }
184
185 fn combine_uniform_tdim_with_konst(&self, a: &TDim, konst: &Tensor) -> Option<TDim> {
186 if konst.len() != 1 {
187 return None;
188 }
189 let b_int: Option<i64> =
191 if konst.datum_type().is_integer() || konst.datum_type().is::<bool>() {
192 konst.cast_to_scalar::<i64>().ok()
193 } else if konst.datum_type().is_float() {
194 konst.cast_to_scalar::<f64>().ok().and_then(|f| {
195 if (f - f.round()).abs() < 1e-6 { Some(f.round() as i64) } else { None }
196 })
197 } else {
198 None
199 };
200 if let Some(b) = b_int {
201 return self.combine_uniform_tdim(a, &TDim::Val(b));
202 }
203 if self.0.neutral_element() == Some(1)
205 && let Some(f) = konst.cast_to_scalar::<f64>().ok().filter(|&f| f > 0.0)
206 {
207 let n = (1.0 / f).round() as u64;
208 if n >= 2 && (f * n as f64 - 1.0).abs() < 1e-6 {
209 return Some(TDim::Div(Box::new(a.clone()), n).reduce());
210 }
211 }
212 None
213 }
214}
215
216impl TypedOp for TypedBinOp {
217 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
218 if inputs[0].rank() != inputs[1].rank() {
219 bail!(
220 "Typed ops require rank match. Invalid inputs for {}: {}",
221 self.name(),
222 inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
223 );
224 }
225 let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
226 let mut fact = out_dt.fact(&*crate::broadcast::multi_broadcast(&[
227 &inputs[0].shape.to_tvec(),
228 &inputs[1].shape.to_tvec(),
229 ])?);
230 if let (Some(a), Some(b)) = (&inputs[0].uniform_tdim, &inputs[1].uniform_tdim) {
231 fact.uniform_tdim = self.combine_uniform_tdim(a, b);
232 if fact.uniform_tdim.is_none() && self.0.is::<crate::ops::logic::And>() {
234 fact.uniform_tdim = Some(TDim::Mul(vec![a.clone(), b.clone()]).reduce());
235 }
236 }
237 if fact.uniform_tdim.is_none() {
239 for (expr, konst_fact) in [
240 (inputs[0].uniform_tdim.as_ref(), inputs[1]),
241 (inputs[1].uniform_tdim.as_ref(), inputs[0]),
242 ] {
243 let Some(a) = expr else { continue };
244 let Some(konst) = konst_fact.konst.as_ref() else { continue };
245 fact.uniform_tdim = self.combine_uniform_tdim_with_konst(a, konst);
246 if fact.uniform_tdim.is_some() {
247 break;
248 }
249 }
250 }
251 Ok(tvec!(fact))
252 }
253
254 fn input_roi(
255 &self,
256 model: &TypedModel,
257 node: &TypedNode,
258 ) -> TractResult<Option<TVec<Option<TDim>>>> {
259 if self.0.neutral_element() == Some(1) {
262 for (mask_ix, other_ix) in [(0usize, 1usize), (1, 0)] {
263 let fact = model.outlet_fact(node.inputs[mask_ix])?;
264 if let Some(mask_expr) = &fact.uniform_tdim {
265 let mut rois = tvec![None; node.inputs.len()];
266 rois[other_ix] = Some(mask_expr.clone());
267 return Ok(Some(rois));
268 }
269 }
270 }
271 crate::optim::propagate_roi::bubble_roi(model, node)
273 }
274
275 fn change_axes(
276 &self,
277 model: &TypedModel,
278 node: &TypedNode,
279 _io: InOut,
280 change: &AxisOp,
281 ) -> TractResult<Option<AxisChangeConsequence>> {
282 if let AxisOp::Rm(rm) = change {
283 let (inputs, outputs) = model.node_facts(node.id)?;
284 if inputs.len() >= 2
285 && outputs.len() >= 1
286 && inputs[0].rank() > *rm
287 && inputs[1].rank() > *rm
288 && outputs[0].rank() > *rm
289 {
290 rule_if!(inputs[0].shape[*rm].is_one());
291 rule_if!(inputs[1].shape[*rm].is_one());
292 rule_if!(outputs[0].shape[*rm].is_one());
293 }
294 }
295 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
296 }
297
298 fn axes_mapping(
299 &self,
300 inputs: &[&TypedFact],
301 outputs: &[&TypedFact],
302 ) -> TractResult<AxesMapping> {
303 AxesMapping::natural(inputs, outputs)
304 }
305
306 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
307 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
308 Ok(self
309 .0
310 .cost_per_element(inputs[0].datum_type)
311 .into_iter()
312 .map(|(c, n)| (c, count.clone() * n))
313 .collect())
314 }
315
316 fn slice(
317 &self,
318 patch: &mut TypedModelPatch,
319 _model: &TypedModel,
320 _node: &TypedNode,
321 prefix: &str,
322 inputs: &[OutletId],
323 _output_axis: usize,
324 _start: &TDim,
325 _end: &TDim,
326 ) -> TractResult<Option<TVec<OutletId>>> {
327 Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
328 }
329
330 fn declutter(
331 &self,
332 model: &TypedModel,
333 node: &TypedNode,
334 ) -> TractResult<Option<TypedModelPatch>> {
335 let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
336 (a.datum_type().unwrap(), b.datum_type().unwrap())
337 } else {
338 unreachable!("TypedBinOp has two inputs.")
339 };
340 if let Some(neutral_patch) =
341 declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
342 {
343 return Ok(Some(neutral_patch));
344 }
345 if let Some(absorbing_patch) = declutter_absorbing(model, node, self.0.as_ref())? {
346 return Ok(Some(absorbing_patch));
347 }
348 if let Some(broadcast_patch) =
349 declutter_broadcasting_operand_1(model, node, self.0.clone())?
350 {
351 return Ok(Some(broadcast_patch));
352 }
353 self.0.declutter(model, node)
354 }
355
356 fn codegen(
357 &self,
358 model: &TypedModel,
359 node: &TypedNode,
360 ) -> TractResult<Option<TypedModelPatch>> {
361 if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
362 let input_facts = model.node_input_facts(node.id)?;
363 let must_swap_inputs =
364 input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
365 (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
366 });
367 let (operand_1, operand_2) = if must_swap_inputs {
368 (input_facts[1], input_facts[0])
369 } else {
370 (input_facts[0], input_facts[1])
371 };
372
373 let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
374 find_most_efficient_config(model, node, must_swap_inputs)?;
375
376 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
378 let op_is_quant = c_dt.is_quantized()
379 || operand_1.datum_type.is_quantized()
380 || operand_2.datum_type.is_quantized();
381
382 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
384 let c_shape = crate::broadcast::multi_broadcast(&[
385 operand_1.shape.clone(),
386 operand_2.shape.clone(),
387 ])?;
388 let can_eval_in_a =
389 (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
390
391 let inputs = if must_swap_inputs {
393 let mut swap_input = node.inputs.clone();
394 swap_input.swap(0, 1);
395 swap_input
396 } else {
397 node.inputs.clone()
398 };
399 let actual_linalg_op =
400 if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
401 let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
402
403 let dt = model.node_input_facts(node.id)?[0].datum_type;
404 if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
405 rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
406 let eval_fn = Arc::from(func);
407 return Ok(Some(
408 TypedModelPatch::replace_single_op(
409 model,
410 node,
411 &inputs,
412 OptBinByScalar { binop: actual_core_op, eval_fn },
413 )?
414 .with_context("ByScalar"),
415 ));
416 }
417
418 if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
419 rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
420 let eval_fn = Arc::from(func);
421 return Ok(Some(
422 TypedModelPatch::replace_single_op(
423 model,
424 node,
425 &inputs,
426 OptBinUnicast { binop: actual_core_op, eval_fn },
427 )?
428 .with_context("Unicast"),
429 ));
430 }
431 }
432
433 Ok(None)
434 }
435 as_op!();
436}
437
438fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
439 match linalg {
440 BinOp::Min => Box::new(Min),
441 BinOp::Max => Box::new(Max),
442 BinOp::Add => Box::new(Add),
443 BinOp::Mul => Box::new(Mul),
444 BinOp::Sub => Box::new(Sub),
445 BinOp::SubF => Box::new(SubF),
446 }
447}
448fn declutter_broadcasting_operand_1(
449 model: &TypedModel,
450 node: &TypedNode,
451 mini_op: Box<dyn BinMiniOp>,
452) -> TractResult<Option<TypedModelPatch>> {
453 let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
454 (a.shape.clone(), b.shape.clone())
455 } else {
456 unreachable!("TypedBinOp has two inputs.")
457 };
458
459 let a_num_elements = a_shape.iter().product::<TDim>();
460 let b_num_elements = b_shape.iter().product::<TDim>();
461 let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
462 if a_should_be_broadcast & mini_op.is_commutative() {
463 let mut swap_input = node.inputs.clone();
464 swap_input.swap(0, 1);
465 return Ok(Some(TypedModelPatch::replace_single_op(
466 model,
467 node,
468 &swap_input,
469 TypedBinOp(mini_op, None),
470 )?));
471 }
472
473 Ok(None)
474}
475
476fn declutter_neutral(
477 model: &TypedModel,
478 node: &TypedNode,
479 mini_op: &dyn BinMiniOp,
480 out_dt: DatumType,
481) -> TractResult<Option<TypedModelPatch>> {
482 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
483 let is_neutral = mini_op
484 .neutral_element()
485 .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
486 .unwrap_or(false);
487
488 let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
491
492 if is_neutral && pos_checked {
493 if uniform.uni.datum_type().is_quantized() {
498 return Ok(Some(TypedModelPatch::replace_single_op(
499 model,
500 node,
501 &[node.inputs[0]],
502 cast(out_dt),
503 )?));
504 } else {
506 return Ok(Some(TypedModelPatch::rewire(
507 model,
508 &[uniform.var],
509 &[node.id.into()],
510 &|_, inputs| Ok(inputs.into()),
511 )?));
512 }
513 }
514 }
515 Ok(None)
516}
517
518fn declutter_absorbing(
528 model: &TypedModel,
529 node: &TypedNode,
530 mini_op: &dyn BinMiniOp,
531) -> TractResult<Option<TypedModelPatch>> {
532 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
533 let is_absorbing = mini_op
534 .absorbing_element()
535 .map(|absorb| tensor0(absorb).close_enough(&uniform.uni, false).is_ok())
536 .unwrap_or(false);
537 if is_absorbing {
538 let output_fact = model.outlet_fact(node.id.into())?;
539 let output_dt = output_fact.datum_type;
540 let output_shape = output_fact.shape.clone();
541 let uni_inlet = if uniform.left_is_uniform { 0 } else { 1 };
542 let uni_input_shape = &model.outlet_fact(node.inputs[uni_inlet])?.shape;
543 if uni_input_shape == &output_shape && uniform.uni.datum_type() == output_dt {
545 return Ok(Some(TypedModelPatch::rewire(
546 model,
547 &[node.inputs[uni_inlet]],
548 &[node.id.into()],
549 &|_, inputs| Ok(inputs.into()),
550 )?));
551 }
552 let absorb_val = mini_op.absorbing_element().unwrap();
556 let absorbing_const =
557 tensor0(absorb_val as f32).cast_to_dt(output_dt)?.into_owned().into_arc_tensor();
558 let mut patch = TypedModelPatch::default();
559 let uni_const =
560 patch.add_const(format!("{}.absorbing_const", node.name), absorbing_const)?;
561 let bcast = patch.wire_node(
562 format!("{}.absorbing_bcast", node.name),
563 crate::ops::array::MultiBroadcastTo { shape: output_shape },
564 &[uni_const],
565 )?[0];
566 patch.shunt_outside(model, node.id.into(), bcast)?;
567 return Ok(Some(patch));
568 }
569 }
570 Ok(None)
571}
572
573fn find_most_efficient_config(
574 model: &TypedModel,
575 node: &TypedNode,
576 swap_input: bool,
577) -> TractResult<(bool, bool)> {
578 if let &[a, b] = &*model.node_input_facts(node.id)? {
579 let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
580 let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
581
582 let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
583 let num_by_scalar_elements = if by_scalar_is_possible {
584 a_shape
585 .iter()
586 .zip(b_shape.iter())
587 .rev()
588 .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
589 .map(|(rev_a_dim, _)| rev_a_dim)
590 .product::<TDim>()
591 } else {
592 TDim::Val(0)
593 };
594
595 let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
596 let num_unicast_elements = if unicast_is_possible {
597 a_shape
598 .iter()
599 .zip(b_shape.iter())
600 .rev()
601 .take_while(|(a_dim, b_dim)| a_dim == b_dim)
602 .map(|(a_dim, _)| a_dim)
603 .product::<TDim>()
604 } else {
605 TDim::Val(0)
606 };
607
608 let min_num_elements = 32;
609 let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
610 let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
611 return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
612 }
613 Ok((false, false))
614}
615
616pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
617 TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
618}
619
620#[derive(Clone)]
621pub struct OptBinByScalar {
622 pub binop: Box<dyn BinMiniOp>,
623 eval_fn: Arc<LinalgFn>,
624}
625
626impl Debug for OptBinByScalar {
627 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
628 f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
629 }
630}
631
632impl OptBinByScalar {
633 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
634 if a_shape.len() != b_shape.len() {
635 return false;
636 };
637
638 a_shape
639 .iter()
640 .zip(b_shape.iter())
641 .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
642 .all(|(_, b_dim)| *b_dim == 1.to_dim())
643 }
644}
645
646impl PartialEq for OptBinByScalar {
647 fn eq(&self, other: &Self) -> bool {
648 *self.binop == *other.binop
649 }
650}
651impl Eq for OptBinByScalar {}
652
653impl Op for OptBinByScalar {
654 fn name(&self) -> StaticName {
655 format!("Opt{}ByScalar", self.binop.name()).into()
656 }
657
658 op_as_typed_op!();
659}
660
661impl EvalOp for OptBinByScalar {
662 fn is_stateless(&self) -> bool {
663 true
664 }
665
666 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
667 let (a, b) = args_2!(inputs);
668 let a_natural = a.len() == a.shape().iter().product::<usize>()
674 && a.strides() == &*Tensor::natural_strides(a.shape());
675 let b_natural = b.len() == b.shape().iter().product::<usize>()
676 && b.strides() == &*Tensor::natural_strides(b.shape());
677 if !a_natural || !b_natural {
678 let c_dt = self.binop.result_datum_type(a.datum_type(), b.datum_type())?;
679 return Ok(tvec!(self.binop.eval(a, b, c_dt)?.into_tvalue()));
680 }
681
682 let a = a.into_tensor();
685 let b_shape = b.shape();
686
687 let first_unary_axis = b_shape
688 .iter()
689 .enumerate()
690 .rev()
691 .take_while(|&(_, &dim)| dim == 1)
692 .map(|(i, _)| i)
693 .last()
694 .context("Cannot use by_scalar when no trailing dimensions are unary")?;
695
696 let iterating_shape = &a.shape()[..first_unary_axis];
697 if !iterating_shape.is_empty() {
698 for it_coords in tract_ndarray::indices(iterating_shape) {
699 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
700 let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
701 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
702 (self.eval_fn)(&mut view, &b_view)?;
703 }
704 } else {
705 let mut view = a.view();
706 let b_view = b.view();
707 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
708 (self.eval_fn)(&mut view, &b_view)?;
709 }
710 Ok(tvec!(a.into_tvalue()))
711 }
712}
713
714impl TypedOp for OptBinByScalar {
715 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
716 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
717 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
718 let out_shape = inputs[0].shape.clone();
719 Ok(tvec!(out_dt.fact(out_shape)))
720 }
721
722 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
723 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
724 Ok(self
725 .binop
726 .cost_per_element(inputs[0].datum_type)
727 .into_iter()
728 .map(|(c, n)| (c, count.clone() * n))
729 .collect())
730 }
731
732 as_op!();
733}
734
735#[derive(Clone)]
736pub struct OptBinUnicast {
737 pub binop: Box<dyn BinMiniOp>,
738 eval_fn: Arc<LinalgFn>,
739}
740
741impl Debug for OptBinUnicast {
742 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
743 f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
744 }
745}
746
747impl OptBinUnicast {
748 fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
749 let num_iterations: TDim = a_shape
750 .iter()
751 .zip(b_shape.iter())
752 .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
753 .map(|(a_dim, _)| a_dim)
754 .product();
755
756 if num_iterations.is_one() {
757 return true;
758 }
759
760 let elements_per_iteration: TDim = a_shape
761 .iter()
762 .zip(b_shape.iter())
763 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
764 .map(|(_, b_dim)| b_dim)
765 .product();
766
767 if let Ok(num_element) = elements_per_iteration.to_i64() {
768 let required_alignment = vector_size();
769 (num_element as usize).is_multiple_of(required_alignment)
770 } else {
771 false
772 }
773 }
774 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
775 if a_shape.len() != b_shape.len() {
776 return false;
777 };
778
779 let unicast_possible = a_shape
780 .iter()
781 .zip(b_shape.iter())
782 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
783 .all(|(a_dim, b_dim)| a_dim == b_dim);
784 let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
785
786 unicast_possible && unicast_is_aligned
787 }
788}
789
790impl PartialEq for OptBinUnicast {
791 fn eq(&self, other: &Self) -> bool {
792 *self.binop == *other.binop
793 }
794}
795impl Eq for OptBinUnicast {}
796
797impl Op for OptBinUnicast {
798 fn name(&self) -> StaticName {
799 format!("Opt{}Unicast", self.binop.name()).into()
800 }
801
802 op_as_typed_op!();
803}
804
805impl EvalOp for OptBinUnicast {
806 fn is_stateless(&self) -> bool {
807 true
808 }
809
810 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
811 let (a, b) = args_2!(inputs);
812 let a_natural = a.len() == a.shape().iter().product::<usize>()
823 && a.strides() == &*Tensor::natural_strides(a.shape());
824 let b_natural = b.len() == b.shape().iter().product::<usize>()
825 && b.strides() == &*Tensor::natural_strides(b.shape());
826 if !a_natural || !b_natural {
827 let c_dt = self.binop.result_datum_type(a.datum_type(), b.datum_type())?;
828 return Ok(tvec!(self.binop.eval(a, b, c_dt)?.into_tvalue()));
829 }
830
831 let a = a.into_tensor();
834 let b_shape = b.shape();
835 let b_view = b.view();
836 let first_non_unary_axis =
837 b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
838
839 if let Some(first_non_unary_axis) = first_non_unary_axis {
840 let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
842 for it_coords in tract_ndarray::indices(iterating_shape) {
843 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
844 debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
845 (self.eval_fn)(&mut view, &b_view)?;
846 }
847 } else {
848 let mut view = a.view();
849 debug_assert_eq!(view.shape(), b_view.shape());
850 (self.eval_fn)(&mut view, &b_view)?;
851 }
852
853 Ok(tvec!(a.into_tvalue()))
854 }
855}
856
857impl TypedOp for OptBinUnicast {
858 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
859 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
860 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
861 let out_shape = inputs[0].shape.clone();
862 Ok(tvec!(out_dt.fact(out_shape)))
863 }
864
865 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
866 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
867 Ok(self
868 .binop
869 .cost_per_element(inputs[0].datum_type)
870 .into_iter()
871 .map(|(c, n)| (c, count.clone() * n))
872 .collect())
873 }
874
875 as_op!();
876}
877
878#[macro_export]
879macro_rules! bin_to_super_type {
880 ($func:ident, $Op:ident,
881 $(codegen: $codegen:expr,)?
882 $(cost: $cost:expr,)?
883 $(declutter: $declutter:expr,)?
884 $(eval_in_a: $eval_in_a:expr,)?
885 $(eval_override: $eval_override: expr,)?
886 $(linalg: $linalg:ident,)?
887 $(operating_datum_type: $operating_datum_type:expr,)?
888 $(is_commutative: $is_commutative:expr,)?
889 $(neutral_element: $neutral_element:expr,)?
890 $(absorbing_element: $absorbing_element:expr,)?
891 $(out_of_place: $out_of_place:expr,)?
892 $(validation: $validation:expr,)?
893 $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
894 $(q_op_on_f32: $q_op_on_f32:expr,)?
895 $( [$($typ:ident),*] => $cab:expr),*) => {
896 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
897 pub struct $Op;
898 #[allow(clippy::redundant_closure_call)]
899 impl $crate::ops::binary::BinMiniOp for $Op {
900 fn name(&self) -> &'static str {
901 stringify!($Op)
902 }
903
904 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
905 $(if $out_of_place(c, a, b)? { return Ok(()) } )?
906 if c.shape() == a.shape() && a.shape() == b.shape() {
910 $(
911 $(if c.datum_type() == $typ::datum_type() {
912 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
913 let a_plain = a.try_as_plain()?;
914 let a_slice = a_plain.as_slice::<$typ>()?;
915 let b_plain = b.try_as_plain()?;
916 let b_slice = b_plain.as_slice::<$typ>()?;
917 let mut c_plain = c.try_as_plain_mut()?;
918 let c_slice = c_plain.as_slice_mut::<$typ>()?;
919 debug_assert_eq!(c_slice.len(), a_slice.len());
920 debug_assert_eq!(c_slice.len(), b_slice.len());
921 for ((cv, av), bv) in c_slice.iter_mut().zip(a_slice.iter()).zip(b_slice.iter()) {
922 cab(cv, av, bv);
923 }
924 return Ok(())
925 })*
926 )*
927 $(
928 $(
929 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
930 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
931 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
932 let a_plain = a.try_as_plain()?;
933 let a_slice = a_plain.as_slice::<$typ_dt>()?;
934 let b_plain = b.try_as_plain()?;
935 let b_slice = b_plain.as_slice::<$typ_dt>()?;
936 let mut c_plain = c.try_as_plain_mut()?;
937 let c_slice = c_plain.as_slice_mut::<$typ_dt>()?;
938 for ((cv, av), bv) in c_slice.iter_mut().zip(a_slice.iter()).zip(b_slice.iter()) {
939 cab(cv, av, bv, zp, scale);
940 }
941 return Ok(())
942 })*
943 )*
944 )?
945 }
946 $(
947 $(if c.datum_type() == $typ::datum_type() {
948 let a = a.to_plain_array_view::<$typ>()?;
949 let b = b.to_plain_array_view::<$typ>()?;
950 let mut c_plain = c.try_as_plain_mut()?;
951 let mut c = c_plain.to_array_view_mut::<$typ>()?;
952 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
953 return Ok(())
954 })*
955 )*
956 $(
957 $(
958 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
959 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
960 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
961 let a = a.to_plain_array_view::<$typ_dt>()?;
962 let b = b.to_plain_array_view::<$typ_dt>()?;
963 let mut c_plain = c.try_as_plain_mut()?;
964 let mut c = c_plain.to_array_view_mut::<$typ_dt>()?;
965 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
966 return Ok(())
967 }
968 )*
969 )*
970 )?
971 bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
972 }
973
974 $(fn is_commutative(&self) -> bool {
975 $is_commutative
976 })?
977 $(fn neutral_element(&self) -> Option<i64> {
978 Some($neutral_element)
979 })?
980 $(fn absorbing_element(&self) -> Option<i64> {
981 Some($absorbing_element)
982 })?
983 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
984 $(if $eval_in_a(a, b)? { return Ok(()) } )?
986 if a.shape() == b.shape() {
989 $(
990 $(if b.datum_type() == $typ::datum_type() {
991 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
992 let b_plain = b.try_as_plain()?;
993 let b_slice = b_plain.as_slice::<$typ>()?;
994 let mut a_plain = a.try_as_plain_mut()?;
995 let a_slice = a_plain.as_slice_mut::<$typ>()?;
996 debug_assert_eq!(a_slice.len(), b_slice.len());
997 for (av, bv) in a_slice.iter_mut().zip(b_slice.iter()) {
998 cab(av, &av.clone(), bv);
999 }
1000 return Ok(())
1001 })*
1002 )*
1003 $(
1004 $(
1005 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
1006 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
1007 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
1008 let b_plain = b.try_as_plain()?;
1009 let b_slice = b_plain.as_slice::<$typ_dt>()?;
1010 let mut a_plain = a.try_as_plain_mut()?;
1011 let a_slice = a_plain.as_slice_mut::<$typ_dt>()?;
1012 for (av, bv) in a_slice.iter_mut().zip(b_slice.iter()) {
1013 cab(av, &(av.clone()), bv, zp, scale);
1014 }
1015 return Ok(())
1016 })*
1017 )*
1018 )?
1019 }
1020 $(
1021 $(if b.datum_type() == $typ::datum_type() {
1022 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
1023 let b = b.to_plain_array_view::<$typ>()?;
1024 let mut a_plain = a.try_as_plain_mut()?;
1025 let mut a = a_plain.to_array_view_mut::<$typ>()?;
1026 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
1027 return Ok(())
1028 })*
1029 )*
1030 $(
1031 $(
1032 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
1033 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
1034 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
1035 let mut a_plain = a.try_as_plain_mut()?;
1036 let mut a = a_plain.to_array_view_mut::<$typ_dt>()?;
1037 let b = b.to_plain_array_view::<$typ_dt>()?;
1038 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
1039 cab(a, &(a.clone()), b, zp, scale)
1040 });
1041 return Ok(())
1042 })*
1043 )*
1044 )?
1045 bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
1046 }
1047
1048 $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
1049 $eval_override(a, b, c_dt)
1050 })?
1051
1052 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
1053 if a.unquantized() == b.unquantized() {
1054 if a.is_quantized() || !b.is_quantized() {
1055 return Ok(a)
1056 }
1057 else {
1058 return Ok(b)
1059 }
1060 }
1061 self.operating_datum_type(a, b)
1062 }
1063
1064 $(
1065 fn declutter(
1066 &self,
1067 model: &TypedModel,
1068 node: &TypedNode,
1069 ) -> TractResult<Option<TypedModelPatch>> {
1070 ($declutter)(self, model, node)
1071 }
1072 )?
1073 $(
1074 fn codegen(
1075 &self,
1076 model: &TypedModel,
1077 node: &TypedNode,
1078 a: &Arc<Tensor>,
1079 ) -> TractResult<Option<TypedModelPatch>> {
1080 ($codegen)(self, model, node, a)
1081 }
1082 )?
1083 $(
1084 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
1085 ($cost)(dt)
1086 }
1087 )?
1088 $(
1089 fn validation(&self) -> Validation {
1090 $validation
1091 }
1092 )?
1093 $(
1094 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
1095 Some(tract_linalg::BinOp::$linalg)
1096 }
1097 )?
1098 $(
1099 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
1100 ($operating_datum_type)(a, b)
1101 })?
1102
1103
1104 #[allow(unused_variables)]
1108 fn maybe_eval_qbinary_as_float_op(
1109 &self,
1110 a: &TValue,
1111 b: &TValue,
1112 c_dt: &DatumType,
1113 ) -> TractResult<Option<Tensor>> {
1114 $(
1115 fn memory_optimised_q_binary_as_float_op(
1119 a: &TValue,
1120 b: &TValue,
1121 c_dt: &DatumType,
1122 ) -> TractResult<Option<Tensor>> {
1123 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
1124 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
1125 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
1126 (a.datum_type(), b.datum_type(), c_dt)
1127 {
1128 let c_inv_scale = 1.0 / c_scale;
1129 let a = a.to_plain_array_view::<u8>()?;
1130 let b = b.to_plain_array_view::<u8>()?;
1131 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1132 let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
1133 let mut c_plain = c.try_as_plain_mut()?;
1134 let view = c_plain.to_array_view_mut::<u8>()?;
1135 $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
1136 *c = (scale_by($q_op_on_f32(
1137 ((*a as i32 - a_zp as i32) as f32 * a_scale),
1138 ((*b as i32 - b_zp as i32) as f32 * b_scale),
1139 ), c_inv_scale) as i32
1140 + *c_zp as i32)
1141 .clamp_cast()
1142 });
1143 return Ok(Some(c));
1144 }
1145 Ok(None)
1146 }
1147
1148 fn generic_q_binary_as_float_op(
1152 a: &TValue,
1153 b: &TValue,
1154 c_dt: &DatumType,
1155 accumulator_dt: DatumType
1156 ) -> TractResult<Option<Tensor>> {
1157 if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
1158 let a = a.cast_to_dt(accumulator_dt)?.into_owned();
1159 let b = b.cast_to_dt(accumulator_dt)?.into_owned();
1160 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1161 let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
1162 match accumulator_dt {
1163 DatumType::F32 => {
1164 let mut c_plain = c.try_as_plain_mut()?;
1165 let view = c_plain.to_array_view_mut::<f32>()?;
1166 $crate::ndarray::Zip::from(view).and_broadcast(a.try_as_plain()?.to_array_view()?).and_broadcast(b.try_as_plain()?.to_array_view()?).for_each(|c, a, b| {
1167 *c = $q_op_on_f32(*a,*b);
1168 })
1169 },
1170 other => bail!("unexpected accumulator data type as {:?}", other)
1171 };
1172
1173 return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
1174 }
1175 Ok(None)
1176 }
1177
1178 if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
1179 return Ok(Some(c));
1180 }
1181 if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
1182 return Ok(Some(d));
1183 }
1184 )?
1185 Ok(None)
1186 }
1187 }
1188
1189 pub fn $func() -> $crate::ops::binary::TypedBinOp {
1190 $crate::ops::binary::TypedBinOp(Box::new($Op), None)
1191 }
1192 };
1193}
1194
1195#[derive(Debug)]
1196pub(crate) struct OneUniformInput {
1197 pub uni: Arc<Tensor>,
1198 pub var: OutletId,
1199 pub left_is_uniform: bool,
1200}
1201
1202pub(crate) fn one_input_is_uniform(
1203 model: &TypedModel,
1204 node: &TypedNode,
1205) -> TractResult<Option<OneUniformInput>> {
1206 if let &[a, b] = &*model.node_input_facts(node.id)? {
1207 let uni = if let Some(a) = &a.uniform {
1208 OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
1209 } else if let Some(b) = &b.uniform {
1210 OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
1211 } else {
1212 return Ok(None);
1213 };
1214 let var_fact = [a, b][uni.left_is_uniform as usize];
1215 let uni_fact = [a, b][!uni.left_is_uniform as usize];
1216 if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
1217 return Ok(Some(uni));
1218 }
1219 }
1220 Ok(None)
1221}
1222
1223#[cfg(test)]
1224mod tests {
1225 use super::*;
1226
1227 #[test]
1239 fn opt_bin_unicast_falls_back_on_non_natural_strides() {
1240 let a_data: Vec<f32> = (0..640).map(|i| i as f32).collect();
1243 let mut a = tensor1(&a_data);
1244 a.insert_axis(0).unwrap();
1245 a.insert_axis(0).unwrap();
1246 assert_eq!(a.shape(), &[1, 1, 640]);
1247 assert_eq!(a.strides(), &[1, 1, 1]);
1248 assert_ne!(a.strides(), &*Tensor::natural_strides(a.shape()));
1249
1250 let b_data: Vec<f32> = vec![1.0; 640];
1252 let mut b = tensor1(&b_data);
1253 b.insert_axis(0).unwrap();
1254 b.insert_axis(0).unwrap();
1255 b = b.into_shape(&[1, 1, 640]).unwrap();
1258
1259 let linalg_fn = tract_linalg::bin_unicast(f32::datum_type(), BinOp::Add)
1260 .expect("f32 unicast Add kernel available");
1261 let op = OptBinUnicast { binop: Box::new(Add), eval_fn: Arc::from(linalg_fn) };
1262
1263 let out = op.eval(tvec!(a.into_tvalue(), b.into_tvalue())).unwrap();
1264 let out = &out[0];
1265 assert_eq!(out.shape(), &[1, 1, 640]);
1266 let plain = out.try_as_plain().unwrap();
1267 let out_slice = plain.as_slice::<f32>().unwrap();
1268 for (i, v) in out_slice.iter().enumerate() {
1269 assert_eq!(*v, i as f32 + 1.0, "mismatch at {i}");
1270 }
1271 }
1272}