1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4use crate::ops::binary::BinMiniOp;
5
6use tract_data::TooEarly;
7
8fn eval_comp_oop<T: Datum + PartialOrd>(
10 a: &Tensor,
11 b: &Tensor,
12 f: impl Fn(&T, &T) -> bool,
13) -> TractResult<Tensor> {
14 let a = a.to_plain_array_view::<T>()?;
15 let b = b.to_plain_array_view::<T>()?;
16 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
17 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
18 let mut c_plain = c.try_as_plain_mut()?;
19 let mut view = c_plain.to_array_view_mut::<bool>()?;
20 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b).for_each(|c, a, b| *c = f(a, b));
21 Ok(c)
22}
23
24fn eval_tdim_symbolic(
26 session: &TurnState,
27 inputs: &TVec<TValue>,
28 prove: impl Fn(&TDim, &TDim) -> TractResult<bool>,
29) -> TractResult<Option<TVec<TValue>>> {
30 rule_if!(inputs[0].datum_type() == TDim::datum_type());
31 let mut a = inputs[0].clone().into_tensor();
32 let mut b = inputs[1].clone().into_tensor();
33 for a in a.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
34 *a = a.eval(&session.resolved_symbols);
35 }
36 for b in b.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
37 *b = b.eval(&session.resolved_symbols);
38 }
39 if let (Ok(a_i64), Ok(b_i64)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
40 let result = eval_comp_oop::<i64>(&a_i64, &b_i64, |a, b| {
41 prove(&(*a).into(), &(*b).into()).unwrap_or(false)
42 })?;
43 return Ok(Some(tvec!(result.into_tvalue())));
44 }
45 let a_view = inputs[0].to_plain_array_view::<TDim>()?;
46 let b_view = inputs[1].to_plain_array_view::<TDim>()?;
47 let shape = multi_broadcast(&[a_view.shape(), b_view.shape()])?;
48 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
49 let mut c_plain = c.try_as_plain_mut()?;
50 let mut view = c_plain.to_array_view_mut::<bool>()?;
51 let a_bc = a_view.broadcast(&*shape).unwrap();
52 let b_bc = b_view.broadcast(&*shape).unwrap();
53 for ixs in tract_ndarray::indices(&*shape) {
54 view[&ixs] = prove(&a_bc[&ixs], &b_bc[&ixs])?;
55 }
56 Ok(Some(tvec!(c.into_tvalue())))
57}
58
59macro_rules! comp_bin_mini_op {
60 ($Op:ident, $name:literal, $cmp:tt, $prove_tdim:expr, $uniform_tdim:expr) => {
61 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
62 pub struct $Op;
63
64 impl BinMiniOp for $Op {
65 fn name(&self) -> &'static str {
66 $name
67 }
68
69 fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
70 Ok(bool::datum_type())
71 }
72
73 fn is_commutative(&self) -> bool {
74 false
75 }
76
77 fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
78 bail!("Comparison changes datum type, eval_in_a not supported")
79 }
80
81 fn eval_out_of_place(
82 &self,
83 c: &mut Tensor,
84 a: &Tensor,
85 b: &Tensor,
86 ) -> TractResult<()> {
87 let dt = a.datum_type();
88 if dt == String::datum_type() {
89 let a = a.to_plain_array_view::<String>()?;
90 let b = b.to_plain_array_view::<String>()?;
91 let mut c_plain = c.try_as_plain_mut()?;
92 let mut view = c_plain.to_array_view_mut::<bool>()?;
93 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
94 .for_each(|c, a, b| *c = a $cmp b);
95 return Ok(());
96 }
97 fn inner<T: Datum + PartialOrd>(c: &mut Tensor, a: &Tensor, b: &Tensor, f: impl Fn(&T, &T) -> bool) -> TractResult<()> {
98 let a = a.to_plain_array_view::<T>()?;
99 let b = b.to_plain_array_view::<T>()?;
100 let mut c_plain = c.try_as_plain_mut()?;
101 let mut view = c_plain.to_array_view_mut::<bool>()?;
102 Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
103 .for_each(|c, a, b| *c = f(a, b));
104 Ok(())
105 }
106 dispatch_numbers!(inner(dt)(c, a, b, |a: &_, b: &_| a $cmp b))
107 }
108
109 fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
110 let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
111 let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
112 self.eval_out_of_place(&mut c, &a, &b)?;
113 Ok(c)
114 }
115
116 fn eval_symbolic(
117 &self,
118 session: &TurnState,
119 inputs: TVec<TValue>,
120 ) -> TractResult<Option<TVec<TValue>>> {
121 eval_tdim_symbolic(session, &inputs, $prove_tdim)
122 }
123
124 fn uniform_tdim_comparison(
125 &self,
126 a: &TDim,
127 b: &TDim,
128 ) -> Option<TDim> {
129 Some(($uniform_tdim)(a, b))
130 }
131 }
132 };
133}
134
135fn prove_eq(a: &TDim, b: &TDim) -> TractResult<bool> {
136 Ok(a == b)
137}
138
139fn prove_ne(a: &TDim, b: &TDim) -> TractResult<bool> {
140 Ok(a != b)
141}
142
143fn prove_gte(a: &TDim, b: &TDim) -> TractResult<bool> {
144 let diff = a.clone() - b;
145 if diff.prove_positive_or_zero() {
146 Ok(true)
147 } else if diff.prove_strict_negative() {
148 Ok(false)
149 } else {
150 bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
151 }
152}
153
154fn prove_gt(a: &TDim, b: &TDim) -> TractResult<bool> {
155 let diff = a.clone() - b;
156 if diff.prove_strict_positive() {
157 Ok(true)
158 } else if diff.prove_negative_or_zero() {
159 Ok(false)
160 } else {
161 bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
162 }
163}
164
165fn prove_lte(a: &TDim, b: &TDim) -> TractResult<bool> {
166 prove_gte(b, a)
167}
168
169fn prove_lt(a: &TDim, b: &TDim) -> TractResult<bool> {
170 prove_gt(b, a)
171}
172
173comp_bin_mini_op!(CompEq, "Eq", ==, prove_eq, |a: &TDim, b: &TDim|
174 TDim::Eq(Box::new(a.clone()), Box::new(b.clone())).reduce()
175);
176
177comp_bin_mini_op!(CompNE, "NE", !=, prove_ne, |a: &TDim, b: &TDim|
178 (TDim::Val(1) - TDim::Eq(Box::new(a.clone()), Box::new(b.clone()))).reduce()
179);
180
181comp_bin_mini_op!(CompLT, "LT", <, prove_lt, |a: &TDim, b: &TDim|
182 TDim::Ge(Box::new(b.clone()), Box::new((a.clone() + TDim::Val(1)).reduce())).reduce()
183);
184
185comp_bin_mini_op!(CompGT, "GT", >, prove_gt, |a: &TDim, b: &TDim|
186 TDim::Ge(Box::new((a.clone() + TDim::Val(1)).reduce()), Box::new(b.clone())).reduce()
187);
188
189comp_bin_mini_op!(CompLTE, "LTE", <=, prove_lte, |a: &TDim, b: &TDim|
190 TDim::Ge(Box::new(b.clone()), Box::new(a.clone())).reduce()
191);
192
193comp_bin_mini_op!(CompGTE, "GTE", >=, prove_gte, |a: &TDim, b: &TDim|
194 TDim::Ge(Box::new(a.clone()), Box::new(b.clone())).reduce()
195);
196
197pub fn comp_eq() -> Box<dyn BinMiniOp> {
199 Box::new(CompEq)
200}
201pub fn comp_ne() -> Box<dyn BinMiniOp> {
202 Box::new(CompNE)
203}
204pub fn comp_lt() -> Box<dyn BinMiniOp> {
205 Box::new(CompLT)
206}
207pub fn comp_gt() -> Box<dyn BinMiniOp> {
208 Box::new(CompGT)
209}
210pub fn comp_lte() -> Box<dyn BinMiniOp> {
211 Box::new(CompLTE)
212}
213pub fn comp_gte() -> Box<dyn BinMiniOp> {
214 Box::new(CompGTE)
215}