Skip to main content

tract_core/ops/logic/
comparison.rs

1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4use crate::ops::binary::BinMiniOp;
5
6use tract_data::TooEarly;
7
8// Helper for eval_out_of_place dispatch
9fn eval_comp_oop<T: Datum + PartialOrd>(
10    a: &Tensor,
11    b: &Tensor,
12    f: impl Fn(&T, &T) -> bool,
13) -> TractResult<Tensor> {
14    let a = a.to_plain_array_view::<T>()?;
15    let b = b.to_plain_array_view::<T>()?;
16    let shape = multi_broadcast(&[a.shape(), b.shape()])?;
17    let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
18    let mut c_plain = c.try_as_plain_mut()?;
19    let mut view = c_plain.to_array_view_mut::<bool>()?;
20    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b).for_each(|c, a, b| *c = f(a, b));
21    Ok(c)
22}
23
24// Helper for TDim symbolic eval
25fn eval_tdim_symbolic(
26    session: &TurnState,
27    inputs: &TVec<TValue>,
28    prove: impl Fn(&TDim, &TDim) -> TractResult<bool>,
29) -> TractResult<Option<TVec<TValue>>> {
30    rule_if!(inputs[0].datum_type() == TDim::datum_type());
31    let mut a = inputs[0].clone().into_tensor();
32    let mut b = inputs[1].clone().into_tensor();
33    for a in a.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
34        *a = a.eval(&session.resolved_symbols);
35    }
36    for b in b.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
37        *b = b.eval(&session.resolved_symbols);
38    }
39    if let (Ok(a_i64), Ok(b_i64)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
40        let result = eval_comp_oop::<i64>(&a_i64, &b_i64, |a, b| {
41            prove(&(*a).into(), &(*b).into()).unwrap_or(false)
42        })?;
43        return Ok(Some(tvec!(result.into_tvalue())));
44    }
45    let a_view = inputs[0].to_plain_array_view::<TDim>()?;
46    let b_view = inputs[1].to_plain_array_view::<TDim>()?;
47    let shape = multi_broadcast(&[a_view.shape(), b_view.shape()])?;
48    let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
49    let mut c_plain = c.try_as_plain_mut()?;
50    let mut view = c_plain.to_array_view_mut::<bool>()?;
51    let a_bc = a_view.broadcast(&*shape).unwrap();
52    let b_bc = b_view.broadcast(&*shape).unwrap();
53    for ixs in tract_ndarray::indices(&*shape) {
54        view[&ixs] = prove(&a_bc[&ixs], &b_bc[&ixs])?;
55    }
56    Ok(Some(tvec!(c.into_tvalue())))
57}
58
59macro_rules! comp_bin_mini_op {
60    ($Op:ident, $name:literal, $cmp:tt, $prove_tdim:expr, $uniform_tdim:expr) => {
61        #[derive(Debug, Clone, Hash, PartialEq, Eq)]
62        pub struct $Op;
63
64        impl BinMiniOp for $Op {
65            fn name(&self) -> &'static str {
66                $name
67            }
68
69            fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
70                Ok(bool::datum_type())
71            }
72
73            fn is_commutative(&self) -> bool {
74                false
75            }
76
77            fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
78                bail!("Comparison changes datum type, eval_in_a not supported")
79            }
80
81            fn eval_out_of_place(
82                &self,
83                c: &mut Tensor,
84                a: &Tensor,
85                b: &Tensor,
86            ) -> TractResult<()> {
87                let dt = a.datum_type();
88                if dt == String::datum_type() {
89                    let a = a.to_plain_array_view::<String>()?;
90                    let b = b.to_plain_array_view::<String>()?;
91                    let mut c_plain = c.try_as_plain_mut()?;
92                    let mut view = c_plain.to_array_view_mut::<bool>()?;
93                    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
94                        .for_each(|c, a, b| *c = a $cmp b);
95                    return Ok(());
96                }
97                fn inner<T: Datum + PartialOrd>(c: &mut Tensor, a: &Tensor, b: &Tensor, f: impl Fn(&T, &T) -> bool) -> TractResult<()> {
98                    let a = a.to_plain_array_view::<T>()?;
99                    let b = b.to_plain_array_view::<T>()?;
100                    let mut c_plain = c.try_as_plain_mut()?;
101                    let mut view = c_plain.to_array_view_mut::<bool>()?;
102                    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
103                        .for_each(|c, a, b| *c = f(a, b));
104                    Ok(())
105                }
106                dispatch_numbers!(inner(dt)(c, a, b, |a: &_, b: &_| a $cmp b))
107            }
108
109            fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
110                let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
111                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
112                self.eval_out_of_place(&mut c, &a, &b)?;
113                Ok(c)
114            }
115
116            fn eval_symbolic(
117                &self,
118                session: &TurnState,
119                inputs: TVec<TValue>,
120            ) -> TractResult<Option<TVec<TValue>>> {
121                eval_tdim_symbolic(session, &inputs, $prove_tdim)
122            }
123
124            fn uniform_tdim_comparison(
125                &self,
126                a: &TDim,
127                b: &TDim,
128            ) -> Option<TDim> {
129                Some(($uniform_tdim)(a, b))
130            }
131        }
132    };
133}
134
135fn prove_eq(a: &TDim, b: &TDim) -> TractResult<bool> {
136    Ok(a == b)
137}
138
139fn prove_ne(a: &TDim, b: &TDim) -> TractResult<bool> {
140    Ok(a != b)
141}
142
143fn prove_gte(a: &TDim, b: &TDim) -> TractResult<bool> {
144    let diff = a.clone() - b;
145    if diff.prove_positive_or_zero() {
146        Ok(true)
147    } else if diff.prove_strict_negative() {
148        Ok(false)
149    } else {
150        bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
151    }
152}
153
154fn prove_gt(a: &TDim, b: &TDim) -> TractResult<bool> {
155    let diff = a.clone() - b;
156    if diff.prove_strict_positive() {
157        Ok(true)
158    } else if diff.prove_negative_or_zero() {
159        Ok(false)
160    } else {
161        bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
162    }
163}
164
165fn prove_lte(a: &TDim, b: &TDim) -> TractResult<bool> {
166    prove_gte(b, a)
167}
168
169fn prove_lt(a: &TDim, b: &TDim) -> TractResult<bool> {
170    prove_gt(b, a)
171}
172
173comp_bin_mini_op!(CompEq, "Eq", ==, prove_eq, |a: &TDim, b: &TDim|
174    TDim::Eq(Box::new(a.clone()), Box::new(b.clone())).reduce()
175);
176
177comp_bin_mini_op!(CompNE, "NE", !=, prove_ne, |a: &TDim, b: &TDim|
178    (TDim::Val(1) - TDim::Eq(Box::new(a.clone()), Box::new(b.clone()))).reduce()
179);
180
181comp_bin_mini_op!(CompLT, "LT", <, prove_lt, |a: &TDim, b: &TDim|
182    TDim::Ge(Box::new(b.clone()), Box::new((a.clone() + TDim::Val(1)).reduce())).reduce()
183);
184
185comp_bin_mini_op!(CompGT, "GT", >, prove_gt, |a: &TDim, b: &TDim|
186    TDim::Ge(Box::new((a.clone() + TDim::Val(1)).reduce()), Box::new(b.clone())).reduce()
187);
188
189comp_bin_mini_op!(CompLTE, "LTE", <=, prove_lte, |a: &TDim, b: &TDim|
190    TDim::Ge(Box::new(b.clone()), Box::new(a.clone())).reduce()
191);
192
193comp_bin_mini_op!(CompGTE, "GTE", >=, prove_gte, |a: &TDim, b: &TDim|
194    TDim::Ge(Box::new(a.clone()), Box::new(b.clone())).reduce()
195);
196
197// Factory functions
198pub fn comp_eq() -> Box<dyn BinMiniOp> {
199    Box::new(CompEq)
200}
201pub fn comp_ne() -> Box<dyn BinMiniOp> {
202    Box::new(CompNE)
203}
204pub fn comp_lt() -> Box<dyn BinMiniOp> {
205    Box::new(CompLT)
206}
207pub fn comp_gt() -> Box<dyn BinMiniOp> {
208    Box::new(CompGT)
209}
210pub fn comp_lte() -> Box<dyn BinMiniOp> {
211    Box::new(CompLTE)
212}
213pub fn comp_gte() -> Box<dyn BinMiniOp> {
214    Box::new(CompGTE)
215}