1#![allow(clippy::bool_comparison)]
2#![allow(clippy::unnecessary_cast)]
3
4mod comparison;
5mod ite;
6pub use comparison::{CompEq, CompGT, CompGTE, CompLT, CompLTE, CompNE};
7pub use comparison::{comp_eq, comp_gt, comp_gte, comp_lt, comp_lte, comp_ne};
8pub use ite::IfThenElse;
9
10use ndarray::*;
11
12use crate::broadcast::multi_broadcast;
13use crate::internal::*;
14
15bin_to_super_type!(and, And,
16 neutral_element: 1,
17 absorbing_element: 0,
18 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 && b as i64 != 0) as _);
19bin_to_super_type!(or, Or,
20 neutral_element: 0,
21 absorbing_element: 1,
22 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = (a as i64 != 0 || b as i64 != 0) as _);
23bin_to_super_type!(xor, Xor, declutter: declutter_xor, neutral_element: 0, [bool] => |c, &a, &b| *c = a ^ b);
24
25fn declutter_xor(
26 _op: &Xor,
27 model: &TypedModel,
28 node: &TypedNode,
29) -> TractResult<Option<TypedModelPatch>> {
30 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
32 if tensor0(1i64).close_enough(&uniform.uni, false).is_ok() {
33 return Ok(Some(TypedModelPatch::replace_single_op(
34 model,
35 node,
36 &[uniform.var],
37 crate::ops::element_wise::ElementWiseOp(Box::new(Not {}), None),
38 )?));
39 }
40 }
41 Ok(None)
42}
43
44element_wise!(not, Not, [bool] => |_, vs| {
45 vs.iter_mut().for_each(|a| *a = !*a);
46 Ok(())
47});
48
49#[derive(Debug, Clone, new, Default, Hash, PartialEq, Eq)]
50pub struct Iff;
51
52impl Iff {
53 pub unsafe fn eval_t<T: Datum>(
54 cond: &ArrayViewD<bool>,
55 out: &mut Tensor,
56 t: &Tensor,
57 f: &Tensor,
58 ) {
59 unsafe {
60 Zip::from(out.to_array_view_mut_unchecked::<T>())
61 .and_broadcast(cond)
62 .and_broadcast(t.to_array_view_unchecked::<T>())
63 .and_broadcast(f.to_array_view_unchecked::<T>())
64 .for_each(|r, c, t, f| *r = if *c { t.clone() } else { f.clone() })
65 }
66 }
67}
68
69impl Op for Iff {
70 fn name(&self) -> StaticName {
71 "Iff".into()
72 }
73 op_as_typed_op!();
74}
75
76impl EvalOp for Iff {
77 fn is_stateless(&self) -> bool {
78 true
79 }
80
81 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
82 let (cond, t, f) = args_3!(inputs);
83 anyhow::ensure!(t.datum_type() == f.datum_type());
84 let shape: TVec<usize> = multi_broadcast(&[cond.shape(), t.shape(), f.shape()])?;
85 unsafe {
86 let mut result = Tensor::uninitialized_dt(t.datum_type(), &shape)?;
87 let cond = cond.to_plain_array_view::<bool>()?;
88 dispatch_datum_by_size!(Self::eval_t(t.datum_type())(&cond, &mut result, &t, &f));
89 Ok(tvec!(result.into_tvalue()))
90 }
91 }
92}
93
94pub fn sym_to_coord_axis(sym: &Symbol) -> Option<usize> {
95 format!("{sym}").strip_prefix("🎯")?.parse::<usize>().ok()
96}
97
98pub(crate) fn coord_bound_assertions(expr: &TDim, shape: &ShapeFact) -> Vec<Assertion> {
99 expr.symbols()
100 .into_iter()
101 .filter_map(|s| sym_to_coord_axis(&s).filter(|k| *k < shape.rank()).map(|k| (k, s)))
102 .flat_map(|(k, sym)| {
103 [
104 Assertion::GTE(TDim::Sym(sym.clone()), TDim::Val(0)),
105 Assertion::LTE(TDim::Sym(sym), shape[k].clone() - TDim::Val(1)),
106 ]
107 })
108 .collect()
109}
110
111pub(crate) fn is_provably_all_false(expr: &TDim, shape: &ShapeFact) -> bool {
112 let extra = coord_bound_assertions(expr, shape);
113 expr.clone().simplify_with_extra_assertions(&extra) == TDim::Val(0)
114}
115
116pub(crate) fn is_provably_all_true(expr: &TDim, shape: &ShapeFact) -> bool {
117 let extra = coord_bound_assertions(expr, shape);
118 expr.clone().simplify_with_extra_assertions(&extra) == TDim::Val(1)
119}
120
121#[derive(Debug, Clone)]
133pub(crate) struct TrueRange {
134 pub axis: usize,
135 pub start: Option<TDim>, pub end: Option<TDim>, }
138
139impl TrueRange {
140 pub fn is_full(&self) -> bool {
142 self.start.is_none() && self.end.is_none()
143 }
144 pub fn is_empty(&self) -> bool {
146 match (&self.start, &self.end) {
147 (None, Some(e)) => *e == TDim::Val(0),
148 (Some(s), Some(e)) => s == e,
149 _ => false,
150 }
151 }
152}
153
154pub(crate) fn classify_true_range(expr: &TDim, shape: &ShapeFact) -> Option<TrueRange> {
155 fn try_ge(ge: &TDim, shape: &ShapeFact) -> Option<(usize, TDim)> {
156 if let TDim::Ge(lhs, rhs) = ge {
157 if let TDim::Sym(sym) = &**lhs {
158 let k = sym_to_coord_axis(sym)?;
159 if k < shape.rank() && !rhs.symbols().contains(sym) {
160 return Some((k, *rhs.clone()));
161 }
162 }
163 }
164 None
165 }
166
167 let simplified = expr.clone().simplify();
168 if simplified == TDim::Val(0) || is_provably_all_false(&simplified, shape) {
170 return Some(TrueRange { axis: 0, start: None, end: Some(TDim::Val(0)) });
171 }
172 if simplified == TDim::Val(1) || is_provably_all_true(&simplified, shape) {
174 return Some(TrueRange { axis: 0, start: None, end: None });
175 }
176 if let Some((axis, split)) = try_ge(&simplified, shape) {
178 return Some(TrueRange { axis, start: Some(split), end: None });
179 }
180 let flipped = (TDim::Val(1) - simplified).simplify();
182 if let Some((axis, split)) = try_ge(&flipped, shape) {
183 return Some(TrueRange { axis, start: None, end: Some(split) });
184 }
185 None
186}
187
188impl TypedOp for Iff {
189 as_op!();
190
191 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
192 ensure!(inputs.len() == 3, "Iff expects 3 intputs.");
193 ensure!(inputs[1].datum_type == inputs[2].datum_type);
194 ensure!(inputs[0].datum_type.is::<bool>());
195 ensure!(inputs[0].rank() == inputs[1].rank());
196 ensure!(inputs[0].rank() == inputs[2].rank());
197 let shape = multi_broadcast(&[
198 inputs[0].shape.to_tvec(),
199 inputs[1].shape.to_tvec(),
200 inputs[2].shape.to_tvec(),
201 ])
202 .unwrap();
203 let mut fact = inputs[1].datum_type.fact(shape);
204 fact.uniform_tdim = match inputs[0].uniform_tdim.as_ref().map(|d| d.clone().simplify()) {
206 Some(TDim::Val(0)) => inputs[2].uniform_tdim.clone(), Some(TDim::Val(_)) => inputs[1].uniform_tdim.clone(), _ => None,
209 };
210 Ok(tvec!(fact))
211 }
212
213 fn input_roi(
214 &self,
215 model: &TypedModel,
216 node: &TypedNode,
217 ) -> TractResult<Option<TVec<Option<TDim>>>> {
218 let cond_fact = model.outlet_fact(node.inputs[0])?;
222 if let Some(cond_expr) = &cond_fact.uniform_tdim {
223 let cond = cond_expr.clone().simplify();
224 let not_cond = TDim::Eq(Box::new(cond.clone()), Box::new(TDim::Val(0))).simplify();
225 return Ok(Some(tvec![None, Some(cond), Some(not_cond)]));
226 }
227 crate::optim::propagate_roi::bubble_roi(model, node)
229 }
230
231 fn declutter(
232 &self,
233 model: &TypedModel,
234 node: &TypedNode,
235 ) -> TractResult<Option<TypedModelPatch>> {
236 let cond_fact = model.outlet_fact(node.inputs[0])?;
240 rule_if_some!(uniform = &cond_fact.uniform);
241 let Ok(cond_val) = uniform.cast_to_scalar::<bool>() else { return Ok(None) };
242 let branch = if cond_val { node.inputs[1] } else { node.inputs[2] };
243 let mut patch = TypedModelPatch::default();
244 let wire = patch.tap_model(model, branch)?;
245 patch.shunt_outside(model, node.id.into(), wire)?;
246 Ok(Some(patch))
247 }
248
249 fn axes_mapping(
250 &self,
251 inputs: &[&TypedFact],
252 outputs: &[&TypedFact],
253 ) -> TractResult<AxesMapping> {
254 AxesMapping::natural(inputs, outputs)
255 }
256}
257
258bin_to_super_type!(bitand, BitAnd,
259 absorbing_element: 0,
260 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a & b);
261bin_to_super_type!(bitor, BitOr,
262 neutral_element: 0,
263 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a | b);
264bin_to_super_type!(bitxor, BitXor,
265 declutter: declutter_bitxor,
266 neutral_element: 0,
267 [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |c, &a, &b| *c = a ^ b);
268
269fn declutter_bitxor(
270 _op: &BitXor,
271 model: &TypedModel,
272 node: &TypedNode,
273) -> TractResult<Option<TypedModelPatch>> {
274 if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
276 let var_dt = model.outlet_fact(uniform.var)?.datum_type;
277 let is_all_ones = if var_dt.is::<bool>() {
278 tensor0(1i64).close_enough(&uniform.uni, false).is_ok()
279 } else {
280 tensor0(-1i64).close_enough(&uniform.uni, false).is_ok()
281 };
282 if is_all_ones {
283 return Ok(Some(TypedModelPatch::replace_single_op(
284 model,
285 node,
286 &[uniform.var],
287 crate::ops::element_wise::ElementWiseOp(Box::new(BitNot {}), None),
288 )?));
289 }
290 }
291 Ok(None)
292}
293
294element_wise!(bitnot, BitNot, [bool, u8, u16, u32, u64, i8, i16, i32, i64] => |_, xs| {
295 xs.iter_mut().for_each(|x| *x = !*x);
296 Ok(())
297});
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::ops::array::TypedConcat;
303 use crate::ops::binary::TypedBinOp;
304 use crate::ops::change_axes::AxisOp;
305
306 #[test]
309 fn iff_fold_case1_eq_t_zero() -> TractResult<()> {
310 let mut model = TypedModel::default();
311 model.symbols.add_assertion("T >= 1")?;
312 let t_sym = model.symbols.sym("T");
313 let t_dim = TDim::Sym(t_sym.clone());
314
315 let t_wire = model.wire_node(
317 "T",
318 crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
319 &[],
320 )?[0];
321
322 let zero_wire = model.wire_node(
324 "zero",
325 crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
326 &[],
327 )?[0];
328
329 let eq_wire = model.wire_node("eq", TypedBinOp(comp_eq(), None), &[t_wire, zero_wire])?[0];
331
332 let data_wire = model.add_source("data", TDim::datum_type().scalar_fact())?;
334
335 let iff_wire = model.wire_node("iff", Iff, &[eq_wire, zero_wire, data_wire])?[0];
337 model.select_output_outlets(&[iff_wire])?;
338
339 let model = model.into_decluttered()?;
340
341 let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
343 assert_eq!(iff_count, 0, "Expected Iff to be folded, but found {iff_count} Iff nodes");
344 Ok(())
345 }
346
347 #[test]
351 fn iff_fold_case2_not_lt_x1_t() -> TractResult<()> {
352 use crate::ops::array::Range;
353
354 let mut model = TypedModel::default();
355 model.symbols.add_assertion("T >= 1")?;
356 let t_sym = model.symbols.sym("T");
357 let t_dim = TDim::Sym(t_sym.clone());
358
359 let start = model.wire_node(
361 "start",
362 crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
363 &[],
364 )?[0];
365 let step = model.wire_node(
366 "step",
367 crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
368 &[],
369 )?[0];
370 let end = model.add_source("T_dyn", TDim::datum_type().scalar_fact())?;
373
374 let range = model.wire_node("range", Range::new(t_dim.clone()), &[start, end, step])?[0];
376
377 let range_unsq = model.wire_node("range_unsq", AxisOp::Add(0), &[range])?[0];
379
380 let t_const = model.wire_node(
382 "T_const",
383 crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
384 &[],
385 )?[0];
386 let t_unsq = model.wire_node("T_unsq", AxisOp::Add(0), &[t_const])?[0];
388 let t_unsq2 = model.wire_node("T_unsq2", AxisOp::Add(0), &[t_unsq])?[0];
389
390 let lt = model.wire_node("lt", TypedBinOp(comp_lt(), None), &[range_unsq, t_unsq2])?[0];
392
393 let bn = model.wire_node("bitnot", bitnot(), &[lt])?[0];
396
397 let data_shape = tvec![TDim::Val(1), t_dim.clone()];
399 let data = model.add_source("data", TDim::datum_type().fact(data_shape.clone()))?;
400
401 let zero_scalar = model.wire_node(
403 "zero_s",
404 crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
405 &[],
406 )?[0];
407 let zeros = model.wire_node(
408 "zeros",
409 crate::ops::array::MultiBroadcastTo {
410 shape: ShapeFact::from_dims(data_shape.iter().cloned()),
411 },
412 &[zero_scalar],
413 )?[0];
414
415 let iff = model.wire_node("iff", Iff, &[bn, zeros, data])?[0];
417 model.select_output_outlets(&[iff])?;
418
419 let model = model.into_decluttered()?;
420
421 let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
422 assert_eq!(iff_count, 0, "Expected Iff to be folded, but found {iff_count} Iff nodes");
423 Ok(())
424 }
425
426 #[test]
428 fn iff_split_to_slice_concat() -> TractResult<()> {
429 use crate::ops::array::Range;
430
431 let mut model = TypedModel::default();
432 model.symbols.add_assertion("T >= 160")?;
433 let t_sym = model.symbols.sym("T");
434 let t_dim = TDim::Sym(t_sym.clone());
435
436 let split = t_dim.clone() / 160;
438 let out_len = TDim::Val(1) + split.clone();
440
441 let start = model.wire_node(
447 "start",
448 crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
449 &[],
450 )?[0];
451 let step = model.wire_node(
452 "step",
453 crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
454 &[],
455 )?[0];
456 let end_val = model.wire_node(
457 "end_val",
458 crate::ops::konst::Const::new(tensor0(out_len.clone()).into_arc_tensor())?,
459 &[],
460 )?[0];
461 let range =
462 model.wire_node("range", Range::new(out_len.clone()), &[start, end_val, step])?[0];
463 let r1 = model.wire_node("r1", AxisOp::Add(0), &[range])?[0];
465 let r2 = model.wire_node("r2", AxisOp::Add(0), &[r1])?[0];
467
468 let split_const = model.wire_node(
470 "split_const",
471 crate::ops::konst::Const::new(tensor0(split.clone()).into_arc_tensor())?,
472 &[],
473 )?[0];
474 let sc1 = model.wire_node("sc1", AxisOp::Add(0), &[split_const])?[0];
476 let sc2 = model.wire_node("sc2", AxisOp::Add(0), &[sc1])?[0];
477 let sc2 = model.wire_node("sc3", AxisOp::Add(0), &[sc2])?[0];
478
479 let cond = model.wire_node("cond", TypedBinOp(comp_gte(), None), &[r2, sc2])?[0];
481
482 let true_branch = model.add_source(
484 "true_b",
485 TDim::datum_type().fact(tvec![TDim::Val(1), TDim::Val(1), out_len.clone()]),
486 )?;
487 let false_branch = model.add_source(
488 "false_b",
489 TDim::datum_type().fact(tvec![TDim::Val(1), TDim::Val(1), out_len.clone()]),
490 )?;
491
492 let iff = model.wire_node("iff", Iff, &[cond, true_branch, false_branch])?[0];
493 model.select_output_outlets(&[iff])?;
494
495 let model = model.into_decluttered()?;
496
497 let iff_count = model.nodes().iter().filter(|n| n.op_as::<Iff>().is_some()).count();
498 assert_eq!(iff_count, 0, "Expected no Iff nodes after declutter, found {iff_count}");
499
500 let concat_count =
501 model.nodes().iter().filter(|n| n.op_as::<TypedConcat>().is_some()).count();
502 assert!(concat_count > 0, "Expected at least one Concat node after declutter");
503
504 Ok(())
505 }
506
507 #[test]
509 fn verify_uniform_tdim_propagation() -> TractResult<()> {
510 use crate::ops::array::Range;
511
512 let mut model = TypedModel::default();
513 model.symbols.add_assertion("T >= 1")?;
514 let t_sym = model.symbols.sym("T");
515 let t_dim = TDim::Sym(t_sym.clone());
516
517 let start = model.wire_node(
518 "start",
519 crate::ops::konst::Const::new(tensor0(TDim::Val(0)).into_arc_tensor())?,
520 &[],
521 )?[0];
522 let step = model.wire_node(
523 "step",
524 crate::ops::konst::Const::new(tensor0(TDim::Val(1)).into_arc_tensor())?,
525 &[],
526 )?[0];
527 let end = model.add_source("T_dyn", TDim::datum_type().scalar_fact())?;
528 let range = model.wire_node("range", Range::new(t_dim.clone()), &[start, end, step])?[0];
529 let range_unsq = model.wire_node("range_unsq", AxisOp::Add(0), &[range])?[0];
530 let t_const = model.wire_node(
531 "T_const",
532 crate::ops::konst::Const::new(tensor0(t_dim.clone()).into_arc_tensor())?,
533 &[],
534 )?[0];
535 let t_unsq = model.wire_node("T_unsq", AxisOp::Add(0), &[t_const])?[0];
536 let t_unsq2 = model.wire_node("T_unsq2", AxisOp::Add(0), &[t_unsq])?[0];
537 let lt = model.wire_node("lt", TypedBinOp(comp_lt(), None), &[range_unsq, t_unsq2])?[0];
538
539 let range_fact = model.outlet_fact(range)?;
540 let range_unsq_fact = model.outlet_fact(range_unsq)?;
541 let t_unsq_fact = model.outlet_fact(t_unsq)?;
542 let lt_fact = model.outlet_fact(lt)?;
543
544 assert!(range_fact.uniform_tdim.is_some(), "range should have uniform_tdim");
545 assert!(range_unsq_fact.uniform_tdim.is_some(), "range_unsq should have uniform_tdim");
546 assert!(t_unsq_fact.uniform_tdim.is_some(), "t_unsq should have uniform_tdim");
547 assert!(lt_fact.uniform_tdim.is_some(), "lt should have uniform_tdim");
548
549 Ok(())
550 }
551}