Skip to main content

tract_core/ops/logic/
comparison.rs

1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash, PartialEq)]
6pub enum Comp {
7    Eq,
8    NE,
9    LT,
10    GT,
11    GTE,
12    LTE,
13}
14
15use Comp::*;
16use tract_data::TooEarly;
17
18impl Op for Comp {
19    fn name(&self) -> StaticName {
20        match *self {
21            Eq => "==",
22            NE => "!=",
23            LT => "<",
24            GT => ">",
25            LTE => "<=",
26            GTE => ">=",
27        }
28        .into()
29    }
30
31    op_as_typed_op!();
32}
33
34impl Comp {
35    fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36        let a = a.to_dense_array_view::<T>()?;
37        let b = b.to_dense_array_view::<T>()?;
38        let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39        let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40        let mut c_dense = c.try_as_dense_mut()?;
41        let mut view = c_dense.to_array_view_mut::<bool>()?;
42        let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
43        match *self {
44            Eq => zipped.for_each(|c, a, b| *c = a == b),
45            NE => zipped.for_each(|c, a, b| *c = a != b),
46            LT => zipped.for_each(|c, a, b| *c = a < b),
47            GT => zipped.for_each(|c, a, b| *c = a > b),
48            LTE => zipped.for_each(|c, a, b| *c = a <= b),
49            GTE => zipped.for_each(|c, a, b| *c = a >= b),
50        }
51        Ok(c)
52    }
53}
54
55impl EvalOp for Comp {
56    fn is_stateless(&self) -> bool {
57        true
58    }
59
60    fn eval_with_session(
61        &self,
62        _node_id: usize,
63        session: &TurnState,
64        inputs: TVec<TValue>,
65    ) -> TractResult<TVec<TValue>> {
66        if inputs[0].datum_type() == TDim::datum_type() {
67            let mut a = inputs[0].clone().into_tensor();
68            let mut b = inputs[1].clone().into_tensor();
69            for a in a.try_as_dense_mut()?.as_slice_mut::<TDim>()? {
70                *a = a.eval(&session.resolved_symbols);
71            }
72            for b in b.try_as_dense_mut()?.as_slice_mut::<TDim>()? {
73                *b = b.eval(&session.resolved_symbols);
74            }
75            if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
76                return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
77            }
78            let a = inputs[0].to_dense_array_view::<TDim>()?;
79            let b = inputs[0].to_dense_array_view::<TDim>()?;
80            let shape = multi_broadcast(&[a.shape(), b.shape()])?;
81            let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
82            let mut c_dense = c.try_as_dense_mut()?;
83            let mut view = c_dense.to_array_view_mut::<bool>()?;
84            let a = a.broadcast(&*shape).unwrap();
85            let b = b.broadcast(&*shape).unwrap();
86            for ixs in tract_ndarray::indices(&*shape) {
87                let (a, b) = (&a[&ixs], &b[&ixs]);
88                let diff = a.clone() - b;
89                view[&ixs] = match *self {
90                    Eq => a == b,
91                    NE => a != b,
92                    GTE => {
93                        if diff.prove_positive_or_zero() {
94                            true
95                        } else if diff.prove_strict_negative() {
96                            false
97                        } else {
98                            bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
99                        }
100                    }
101                    GT => {
102                        if diff.prove_strict_positive() {
103                            true
104                        } else if diff.prove_negative_or_zero() {
105                            false
106                        } else {
107                            bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
108                        }
109                    }
110                    LTE => {
111                        if diff.prove_negative_or_zero() {
112                            true
113                        } else if diff.prove_strict_positive() {
114                            false
115                        } else {
116                            bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
117                        }
118                    }
119                    LT => {
120                        if diff.prove_strict_negative() {
121                            true
122                        } else if diff.prove_negative_or_zero() {
123                            false
124                        } else {
125                            bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
126                        }
127                    }
128                };
129            }
130            Ok(tvec!(c.into_tvalue()))
131        } else if inputs[0].datum_type().is::<String>() {
132            let t = self.eval::<String>(&inputs[0], &inputs[1])?;
133            Ok(tvec!(t.into_tvalue()))
134        } else {
135            let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
136                self, &inputs[0], &inputs[1]
137            ))?;
138            Ok(tvec!(t.into_tvalue()))
139        }
140    }
141}
142
143impl TypedOp for Comp {
144    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
145        let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
146        Ok(tvec!(bool::datum_type().fact(shape)))
147    }
148
149    fn change_axes(
150        &self,
151        model: &TypedModel,
152        node: &TypedNode,
153        _io: InOut,
154        change: &AxisOp,
155    ) -> TractResult<Option<AxisChangeConsequence>> {
156        if let AxisOp::Rm(rm) = change {
157            let (inputs, outputs) = model.node_facts(node.id)?;
158            rule_if!(inputs[0].shape[*rm].is_one());
159            rule_if!(inputs[1].shape[*rm].is_one());
160            rule_if!(outputs[0].shape[*rm].is_one());
161        }
162        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
163    }
164
165    fn slice(
166        &self,
167        patch: &mut TypedModelPatch,
168        _model: &TypedModel,
169        _node: &TypedNode,
170        prefix: &str,
171        inputs: &[OutletId],
172        _output_axis: usize,
173        _start: &TDim,
174        _end: &TDim,
175    ) -> TractResult<Option<TVec<OutletId>>> {
176        Ok(Some(patch.wire_node(prefix, *self, inputs)?))
177    }
178
179    fn axes_mapping(
180        &self,
181        inputs: &[&TypedFact],
182        outputs: &[&TypedFact],
183    ) -> TractResult<AxesMapping> {
184        AxesMapping::natural(inputs, outputs)
185    }
186
187    as_op!();
188}