1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4use crate::ops::binary::BinMiniOp;
5
6use tract_data::TooEarly;
7
8fn eval_comp_oop<T: Datum + PartialOrd>(
10 a: &Tensor,
11 b: &Tensor,
12 f: impl Fn(&T, &T) -> bool,
13) -> TractResult<Tensor> {
14 let a = a.to_plain_array_view::<T>()?;
15 let b = b.to_plain_array_view::<T>()?;
16 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
17 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
18 let mut c_plain = c.try_as_plain_mut()?;
19 let mut view = c_plain.to_array_view_mut::<bool>()?;
20 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b).for_each(|c, a, b| *c = f(a, b));
21 Ok(c)
22}
23
24fn eval_tdim_symbolic(
26 session: &TurnState,
27 inputs: &TVec<TValue>,
28 prove: impl Fn(&TDim, &TDim) -> TractResult<bool>,
29) -> TractResult<Option<TVec<TValue>>> {
30 if inputs[0].datum_type() != TDim::datum_type() {
31 return Ok(None);
32 }
33 let mut a = inputs[0].clone().into_tensor();
34 let mut b = inputs[1].clone().into_tensor();
35 for a in a.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
36 *a = a.eval(&session.resolved_symbols);
37 }
38 for b in b.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
39 *b = b.eval(&session.resolved_symbols);
40 }
41 if let (Ok(a_i64), Ok(b_i64)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
42 let result = eval_comp_oop::<i64>(&a_i64, &b_i64, |a, b| {
43 prove(&(*a).into(), &(*b).into()).unwrap_or(false)
44 })?;
45 return Ok(Some(tvec!(result.into_tvalue())));
46 }
47 let a_view = inputs[0].to_plain_array_view::<TDim>()?;
48 let b_view = inputs[1].to_plain_array_view::<TDim>()?;
49 let shape = multi_broadcast(&[a_view.shape(), b_view.shape()])?;
50 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
51 let mut c_plain = c.try_as_plain_mut()?;
52 let mut view = c_plain.to_array_view_mut::<bool>()?;
53 let a_bc = a_view.broadcast(&*shape).unwrap();
54 let b_bc = b_view.broadcast(&*shape).unwrap();
55 for ixs in tract_ndarray::indices(&*shape) {
56 view[&ixs] = prove(&a_bc[&ixs], &b_bc[&ixs])?;
57 }
58 Ok(Some(tvec!(c.into_tvalue())))
59}
60
61macro_rules! comp_bin_mini_op {
62 ($Op:ident, $name:literal, $cmp:tt, $prove_tdim:expr, $uniform_tdim:expr) => {
63 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
64 pub struct $Op;
65
66 impl BinMiniOp for $Op {
67 fn name(&self) -> &'static str {
68 $name
69 }
70
71 fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
72 Ok(bool::datum_type())
73 }
74
75 fn is_commutative(&self) -> bool {
76 false
77 }
78
79 fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
80 bail!("Comparison changes datum type, eval_in_a not supported")
81 }
82
83 fn eval_out_of_place(
84 &self,
85 c: &mut Tensor,
86 a: &Tensor,
87 b: &Tensor,
88 ) -> TractResult<()> {
89 let dt = a.datum_type();
90 if dt == String::datum_type() {
91 let a = a.to_plain_array_view::<String>()?;
92 let b = b.to_plain_array_view::<String>()?;
93 let mut c_plain = c.try_as_plain_mut()?;
94 let mut view = c_plain.to_array_view_mut::<bool>()?;
95 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
96 .for_each(|c, a, b| *c = a $cmp b);
97 return Ok(());
98 }
99 fn inner<T: Datum + PartialOrd>(c: &mut Tensor, a: &Tensor, b: &Tensor, f: impl Fn(&T, &T) -> bool) -> TractResult<()> {
100 let a = a.to_plain_array_view::<T>()?;
101 let b = b.to_plain_array_view::<T>()?;
102 let mut c_plain = c.try_as_plain_mut()?;
103 let mut view = c_plain.to_array_view_mut::<bool>()?;
104 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
105 .for_each(|c, a, b| *c = f(a, b));
106 Ok(())
107 }
108 dispatch_numbers!(inner(dt)(c, a, b, |a: &_, b: &_| a $cmp b))
109 }
110
111 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
112 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
113 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
114 self.eval_out_of_place(&mut c, &a, &b)?;
115 Ok(c)
116 }
117
118 fn eval_symbolic(
119 &self,
120 session: &TurnState,
121 inputs: TVec<TValue>,
122 ) -> TractResult<Option<TVec<TValue>>> {
123 eval_tdim_symbolic(session, &inputs, $prove_tdim)
124 }
125
126 fn uniform_tdim_comparison(
127 &self,
128 a: &TDim,
129 b: &TDim,
130 ) -> Option<TDim> {
131 Some(($uniform_tdim)(a, b))
132 }
133 }
134 };
135}
136
137fn prove_eq(a: &TDim, b: &TDim) -> TractResult<bool> {
138 Ok(a == b)
139}
140
141fn prove_ne(a: &TDim, b: &TDim) -> TractResult<bool> {
142 Ok(a != b)
143}
144
145fn prove_gte(a: &TDim, b: &TDim) -> TractResult<bool> {
146 let diff = a.clone() - b;
147 if diff.prove_positive_or_zero() {
148 Ok(true)
149 } else if diff.prove_strict_negative() {
150 Ok(false)
151 } else {
152 bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
153 }
154}
155
156fn prove_gt(a: &TDim, b: &TDim) -> TractResult<bool> {
157 let diff = a.clone() - b;
158 if diff.prove_strict_positive() {
159 Ok(true)
160 } else if diff.prove_negative_or_zero() {
161 Ok(false)
162 } else {
163 bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
164 }
165}
166
167fn prove_lte(a: &TDim, b: &TDim) -> TractResult<bool> {
168 prove_gte(b, a)
169}
170
171fn prove_lt(a: &TDim, b: &TDim) -> TractResult<bool> {
172 prove_gt(b, a)
173}
174
175comp_bin_mini_op!(CompEq, "Eq", ==, prove_eq, |a: &TDim, b: &TDim|
176 TDim::Eq(Box::new(a.clone()), Box::new(b.clone())).reduce()
177);
178
179comp_bin_mini_op!(CompNE, "NE", !=, prove_ne, |a: &TDim, b: &TDim|
180 (TDim::Val(1) - TDim::Eq(Box::new(a.clone()), Box::new(b.clone()))).reduce()
181);
182
183comp_bin_mini_op!(CompLT, "LT", <, prove_lt, |a: &TDim, b: &TDim|
184 TDim::Ge(Box::new(b.clone()), Box::new((a.clone() + TDim::Val(1)).reduce())).reduce()
185);
186
187comp_bin_mini_op!(CompGT, "GT", >, prove_gt, |a: &TDim, b: &TDim|
188 TDim::Ge(Box::new((a.clone() + TDim::Val(1)).reduce()), Box::new(b.clone())).reduce()
189);
190
191comp_bin_mini_op!(CompLTE, "LTE", <=, prove_lte, |a: &TDim, b: &TDim|
192 TDim::Ge(Box::new(b.clone()), Box::new(a.clone())).reduce()
193);
194
195comp_bin_mini_op!(CompGTE, "GTE", >=, prove_gte, |a: &TDim, b: &TDim|
196 TDim::Ge(Box::new(a.clone()), Box::new(b.clone())).reduce()
197);
198
199pub fn comp_eq() -> Box<dyn BinMiniOp> {
201 Box::new(CompEq)
202}
203pub fn comp_ne() -> Box<dyn BinMiniOp> {
204 Box::new(CompNE)
205}
206pub fn comp_lt() -> Box<dyn BinMiniOp> {
207 Box::new(CompLT)
208}
209pub fn comp_gt() -> Box<dyn BinMiniOp> {
210 Box::new(CompGT)
211}
212pub fn comp_lte() -> Box<dyn BinMiniOp> {
213 Box::new(CompLTE)
214}
215pub fn comp_gte() -> Box<dyn BinMiniOp> {
216 Box::new(CompGTE)
217}