tract_core/ops/logic/
comparison.rs

1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash)]
6pub enum Comp {
7    Eq,
8    NE,
9    LT,
10    GT,
11    GTE,
12    LTE,
13}
14
15use tract_data::TooEarly;
16use Comp::*;
17
18impl Op for Comp {
19    fn name(&self) -> StaticName {
20        match *self {
21            Eq => "==",
22            NE => "!=",
23            LT => "<",
24            GT => ">",
25            LTE => "<=",
26            GTE => ">=",
27        }
28        .into()
29    }
30
31    op_as_typed_op!();
32}
33
34impl Comp {
35    fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36        let a = a.to_array_view::<T>()?;
37        let b = b.to_array_view::<T>()?;
38        let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39        let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40        let mut view = c.to_array_view_mut::<bool>()?;
41        let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
42        match *self {
43            Eq => zipped.for_each(|c, a, b| *c = a == b),
44            NE => zipped.for_each(|c, a, b| *c = a != b),
45            LT => zipped.for_each(|c, a, b| *c = a < b),
46            GT => zipped.for_each(|c, a, b| *c = a > b),
47            LTE => zipped.for_each(|c, a, b| *c = a <= b),
48            GTE => zipped.for_each(|c, a, b| *c = a >= b),
49        }
50        Ok(c)
51    }
52}
53
54impl EvalOp for Comp {
55    fn is_stateless(&self) -> bool {
56        true
57    }
58
59    fn eval_with_session(
60        &self,
61        _node_id: usize,
62        session: &SessionState,
63        inputs: TVec<TValue>,
64    ) -> TractResult<TVec<TValue>> {
65        if inputs[0].datum_type() == TDim::datum_type() {
66            let mut a = inputs[0].clone().into_tensor();
67            let mut b = inputs[1].clone().into_tensor();
68            for a in a.as_slice_mut::<TDim>()? {
69                *a = a.eval(&session.resolved_symbols);
70            }
71            for b in b.as_slice_mut::<TDim>()? {
72                *b = b.eval(&session.resolved_symbols);
73            }
74            if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
75                return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
76            }
77            let a = inputs[0].to_array_view::<TDim>()?;
78            let b = inputs[0].to_array_view::<TDim>()?;
79            let shape = multi_broadcast(&[a.shape(), b.shape()])?;
80            let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
81            let mut view = c.to_array_view_mut::<bool>()?;
82            let a = a.broadcast(&*shape).unwrap();
83            let b = b.broadcast(&*shape).unwrap();
84            for ixs in tract_ndarray::indices(&*shape) {
85                let (a, b) = (&a[&ixs], &b[&ixs]);
86                let diff = a.clone() - b;
87                view[&ixs] = match *self {
88                    Eq => a == b,
89                    NE => a != b,
90                    GTE => {
91                        if diff.prove_positive_or_zero() {
92                            true
93                        } else if diff.prove_strict_negative() {
94                            false
95                        } else {
96                            bail!(TooEarly::UndeterminedSymbol(diff));
97                        }
98                    }
99                    GT => {
100                        if diff.prove_strict_positive() {
101                            true
102                        } else if diff.prove_negative_or_zero() {
103                            false
104                        } else {
105                            bail!(TooEarly::UndeterminedSymbol(diff));
106                        }
107                    }
108                    LTE => {
109                        if diff.prove_negative_or_zero() {
110                            true
111                        } else if diff.prove_strict_positive() {
112                            false
113                        } else {
114                            bail!(TooEarly::UndeterminedSymbol(diff));
115                        }
116                    }
117                    LT => {
118                        if diff.prove_strict_negative() {
119                            true
120                        } else if diff.prove_negative_or_zero() {
121                            false
122                        } else {
123                            bail!(TooEarly::UndeterminedSymbol(diff));
124                        }
125                    }
126                };
127            }
128            Ok(tvec!(c.into_tvalue()))
129        } else {
130            let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
131                self, &inputs[0], &inputs[1]
132            ))?;
133            Ok(tvec!(t.into_tvalue()))
134        }
135    }
136}
137
138impl TypedOp for Comp {
139    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
140        let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
141        Ok(tvec!(bool::datum_type().fact(shape)))
142    }
143
144    fn change_axes(
145        &self,
146        model: &TypedModel,
147        node: &TypedNode,
148        _io: InOut,
149        change: &AxisOp,
150    ) -> TractResult<Option<AxisChangeConsequence>> {
151        if let AxisOp::Rm(rm) = change {
152            let (inputs, outputs) = model.node_facts(node.id)?;
153            if !inputs[0].shape[*rm].is_one()
154                || !inputs[0].shape[*rm].is_one()
155                || !outputs[0].shape[*rm].is_one()
156            {
157                return Ok(None);
158            }
159        }
160        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
161    }
162
163    fn slice(
164        &self,
165        patch: &mut TypedModelPatch,
166        _model: &TypedModel,
167        _node: &TypedNode,
168        prefix: &str,
169        inputs: &[OutletId],
170        _output_axis: usize,
171        _start: &TDim,
172        _end: &TDim,
173    ) -> TractResult<Option<TVec<OutletId>>> {
174        Ok(Some(patch.wire_node(prefix, *self, inputs)?))
175    }
176
177    fn axes_mapping(
178        &self,
179        inputs: &[&TypedFact],
180        outputs: &[&TypedFact],
181    ) -> TractResult<AxesMapping> {
182        AxesMapping::natural(inputs, outputs)
183    }
184
185    as_op!();
186}