1use crate::internal::*;
2use crate::ndarray::Dimension;
3use downcast_rs::Downcast;
4use dyn_eq::DynEq;
5use std::fmt::{self, Debug};
6use tract_data::itertools::izip;
7use tract_itertools::Itertools;
8use tract_linalg::{BinOp, LinalgFn};
9
10use super::math::{Add, Max, Min, Mul, Sub};
11use super::{cast::cast, math::SubF};
12
13pub trait BinMiniOp:
14 fmt::Debug + dyn_clone::DynClone + dyn_eq::DynEq + Send + Sync + 'static + Downcast
15{
16 fn name(&self) -> &'static str;
17 fn validation(&self) -> Validation {
18 Validation::Accurate
19 }
20 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
21 a.common_super_type(b).with_context(|| format_err!("No super type for {:?} and {:?}", a, b))
22 }
23 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType>;
24 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
25 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()>;
26
27 fn is_commutative(&self) -> bool {
28 true
29 }
30 fn neutral_element(&self) -> Option<i64> {
31 None
32 }
33 fn absorbing_element(&self) -> Option<i64> {
34 None
35 }
36
37 #[allow(unused_variables)]
38 fn maybe_eval_qbinary_as_float_op(
39 &self,
40 a: &TValue,
41 b: &TValue,
42 c_dt: &DatumType,
43 ) -> TractResult<Option<Tensor>> {
44 Ok(None)
45 }
46
47 fn generic_eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
48 if let Some(tensor) = self.maybe_eval_qbinary_as_float_op(&a, &b, &c_dt)? {
49 Ok(tensor)
50 } else {
51 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
52 if &*c_shape == a.shape() && c_dt == a.datum_type() {
53 let mut a = a.into_tensor();
54 self.eval_in_a(&mut a, &b)?;
55 Ok(a)
56 } else {
57 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
58 self.eval_out_of_place(&mut c, &a, &b)?;
59 Ok(c)
60 }
61 }
62 }
63 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
64 self.generic_eval(a, b, c_dt)
65 }
66 #[allow(unused_variables)]
67 fn declutter(
68 &self,
69 model: &TypedModel,
70 node: &TypedNode,
71 ) -> TractResult<Option<TypedModelPatch>> {
72 Ok(None)
73 }
74 #[allow(unused_variables)]
75 fn codegen(
76 &self,
77 model: &TypedModel,
78 node: &TypedNode,
79 ) -> TractResult<Option<TypedModelPatch>> {
80 Ok(None)
81 }
82 #[allow(unused_variables)]
83 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
84 tvec!()
85 }
86 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
87 None
88 }
89
90 #[allow(unused_variables)]
92 fn eval_symbolic(
93 &self,
94 session: &TurnState,
95 inputs: TVec<TValue>,
96 ) -> TractResult<Option<TVec<TValue>>> {
97 Ok(None)
98 }
99
100 #[allow(unused_variables)]
102 fn uniform_tdim_comparison(&self, a: &TDim, b: &TDim) -> Option<TDim> {
103 None
104 }
105}
106dyn_clone::clone_trait_object!(BinMiniOp);
107dyn_eq::eq_trait_object!(BinMiniOp);
108downcast_rs::impl_downcast!(BinMiniOp);
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct TypedBinOp(pub Box<dyn BinMiniOp>, pub Option<DatumType>);
112
113impl Op for TypedBinOp {
114 fn name(&self) -> StaticName {
115 self.0.name().into()
116 }
117
118 fn validation(&self) -> Validation {
119 self.0.validation()
120 }
121
122 op_as_typed_op!();
123}
124
125impl TypedBinOp {
126 fn output_datum_type(&self, a_dt: DatumType, b_dt: DatumType) -> TractResult<DatumType> {
127 if let Some(dt) = self.1 { Ok(dt) } else { self.0.result_datum_type(a_dt, b_dt) }
128 }
129}
130
131impl EvalOp for TypedBinOp {
132 fn is_stateless(&self) -> bool {
133 true
134 }
135
136 fn eval_with_session(
137 &self,
138 _node_id: usize,
139 session: &TurnState,
140 inputs: TVec<TValue>,
141 ) -> TractResult<TVec<TValue>> {
142 if let Some(result) = self.0.eval_symbolic(session, inputs.clone())? {
143 return Ok(result);
144 }
145 let (a, b) = args_2!(inputs);
146 ensure!(a.rank() == b.rank());
147 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
148 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
149 }
150
151 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
152 let (a, b) = args_2!(inputs);
153 ensure!(a.rank() == b.rank());
154 let c_dt = self.output_datum_type(a.datum_type(), b.datum_type())?;
155 Ok(tvec!(self.0.eval(a, b, c_dt)?.into_tvalue()))
156 }
157}
158
159impl TypedBinOp {
160 fn combine_uniform_tdim(&self, a: &TDim, b: &TDim) -> Option<TDim> {
161 if let Some(result) = self.0.uniform_tdim_comparison(a, b) {
163 return Some(result);
164 }
165 let a = tensor0(a.clone()).into_tvalue();
166 let b = tensor0(b.clone()).into_tvalue();
167 let result = self.0.eval(a, b, TDim::datum_type()).ok()?;
168 result
169 .try_as_plain()
170 .ok()
171 .and_then(|d| d.as_slice::<TDim>().ok())
172 .and_then(|s| s.first())
173 .cloned()
174 .map(|d| d.reduce())
175 }
176
177 fn combine_uniform_tdim_with_konst(&self, a: &TDim, konst: &Tensor) -> Option<TDim> {
178 if konst.len() != 1 {
179 return None;
180 }
181 let b_int: Option<i64> =
183 if konst.datum_type().is_integer() || konst.datum_type().is::<bool>() {
184 konst.cast_to_scalar::<i64>().ok()
185 } else if konst.datum_type().is_float() {
186 konst.cast_to_scalar::<f64>().ok().and_then(|f| {
187 if (f - f.round()).abs() < 1e-6 { Some(f.round() as i64) } else { None }
188 })
189 } else {
190 None
191 };
192 if let Some(b) = b_int {
193 return self.combine_uniform_tdim(a, &TDim::Val(b));
194 }
195 if self.0.neutral_element() == Some(1)
197 && let Some(f) = konst.cast_to_scalar::<f64>().ok().filter(|&f| f > 0.0)
198 {
199 let n = (1.0 / f).round() as u64;
200 if n >= 2 && (f * n as f64 - 1.0).abs() < 1e-6 {
201 return Some(TDim::Div(Box::new(a.clone()), n).reduce());
202 }
203 }
204 None
205 }
206}
207
208impl TypedOp for TypedBinOp {
209 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210 if inputs[0].rank() != inputs[1].rank() {
211 bail!(
212 "Typed ops require rank match. Invalid inputs for {}: {}",
213 self.name(),
214 inputs.iter().map(|s| format!("{s:?}")).join(" ; ")
215 );
216 }
217 let out_dt = self.output_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
218 let mut fact = out_dt.fact(&*crate::broadcast::multi_broadcast(&[
219 &inputs[0].shape.to_tvec(),
220 &inputs[1].shape.to_tvec(),
221 ])?);
222 if let (Some(a), Some(b)) = (&inputs[0].uniform_tdim, &inputs[1].uniform_tdim) {
223 fact.uniform_tdim = self.combine_uniform_tdim(a, b);
224 if fact.uniform_tdim.is_none() && self.0.is::<crate::ops::logic::And>() {
226 fact.uniform_tdim = Some(TDim::Mul(vec![a.clone(), b.clone()]).reduce());
227 }
228 }
229 if fact.uniform_tdim.is_none() {
231 for (expr, konst_fact) in [
232 (inputs[0].uniform_tdim.as_ref(), inputs[1]),
233 (inputs[1].uniform_tdim.as_ref(), inputs[0]),
234 ] {
235 let Some(a) = expr else { continue };
236 let Some(konst) = konst_fact.konst.as_ref() else { continue };
237 fact.uniform_tdim = self.combine_uniform_tdim_with_konst(a, konst);
238 if fact.uniform_tdim.is_some() {
239 break;
240 }
241 }
242 }
243 Ok(tvec!(fact))
244 }
245
246 fn input_roi(
247 &self,
248 model: &TypedModel,
249 node: &TypedNode,
250 ) -> TractResult<Option<TVec<Option<TDim>>>> {
251 if self.0.neutral_element() == Some(1) {
254 for (mask_ix, other_ix) in [(0usize, 1usize), (1, 0)] {
255 let fact = model.outlet_fact(node.inputs[mask_ix])?;
256 if let Some(mask_expr) = &fact.uniform_tdim {
257 let mut rois = tvec![None; node.inputs.len()];
258 rois[other_ix] = Some(mask_expr.clone());
259 return Ok(Some(rois));
260 }
261 }
262 }
263 crate::optim::propagate_roi::bubble_roi(model, node)
265 }
266
267 fn change_axes(
268 &self,
269 model: &TypedModel,
270 node: &TypedNode,
271 _io: InOut,
272 change: &AxisOp,
273 ) -> TractResult<Option<AxisChangeConsequence>> {
274 if let AxisOp::Rm(rm) = change {
275 let (inputs, outputs) = model.node_facts(node.id)?;
276 if inputs.len() >= 2
277 && outputs.len() >= 1
278 && inputs[0].rank() > *rm
279 && inputs[1].rank() > *rm
280 && outputs[0].rank() > *rm
281 {
282 rule_if!(inputs[0].shape[*rm].is_one());
283 rule_if!(inputs[1].shape[*rm].is_one());
284 rule_if!(outputs[0].shape[*rm].is_one());
285 }
286 }
287 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
288 }
289
290 fn axes_mapping(
291 &self,
292 inputs: &[&TypedFact],
293 outputs: &[&TypedFact],
294 ) -> TractResult<AxesMapping> {
295 AxesMapping::natural(inputs, outputs)
296 }
297
298 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
299 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
300 Ok(self
301 .0
302 .cost_per_element(inputs[0].datum_type)
303 .into_iter()
304 .map(|(c, n)| (c, count.clone() * n))
305 .collect())
306 }
307
308 fn slice(
309 &self,
310 patch: &mut TypedModelPatch,
311 _model: &TypedModel,
312 _node: &TypedNode,
313 prefix: &str,
314 inputs: &[OutletId],
315 _output_axis: usize,
316 _start: &TDim,
317 _end: &TDim,
318 ) -> TractResult<Option<TVec<OutletId>>> {
319 Ok(Some(patch.wire_node(prefix, self.clone(), inputs)?))
320 }
321
322 fn declutter(
323 &self,
324 model: &TypedModel,
325 node: &TypedNode,
326 ) -> TractResult<Option<TypedModelPatch>> {
327 let (a_dt, b_dt) = if let &[a, b] = &*model.node_input_facts(node.id)? {
328 (a.datum_type().unwrap(), b.datum_type().unwrap())
329 } else {
330 unreachable!("TypedBinOp has two inputs.")
331 };
332 if let Some(neutral_patch) =
333 declutter_neutral(model, node, self.0.as_ref(), self.output_datum_type(a_dt, b_dt)?)?
334 {
335 return Ok(Some(neutral_patch));
336 }
337 if let Some(absorbing_patch) = declutter_absorbing(model, node, self.0.as_ref())? {
338 return Ok(Some(absorbing_patch));
339 }
340 if let Some(broadcast_patch) =
341 declutter_broadcasting_operand_1(model, node, self.0.clone())?
342 {
343 return Ok(Some(broadcast_patch));
344 }
345 self.0.declutter(model, node)
346 }
347
348 fn codegen(
349 &self,
350 model: &TypedModel,
351 node: &TypedNode,
352 ) -> TractResult<Option<TypedModelPatch>> {
353 if let Some(linalg_bin_op) = self.0.as_linalg_binop() {
354 let input_facts = model.node_input_facts(node.id)?;
355 let must_swap_inputs =
356 input_facts.iter().collect_tuple().is_some_and(|(a_fact, b_fact)| {
357 (a_fact.shape.volume() - b_fact.shape.volume()).prove_strict_negative()
358 });
359 let (operand_1, operand_2) = if must_swap_inputs {
360 (input_facts[1], input_facts[0])
361 } else {
362 (input_facts[0], input_facts[1])
363 };
364
365 let (by_scalar_should_be_efficient, unicast_should_be_efficient) =
366 find_most_efficient_config(model, node, must_swap_inputs)?;
367
368 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
370 let op_is_quant = c_dt.is_quantized()
371 || operand_1.datum_type.is_quantized()
372 || operand_2.datum_type.is_quantized();
373
374 let c_dt = self.output_datum_type(operand_1.datum_type, operand_2.datum_type)?;
376 let c_shape = crate::broadcast::multi_broadcast(&[
377 operand_1.shape.clone(),
378 operand_2.shape.clone(),
379 ])?;
380 let can_eval_in_a =
381 (c_shape.to_vec() == operand_1.shape.to_vec()) && (c_dt == operand_1.datum_type);
382
383 let inputs = if must_swap_inputs {
385 let mut swap_input = node.inputs.clone();
386 swap_input.swap(0, 1);
387 swap_input
388 } else {
389 node.inputs.clone()
390 };
391 let actual_linalg_op =
392 if must_swap_inputs { linalg_bin_op.flip() } else { linalg_bin_op };
393 let actual_core_op = core_op_for_linalg_op(&actual_linalg_op);
394
395 let dt = model.node_input_facts(node.id)?[0].datum_type;
396 if by_scalar_should_be_efficient & can_eval_in_a & !op_is_quant {
397 rule_if_some!(func = tract_linalg::bin_by_scalar(dt, actual_linalg_op));
398 let eval_fn = Arc::from(func);
399 return Ok(Some(
400 TypedModelPatch::replace_single_op(
401 model,
402 node,
403 &inputs,
404 OptBinByScalar { binop: actual_core_op, eval_fn },
405 )?
406 .with_context("ByScalar"),
407 ));
408 }
409
410 if unicast_should_be_efficient & can_eval_in_a & !op_is_quant {
411 rule_if_some!(func = tract_linalg::bin_unicast(dt, actual_linalg_op));
412 let eval_fn = Arc::from(func);
413 return Ok(Some(
414 TypedModelPatch::replace_single_op(
415 model,
416 node,
417 &inputs,
418 OptBinUnicast { binop: actual_core_op, eval_fn },
419 )?
420 .with_context("Unicast"),
421 ));
422 }
423 }
424
425 Ok(None)
426 }
427 as_op!();
428}
429
430fn core_op_for_linalg_op(linalg: &BinOp) -> Box<dyn BinMiniOp> {
431 match linalg {
432 BinOp::Min => Box::new(Min),
433 BinOp::Max => Box::new(Max),
434 BinOp::Add => Box::new(Add),
435 BinOp::Mul => Box::new(Mul),
436 BinOp::Sub => Box::new(Sub),
437 BinOp::SubF => Box::new(SubF),
438 }
439}
440fn declutter_broadcasting_operand_1(
441 model: &TypedModel,
442 node: &TypedNode,
443 mini_op: Box<dyn BinMiniOp>,
444) -> TractResult<Option<TypedModelPatch>> {
445 let (a_shape, b_shape) = if let &[a, b] = &*model.node_input_facts(node.id)? {
446 (a.shape.clone(), b.shape.clone())
447 } else {
448 unreachable!("TypedBinOp has two inputs.")
449 };
450
451 let a_num_elements = a_shape.iter().product::<TDim>();
452 let b_num_elements = b_shape.iter().product::<TDim>();
453 let a_should_be_broadcast = (a_num_elements - b_num_elements).prove_strict_negative();
454 if a_should_be_broadcast & mini_op.is_commutative() {
455 let mut swap_input = node.inputs.clone();
456 swap_input.swap(0, 1);
457 return Ok(Some(TypedModelPatch::replace_single_op(
458 model,
459 node,
460 &swap_input,
461 TypedBinOp(mini_op, None),
462 )?));
463 }
464
465 Ok(None)
466}
467
468fn declutter_neutral(
469 model: &TypedModel,
470 node: &TypedNode,
471 mini_op: &dyn BinMiniOp,
472 out_dt: DatumType,
473) -> TractResult<Option<TypedModelPatch>> {
474 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
475 let is_neutral = mini_op
476 .neutral_element()
477 .map(|neutral| tensor0(neutral).close_enough(&uniform.uni, false).is_ok())
478 .unwrap_or(false);
479
480 let pos_checked = mini_op.is_commutative() || !uniform.left_is_uniform;
483
484 if is_neutral && pos_checked {
485 if uniform.uni.datum_type().is_quantized() {
490 return Ok(Some(TypedModelPatch::replace_single_op(
491 model,
492 node,
493 &[node.inputs[0]],
494 cast(out_dt),
495 )?));
496 } else {
498 return Ok(Some(TypedModelPatch::rewire(
499 model,
500 &[uniform.var],
501 &[node.id.into()],
502 &|_, inputs| Ok(inputs.into()),
503 )?));
504 }
505 }
506 }
507 Ok(None)
508}
509
510fn declutter_absorbing(
513 model: &TypedModel,
514 node: &TypedNode,
515 mini_op: &dyn BinMiniOp,
516) -> TractResult<Option<TypedModelPatch>> {
517 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
518 let is_absorbing = mini_op
519 .absorbing_element()
520 .map(|absorb| tensor0(absorb).close_enough(&uniform.uni, false).is_ok())
521 .unwrap_or(false);
522 if is_absorbing {
523 let uni_inlet = if uniform.left_is_uniform { 0 } else { 1 };
524 return Ok(Some(TypedModelPatch::rewire(
525 model,
526 &[node.inputs[uni_inlet]],
527 &[node.id.into()],
528 &|_, inputs| Ok(inputs.into()),
529 )?));
530 }
531 }
532 Ok(None)
533}
534
535fn find_most_efficient_config(
536 model: &TypedModel,
537 node: &TypedNode,
538 swap_input: bool,
539) -> TractResult<(bool, bool)> {
540 if let &[a, b] = &*model.node_input_facts(node.id)? {
541 let a_shape = if swap_input { b.shape.clone() } else { a.shape.clone() };
542 let b_shape = if swap_input { a.shape.clone() } else { b.shape.clone() };
543
544 let by_scalar_is_possible = OptBinByScalar::check_input_shapes(&a_shape, &b_shape);
545 let num_by_scalar_elements = if by_scalar_is_possible {
546 a_shape
547 .iter()
548 .zip(b_shape.iter())
549 .rev()
550 .take_while(|(_, rev_b_dim)| **rev_b_dim == TDim::Val(1))
551 .map(|(rev_a_dim, _)| rev_a_dim)
552 .product::<TDim>()
553 } else {
554 TDim::Val(0)
555 };
556
557 let unicast_is_possible = OptBinUnicast::check_input_shapes(&a_shape, &b_shape);
558 let num_unicast_elements = if unicast_is_possible {
559 a_shape
560 .iter()
561 .zip(b_shape.iter())
562 .rev()
563 .take_while(|(a_dim, b_dim)| a_dim == b_dim)
564 .map(|(a_dim, _)| a_dim)
565 .product::<TDim>()
566 } else {
567 TDim::Val(0)
568 };
569
570 let min_num_elements = 32;
571 let by_scalar_should_be_efficient = gt_tdim(num_by_scalar_elements, min_num_elements);
572 let unicast_should_be_efficient = gt_tdim(num_unicast_elements, min_num_elements);
573 return Ok((by_scalar_should_be_efficient, unicast_should_be_efficient));
574 }
575 Ok((false, false))
576}
577
578pub fn gt_tdim(x: TDim, min_val: i64) -> bool {
579 TDim::Val(min_val).mini(x).to_i64().is_ok_and(|v| v == min_val)
580}
581
582#[derive(Clone)]
583pub struct OptBinByScalar {
584 pub binop: Box<dyn BinMiniOp>,
585 eval_fn: Arc<LinalgFn>,
586}
587
588impl Debug for OptBinByScalar {
589 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
590 f.debug_struct("OptBinByScalar").field("binop", &self.binop).finish()
591 }
592}
593
594impl OptBinByScalar {
595 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
596 if a_shape.len() != b_shape.len() {
597 return false;
598 };
599
600 a_shape
601 .iter()
602 .zip(b_shape.iter())
603 .skip_while(|(a_dim, b_dim)| a_dim == b_dim)
604 .all(|(_, b_dim)| *b_dim == 1.to_dim())
605 }
606}
607
608impl PartialEq for OptBinByScalar {
609 fn eq(&self, other: &Self) -> bool {
610 *self.binop == *other.binop
611 }
612}
613impl Eq for OptBinByScalar {}
614
615impl Op for OptBinByScalar {
616 fn name(&self) -> StaticName {
617 format!("Opt{}ByScalar", self.binop.name()).into()
618 }
619
620 op_as_typed_op!();
621}
622
623impl EvalOp for OptBinByScalar {
624 fn is_stateless(&self) -> bool {
625 true
626 }
627
628 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
629 let (a, b) = args_2!(inputs);
630 let a = a.into_tensor();
633 let b_shape = b.shape();
634
635 let first_unary_axis = b_shape
636 .iter()
637 .enumerate()
638 .rev()
639 .take_while(|&(_, &dim)| dim == 1)
640 .map(|(i, _)| i)
641 .last()
642 .context("Cannot use by_scalar when no trailing dimensions are unary")?;
643
644 let iterating_shape = &a.shape()[..first_unary_axis];
645 if !iterating_shape.is_empty() {
646 for it_coords in tract_ndarray::indices(iterating_shape) {
647 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
648 let b_view = TensorView::at_prefix(&b, it_coords.slice())?;
649 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
650 (self.eval_fn)(&mut view, &b_view)?;
651 }
652 } else {
653 let mut view = a.view();
654 let b_view = b.view();
655 debug_assert_eq!(b_view.shape().iter().product::<usize>(), 1);
656 (self.eval_fn)(&mut view, &b_view)?;
657 }
658 Ok(tvec!(a.into_tvalue()))
659 }
660}
661
662impl TypedOp for OptBinByScalar {
663 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
664 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
665 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
666 let out_shape = inputs[0].shape.clone();
667 Ok(tvec!(out_dt.fact(out_shape)))
668 }
669
670 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
671 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
672 Ok(self
673 .binop
674 .cost_per_element(inputs[0].datum_type)
675 .into_iter()
676 .map(|(c, n)| (c, count.clone() * n))
677 .collect())
678 }
679
680 as_op!();
681}
682
683#[derive(Clone)]
684pub struct OptBinUnicast {
685 pub binop: Box<dyn BinMiniOp>,
686 eval_fn: Arc<LinalgFn>,
687}
688
689impl Debug for OptBinUnicast {
690 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
691 f.debug_struct("OptBinUnicast").field("binop", &self.binop).finish()
692 }
693}
694
695impl OptBinUnicast {
696 fn check_b_alignement(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
697 let num_iterations: TDim = a_shape
698 .iter()
699 .zip(b_shape.iter())
700 .take_while(|(_, b_dim)| **b_dim == 1.to_dim())
701 .map(|(a_dim, _)| a_dim)
702 .product();
703
704 if num_iterations.is_one() {
705 return true;
706 }
707
708 let elements_per_iteration: TDim = a_shape
709 .iter()
710 .zip(b_shape.iter())
711 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
712 .map(|(_, b_dim)| b_dim)
713 .product();
714
715 if let Ok(num_element) = elements_per_iteration.to_i64() {
716 let required_alignment = vector_size();
717 (num_element as usize).is_multiple_of(required_alignment)
718 } else {
719 false
720 }
721 }
722 fn check_input_shapes(a_shape: &[TDim], b_shape: &[TDim]) -> bool {
723 if a_shape.len() != b_shape.len() {
724 return false;
725 };
726
727 let unicast_possible = a_shape
728 .iter()
729 .zip(b_shape.iter())
730 .skip_while(|(_, b_dim)| **b_dim == 1.to_dim())
731 .all(|(a_dim, b_dim)| a_dim == b_dim);
732 let unicast_is_aligned = Self::check_b_alignement(a_shape, b_shape);
733
734 unicast_possible && unicast_is_aligned
735 }
736}
737
738impl PartialEq for OptBinUnicast {
739 fn eq(&self, other: &Self) -> bool {
740 *self.binop == *other.binop
741 }
742}
743impl Eq for OptBinUnicast {}
744
745impl Op for OptBinUnicast {
746 fn name(&self) -> StaticName {
747 format!("Opt{}Unicast", self.binop.name()).into()
748 }
749
750 op_as_typed_op!();
751}
752
753impl EvalOp for OptBinUnicast {
754 fn is_stateless(&self) -> bool {
755 true
756 }
757
758 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
759 let (a, b) = args_2!(inputs);
760 let a = a.into_tensor();
763 let b_shape = b.shape();
764 let b_view = b.view();
765 let first_non_unary_axis =
766 b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i + 1).last();
767
768 if let Some(first_non_unary_axis) = first_non_unary_axis {
769 let iterating_shape = a.shape()[..first_non_unary_axis].to_vec();
771 for it_coords in tract_ndarray::indices(iterating_shape) {
772 let mut view = TensorView::at_prefix(&a, it_coords.slice())?;
773 debug_assert_eq!(view.shape(), &b_view.shape()[it_coords.slice().len()..]);
774 (self.eval_fn)(&mut view, &b_view)?;
775 }
776 } else {
777 let mut view = a.view();
778 debug_assert_eq!(view.shape(), b_view.shape());
779 (self.eval_fn)(&mut view, &b_view)?;
780 }
781
782 Ok(tvec!(a.into_tvalue()))
783 }
784}
785
786impl TypedOp for OptBinUnicast {
787 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
788 ensure!(Self::check_input_shapes(&inputs[0].shape, &inputs[1].shape));
789 let out_dt = self.binop.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?;
790 let out_shape = inputs[0].shape.clone();
791 Ok(tvec!(out_dt.fact(out_shape)))
792 }
793
794 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
795 let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
796 Ok(self
797 .binop
798 .cost_per_element(inputs[0].datum_type)
799 .into_iter()
800 .map(|(c, n)| (c, count.clone() * n))
801 .collect())
802 }
803
804 as_op!();
805}
806
807#[macro_export]
808macro_rules! bin_to_super_type {
809 ($func:ident, $Op:ident,
810 $(codegen: $codegen:expr,)?
811 $(cost: $cost:expr,)?
812 $(declutter: $declutter:expr,)?
813 $(eval_in_a: $eval_in_a:expr,)?
814 $(eval_override: $eval_override: expr,)?
815 $(linalg: $linalg:ident,)?
816 $(operating_datum_type: $operating_datum_type:expr,)?
817 $(is_commutative: $is_commutative:expr,)?
818 $(neutral_element: $neutral_element:expr,)?
819 $(absorbing_element: $absorbing_element:expr,)?
820 $(out_of_place: $out_of_place:expr,)?
821 $(validation: $validation:expr,)?
822 $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
823 $(q_op_on_f32: $q_op_on_f32:expr,)?
824 $( [$($typ:ident),*] => $cab:expr),*) => {
825 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
826 pub struct $Op;
827 #[allow(clippy::redundant_closure_call)]
828 impl $crate::ops::binary::BinMiniOp for $Op {
829 fn name(&self) -> &'static str {
830 stringify!($Op)
831 }
832
833 fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
834 $(if $out_of_place(c, a, b)? { return Ok(()) } )?
835 $(
836 $(if c.datum_type() == $typ::datum_type() {
837 let a = a.to_plain_array_view::<$typ>()?;
838 let b = b.to_plain_array_view::<$typ>()?;
839 let mut c_plain = c.try_as_plain_mut()?;
840 let mut c = c_plain.to_array_view_mut::<$typ>()?;
841 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
842 return Ok(())
843 })*
844 )*
845 $(
846 $(
847 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
848 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
849 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
850 let a = a.to_plain_array_view::<$typ_dt>()?;
851 let b = b.to_plain_array_view::<$typ_dt>()?;
852 let mut c_plain = c.try_as_plain_mut()?;
853 let mut c = c_plain.to_array_view_mut::<$typ_dt>()?;
854 $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
855 return Ok(())
856 }
857 )*
858 )*
859 )?
860 bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
861 }
862
863 $(fn is_commutative(&self) -> bool {
864 $is_commutative
865 })?
866 $(fn neutral_element(&self) -> Option<i64> {
867 Some($neutral_element)
868 })?
869 $(fn absorbing_element(&self) -> Option<i64> {
870 Some($absorbing_element)
871 })?
872 fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
873 $(if $eval_in_a(a, b)? { return Ok(()) } )?
875 $(
876 $(if b.datum_type() == $typ::datum_type() {
877 let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
878 let b = b.to_plain_array_view::<$typ>()?;
879 let mut a_plain = a.try_as_plain_mut()?;
880 let mut a = a_plain.to_array_view_mut::<$typ>()?;
881 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
882 return Ok(())
883 })*
884 )*
885 $(
886 $(
887 $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
888 let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
889 let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
890 let mut a_plain = a.try_as_plain_mut()?;
891 let mut a = a_plain.to_array_view_mut::<$typ_dt>()?;
892 let b = b.to_plain_array_view::<$typ_dt>()?;
893 $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| {
894 cab(a, &(a.clone()), b, zp, scale)
895 });
896 return Ok(())
897 })*
898 )*
899 )?
900 bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
901 }
902
903 $(fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
904 $eval_override(a, b, c_dt)
905 })?
906
907 fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
908 if a.unquantized() == b.unquantized() {
909 if a.is_quantized() || !b.is_quantized() {
910 return Ok(a)
911 }
912 else {
913 return Ok(b)
914 }
915 }
916 self.operating_datum_type(a, b)
917 }
918
919 $(
920 fn declutter(
921 &self,
922 model: &TypedModel,
923 node: &TypedNode,
924 ) -> TractResult<Option<TypedModelPatch>> {
925 ($declutter)(self, model, node)
926 }
927 )?
928 $(
929 fn codegen(
930 &self,
931 model: &TypedModel,
932 node: &TypedNode,
933 a: &Arc<Tensor>,
934 ) -> TractResult<Option<TypedModelPatch>> {
935 ($codegen)(self, model, node, a)
936 }
937 )?
938 $(
939 fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
940 ($cost)(dt)
941 }
942 )?
943 $(
944 fn validation(&self) -> Validation {
945 $validation
946 }
947 )?
948 $(
949 fn as_linalg_binop(&self) -> Option<tract_linalg::BinOp> {
950 Some(tract_linalg::BinOp::$linalg)
951 }
952 )?
953 $(
954 fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
955 ($operating_datum_type)(a, b)
956 })?
957
958
959 #[allow(unused_variables)]
963 fn maybe_eval_qbinary_as_float_op(
964 &self,
965 a: &TValue,
966 b: &TValue,
967 c_dt: &DatumType,
968 ) -> TractResult<Option<Tensor>> {
969 $(
970 fn memory_optimised_q_binary_as_float_op(
974 a: &TValue,
975 b: &TValue,
976 c_dt: &DatumType,
977 ) -> TractResult<Option<Tensor>> {
978 if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
979 DatumType::QU8(QParams::ZpScale {zero_point: b_zp, scale: b_scale}),
980 DatumType::QU8(QParams::ZpScale {zero_point: c_zp, scale: c_scale})) =
981 (a.datum_type(), b.datum_type(), c_dt)
982 {
983 let c_inv_scale = 1.0 / c_scale;
984 let a = a.to_plain_array_view::<u8>()?;
985 let b = b.to_plain_array_view::<u8>()?;
986 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
987 let mut c = Tensor::zero_dt(*c_dt, &c_shape)?;
988 let mut c_plain = c.try_as_plain_mut()?;
989 let view = c_plain.to_array_view_mut::<u8>()?;
990 $crate::ndarray::Zip::from(view).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| {
991 *c = (scale_by($q_op_on_f32(
992 ((*a as i32 - a_zp as i32) as f32 * a_scale),
993 ((*b as i32 - b_zp as i32) as f32 * b_scale),
994 ), c_inv_scale) as i32
995 + *c_zp as i32)
996 .clamp_cast()
997 });
998 return Ok(Some(c));
999 }
1000 Ok(None)
1001 }
1002
1003 fn generic_q_binary_as_float_op(
1007 a: &TValue,
1008 b: &TValue,
1009 c_dt: &DatumType,
1010 accumulator_dt: DatumType
1011 ) -> TractResult<Option<Tensor>> {
1012 if a.datum_type().is_quantized() && b.datum_type().is_quantized() && c_dt.is_quantized() {
1013 let a = a.cast_to_dt(accumulator_dt)?.into_owned();
1014 let b = b.cast_to_dt(accumulator_dt)?.into_owned();
1015 let c_shape = $crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
1016 let mut c = Tensor::zero_dt(accumulator_dt, &c_shape)?;
1017 match accumulator_dt {
1018 DatumType::F32 => {
1019 let mut c_plain = c.try_as_plain_mut()?;
1020 let view = c_plain.to_array_view_mut::<f32>()?;
1021 $crate::ndarray::Zip::from(view).and_broadcast(a.try_as_plain()?.to_array_view()?).and_broadcast(b.try_as_plain()?.to_array_view()?).for_each(|c, a, b| {
1022 *c = $q_op_on_f32(*a,*b);
1023 })
1024 },
1025 other => bail!("unexpected accumulator data type as {:?}", other)
1026 };
1027
1028 return Ok(Some(c.cast_to_dt(*c_dt)?.into_owned()));
1029 }
1030 Ok(None)
1031 }
1032
1033 if let Some(c) = memory_optimised_q_binary_as_float_op(a, b, c_dt)? {
1034 return Ok(Some(c));
1035 }
1036 if let Some(d) = generic_q_binary_as_float_op(a, b, c_dt, DatumType::F32)? {
1037 return Ok(Some(d));
1038 }
1039 )?
1040 Ok(None)
1041 }
1042 }
1043
1044 pub fn $func() -> $crate::ops::binary::TypedBinOp {
1045 $crate::ops::binary::TypedBinOp(Box::new($Op), None)
1046 }
1047 };
1048}
1049
1050#[derive(Debug)]
1051pub(crate) struct OneUniformInput {
1052 pub uni: Arc<Tensor>,
1053 pub var: OutletId,
1054 pub left_is_uniform: bool,
1055}
1056
1057pub(crate) fn one_input_is_uniform(
1058 model: &TypedModel,
1059 node: &TypedNode,
1060) -> TractResult<Option<OneUniformInput>> {
1061 if let &[a, b] = &*model.node_input_facts(node.id)? {
1062 let uni = if let Some(a) = &a.uniform {
1063 OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
1064 } else if let Some(b) = &b.uniform {
1065 OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
1066 } else {
1067 return Ok(None);
1068 };
1069 let var_fact = [a, b][uni.left_is_uniform as usize];
1070 let uni_fact = [a, b][!uni.left_is_uniform as usize];
1071 if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
1072 return Ok(Some(uni));
1073 }
1074 }
1075 Ok(None)
1076}