tract_core/ops/logic/
comparison.rs

1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash)]
6pub enum Comp {
7    Eq,
8    NE,
9    LT,
10    GT,
11    GTE,
12    LTE,
13}
14
15use tract_data::TooEarly;
16use Comp::*;
17
18impl Op for Comp {
19    fn name(&self) -> Cow<str> {
20        match *self {
21            Eq => "==",
22            NE => "!=",
23            LT => "<",
24            GT => ">",
25            LTE => "<=",
26            GTE => ">=",
27        }
28        .into()
29    }
30
31    op_as_typed_op!();
32}
33
34impl Comp {
35    fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36        let a = a.to_array_view::<T>()?;
37        let b = b.to_array_view::<T>()?;
38        let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39        let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40        let mut view = c.to_array_view_mut::<bool>()?;
41        let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
42        match *self {
43            Eq => zipped.for_each(|c, a, b| *c = a == b),
44            NE => zipped.for_each(|c, a, b| *c = a != b),
45            LT => zipped.for_each(|c, a, b| *c = a < b),
46            GT => zipped.for_each(|c, a, b| *c = a > b),
47            LTE => zipped.for_each(|c, a, b| *c = a <= b),
48            GTE => zipped.for_each(|c, a, b| *c = a >= b),
49        }
50        Ok(c)
51    }
52}
53
54impl EvalOp for Comp {
55    fn is_stateless(&self) -> bool {
56        true
57    }
58
59    fn eval_with_session(
60        &self,
61        session: &SessionState,
62        inputs: TVec<TValue>,
63    ) -> TractResult<TVec<TValue>> {
64        if inputs[0].datum_type() == TDim::datum_type() {
65            let mut a = inputs[0].clone().into_tensor();
66            let mut b = inputs[1].clone().into_tensor();
67            for a in a.as_slice_mut::<TDim>()? {
68                *a = a.eval(&session.resolved_symbols);
69            }
70            for b in b.as_slice_mut::<TDim>()? {
71                *b = b.eval(&session.resolved_symbols);
72            }
73            if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
74                return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
75            }
76            let a = inputs[0].to_array_view::<TDim>()?;
77            let b = inputs[0].to_array_view::<TDim>()?;
78            let shape = multi_broadcast(&[a.shape(), b.shape()])?;
79            let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
80            let mut view = c.to_array_view_mut::<bool>()?;
81            let a = a.broadcast(&*shape).unwrap();
82            let b = b.broadcast(&*shape).unwrap();
83            for ixs in tract_ndarray::indices(&*shape) {
84                let (a, b) = (&a[&ixs], &b[&ixs]);
85                let diff = a.clone() - b;
86                view[&ixs] = match *self {
87                    Eq => a == b,
88                    NE => a != b,
89                    GTE => {
90                        if diff.prove_positive_or_zero() {
91                            true
92                        } else if diff.prove_strict_negative() {
93                            false
94                        } else {
95                            bail!(TooEarly::UndeterminedSymbol(diff));
96                        }
97                    }
98                    GT => {
99                        if diff.prove_strict_positive() {
100                            true
101                        } else if diff.prove_negative_or_zero() {
102                            false
103                        } else {
104                            bail!(TooEarly::UndeterminedSymbol(diff));
105                        }
106                    }
107                    LTE => {
108                        if diff.prove_negative_or_zero() {
109                            true
110                        } else if diff.prove_strict_positive() {
111                            false
112                        } else {
113                            bail!(TooEarly::UndeterminedSymbol(diff));
114                        }
115                    }
116                    LT => {
117                        if diff.prove_strict_negative() {
118                            true
119                        } else if diff.prove_negative_or_zero() {
120                            false
121                        } else {
122                            bail!(TooEarly::UndeterminedSymbol(diff));
123                        }
124                    }
125                };
126            }
127            Ok(tvec!(c.into_tvalue()))
128        } else {
129            let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
130                self, &inputs[0], &inputs[1]
131            ))?;
132            Ok(tvec!(t.into_tvalue()))
133        }
134    }
135}
136
137impl TypedOp for Comp {
138    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
139        let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
140        Ok(tvec!(bool::datum_type().fact(shape)))
141    }
142
143    fn change_axes(
144        &self,
145        model: &TypedModel,
146        node: &TypedNode,
147        _io: InOut,
148        change: &AxisOp,
149    ) -> TractResult<Option<AxisChangeConsequence>> {
150        if let AxisOp::Rm(rm) = change {
151            let (inputs, outputs) = model.node_facts(node.id)?;
152            if !inputs[0].shape[*rm].is_one()
153                || !inputs[0].shape[*rm].is_one()
154                || !outputs[0].shape[*rm].is_one()
155            {
156                return Ok(None);
157            }
158        }
159        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160    }
161
162    fn slice(
163        &self,
164        patch: &mut TypedModelPatch,
165        _model: &TypedModel,
166        _node: &TypedNode,
167        prefix: &str,
168        inputs: &[OutletId],
169        _output_axis: usize,
170        _start: usize,
171        _end: usize,
172    ) -> TractResult<Option<TVec<OutletId>>> {
173        Ok(Some(patch.wire_node(prefix, *self, inputs)?))
174    }
175
176    fn axes_mapping(
177        &self,
178        inputs: &[&TypedFact],
179        outputs: &[&TypedFact],
180    ) -> TractResult<AxesMapping> {
181        AxesMapping::natural(inputs, outputs)
182    }
183
184    as_op!();
185}