Skip to main content

tract_core/ops/logic/
comparison.rs

1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4use crate::ops::binary::BinMiniOp;
5
6use tract_data::TooEarly;
7
8// Helper for eval_out_of_place dispatch
9fn eval_comp_oop<T: Datum + PartialOrd>(
10    a: &Tensor,
11    b: &Tensor,
12    f: impl Fn(&T, &T) -> bool,
13) -> TractResult<Tensor> {
14    let a = a.to_plain_array_view::<T>()?;
15    let b = b.to_plain_array_view::<T>()?;
16    let shape = multi_broadcast(&[a.shape(), b.shape()])?;
17    let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
18    let mut c_plain = c.try_as_plain_mut()?;
19    let mut view = c_plain.to_array_view_mut::<bool>()?;
20    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b).for_each(|c, a, b| *c = f(a, b));
21    Ok(c)
22}
23
24// Helper for TDim symbolic eval
25fn eval_tdim_symbolic(
26    session: &TurnState,
27    inputs: &TVec<TValue>,
28    prove: impl Fn(&TDim, &TDim) -> TractResult<bool>,
29) -> TractResult<Option<TVec<TValue>>> {
30    if inputs[0].datum_type() != TDim::datum_type() {
31        return Ok(None);
32    }
33    let mut a = inputs[0].clone().into_tensor();
34    let mut b = inputs[1].clone().into_tensor();
35    for a in a.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
36        *a = a.eval(&session.resolved_symbols);
37    }
38    for b in b.try_as_plain_mut()?.as_slice_mut::<TDim>()? {
39        *b = b.eval(&session.resolved_symbols);
40    }
41    if let (Ok(a_i64), Ok(b_i64)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
42        let result = eval_comp_oop::<i64>(&a_i64, &b_i64, |a, b| {
43            prove(&(*a).into(), &(*b).into()).unwrap_or(false)
44        })?;
45        return Ok(Some(tvec!(result.into_tvalue())));
46    }
47    let a_view = inputs[0].to_plain_array_view::<TDim>()?;
48    let b_view = inputs[1].to_plain_array_view::<TDim>()?;
49    let shape = multi_broadcast(&[a_view.shape(), b_view.shape()])?;
50    let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
51    let mut c_plain = c.try_as_plain_mut()?;
52    let mut view = c_plain.to_array_view_mut::<bool>()?;
53    let a_bc = a_view.broadcast(&*shape).unwrap();
54    let b_bc = b_view.broadcast(&*shape).unwrap();
55    for ixs in tract_ndarray::indices(&*shape) {
56        view[&ixs] = prove(&a_bc[&ixs], &b_bc[&ixs])?;
57    }
58    Ok(Some(tvec!(c.into_tvalue())))
59}
60
61macro_rules! comp_bin_mini_op {
62    ($Op:ident, $name:literal, $cmp:tt, $prove_tdim:expr, $uniform_tdim:expr) => {
63        #[derive(Debug, Clone, Hash, PartialEq, Eq)]
64        pub struct $Op;
65
66        impl BinMiniOp for $Op {
67            fn name(&self) -> &'static str {
68                $name
69            }
70
71            fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
72                Ok(bool::datum_type())
73            }
74
75            fn is_commutative(&self) -> bool {
76                false
77            }
78
79            fn eval_in_a(&self, _a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
80                bail!("Comparison changes datum type, eval_in_a not supported")
81            }
82
83            fn eval_out_of_place(
84                &self,
85                c: &mut Tensor,
86                a: &Tensor,
87                b: &Tensor,
88            ) -> TractResult<()> {
89                let dt = a.datum_type();
90                if dt == String::datum_type() {
91                    let a = a.to_plain_array_view::<String>()?;
92                    let b = b.to_plain_array_view::<String>()?;
93                    let mut c_plain = c.try_as_plain_mut()?;
94                    let mut view = c_plain.to_array_view_mut::<bool>()?;
95                    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
96                        .for_each(|c, a, b| *c = a $cmp b);
97                    return Ok(());
98                }
99                fn inner<T: Datum + PartialOrd>(c: &mut Tensor, a: &Tensor, b: &Tensor, f: impl Fn(&T, &T) -> bool) -> TractResult<()> {
100                    let a = a.to_plain_array_view::<T>()?;
101                    let b = b.to_plain_array_view::<T>()?;
102                    let mut c_plain = c.try_as_plain_mut()?;
103                    let mut view = c_plain.to_array_view_mut::<bool>()?;
104                    Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b)
105                        .for_each(|c, a, b| *c = f(a, b));
106                    Ok(())
107                }
108                dispatch_numbers!(inner(dt)(c, a, b, |a: &_, b: &_| a $cmp b))
109            }
110
111            fn eval(&self, a: TValue, b: TValue, c_dt: DatumType) -> TractResult<Tensor> {
112                let c_shape = crate::broadcast::multi_broadcast(&[a.shape(), b.shape()])?;
113                let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
114                self.eval_out_of_place(&mut c, &a, &b)?;
115                Ok(c)
116            }
117
118            fn eval_symbolic(
119                &self,
120                session: &TurnState,
121                inputs: TVec<TValue>,
122            ) -> TractResult<Option<TVec<TValue>>> {
123                eval_tdim_symbolic(session, &inputs, $prove_tdim)
124            }
125
126            fn uniform_tdim_comparison(
127                &self,
128                a: &TDim,
129                b: &TDim,
130            ) -> Option<TDim> {
131                Some(($uniform_tdim)(a, b))
132            }
133        }
134    };
135}
136
137fn prove_eq(a: &TDim, b: &TDim) -> TractResult<bool> {
138    Ok(a == b)
139}
140
141fn prove_ne(a: &TDim, b: &TDim) -> TractResult<bool> {
142    Ok(a != b)
143}
144
145fn prove_gte(a: &TDim, b: &TDim) -> TractResult<bool> {
146    let diff = a.clone() - b;
147    if diff.prove_positive_or_zero() {
148        Ok(true)
149    } else if diff.prove_strict_negative() {
150        Ok(false)
151    } else {
152        bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
153    }
154}
155
156fn prove_gt(a: &TDim, b: &TDim) -> TractResult<bool> {
157    let diff = a.clone() - b;
158    if diff.prove_strict_positive() {
159        Ok(true)
160    } else if diff.prove_negative_or_zero() {
161        Ok(false)
162    } else {
163        bail!(TooEarly::UndeterminedSymbol(diff.to_string()))
164    }
165}
166
167fn prove_lte(a: &TDim, b: &TDim) -> TractResult<bool> {
168    prove_gte(b, a)
169}
170
171fn prove_lt(a: &TDim, b: &TDim) -> TractResult<bool> {
172    prove_gt(b, a)
173}
174
175comp_bin_mini_op!(CompEq, "Eq", ==, prove_eq, |a: &TDim, b: &TDim|
176    TDim::Eq(Box::new(a.clone()), Box::new(b.clone())).reduce()
177);
178
179comp_bin_mini_op!(CompNE, "NE", !=, prove_ne, |a: &TDim, b: &TDim|
180    (TDim::Val(1) - TDim::Eq(Box::new(a.clone()), Box::new(b.clone()))).reduce()
181);
182
183comp_bin_mini_op!(CompLT, "LT", <, prove_lt, |a: &TDim, b: &TDim|
184    TDim::Ge(Box::new(b.clone()), Box::new((a.clone() + TDim::Val(1)).reduce())).reduce()
185);
186
187comp_bin_mini_op!(CompGT, "GT", >, prove_gt, |a: &TDim, b: &TDim|
188    TDim::Ge(Box::new((a.clone() + TDim::Val(1)).reduce()), Box::new(b.clone())).reduce()
189);
190
191comp_bin_mini_op!(CompLTE, "LTE", <=, prove_lte, |a: &TDim, b: &TDim|
192    TDim::Ge(Box::new(b.clone()), Box::new(a.clone())).reduce()
193);
194
195comp_bin_mini_op!(CompGTE, "GTE", >=, prove_gte, |a: &TDim, b: &TDim|
196    TDim::Ge(Box::new(a.clone()), Box::new(b.clone())).reduce()
197);
198
199// Factory functions
200pub fn comp_eq() -> Box<dyn BinMiniOp> {
201    Box::new(CompEq)
202}
203pub fn comp_ne() -> Box<dyn BinMiniOp> {
204    Box::new(CompNE)
205}
206pub fn comp_lt() -> Box<dyn BinMiniOp> {
207    Box::new(CompLT)
208}
209pub fn comp_gt() -> Box<dyn BinMiniOp> {
210    Box::new(CompGT)
211}
212pub fn comp_lte() -> Box<dyn BinMiniOp> {
213    Box::new(CompLTE)
214}
215pub fn comp_gte() -> Box<dyn BinMiniOp> {
216    Box::new(CompGTE)
217}