tract_core/ops/logic/
comparison.rs1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash, PartialEq)]
6pub enum Comp {
7 Eq,
8 NE,
9 LT,
10 GT,
11 GTE,
12 LTE,
13}
14
15use Comp::*;
16use tract_data::TooEarly;
17
18impl Op for Comp {
19 fn name(&self) -> StaticName {
20 match *self {
21 Eq => "==",
22 NE => "!=",
23 LT => "<",
24 GT => ">",
25 LTE => "<=",
26 GTE => ">=",
27 }
28 .into()
29 }
30
31 op_as_typed_op!();
32}
33
34impl Comp {
35 fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36 let a = a.to_dense_array_view::<T>()?;
37 let b = b.to_dense_array_view::<T>()?;
38 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40 let mut c_dense = c.try_as_dense_mut()?;
41 let mut view = c_dense.to_array_view_mut::<bool>()?;
42 let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
43 match *self {
44 Eq => zipped.for_each(|c, a, b| *c = a == b),
45 NE => zipped.for_each(|c, a, b| *c = a != b),
46 LT => zipped.for_each(|c, a, b| *c = a < b),
47 GT => zipped.for_each(|c, a, b| *c = a > b),
48 LTE => zipped.for_each(|c, a, b| *c = a <= b),
49 GTE => zipped.for_each(|c, a, b| *c = a >= b),
50 }
51 Ok(c)
52 }
53}
54
55impl EvalOp for Comp {
56 fn is_stateless(&self) -> bool {
57 true
58 }
59
60 fn eval_with_session(
61 &self,
62 _node_id: usize,
63 session: &TurnState,
64 inputs: TVec<TValue>,
65 ) -> TractResult<TVec<TValue>> {
66 if inputs[0].datum_type() == TDim::datum_type() {
67 let mut a = inputs[0].clone().into_tensor();
68 let mut b = inputs[1].clone().into_tensor();
69 for a in a.try_as_dense_mut()?.as_slice_mut::<TDim>()? {
70 *a = a.eval(&session.resolved_symbols);
71 }
72 for b in b.try_as_dense_mut()?.as_slice_mut::<TDim>()? {
73 *b = b.eval(&session.resolved_symbols);
74 }
75 if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
76 return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
77 }
78 let a = inputs[0].to_dense_array_view::<TDim>()?;
79 let b = inputs[0].to_dense_array_view::<TDim>()?;
80 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
81 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
82 let mut c_dense = c.try_as_dense_mut()?;
83 let mut view = c_dense.to_array_view_mut::<bool>()?;
84 let a = a.broadcast(&*shape).unwrap();
85 let b = b.broadcast(&*shape).unwrap();
86 for ixs in tract_ndarray::indices(&*shape) {
87 let (a, b) = (&a[&ixs], &b[&ixs]);
88 let diff = a.clone() - b;
89 view[&ixs] = match *self {
90 Eq => a == b,
91 NE => a != b,
92 GTE => {
93 if diff.prove_positive_or_zero() {
94 true
95 } else if diff.prove_strict_negative() {
96 false
97 } else {
98 bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
99 }
100 }
101 GT => {
102 if diff.prove_strict_positive() {
103 true
104 } else if diff.prove_negative_or_zero() {
105 false
106 } else {
107 bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
108 }
109 }
110 LTE => {
111 if diff.prove_negative_or_zero() {
112 true
113 } else if diff.prove_strict_positive() {
114 false
115 } else {
116 bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
117 }
118 }
119 LT => {
120 if diff.prove_strict_negative() {
121 true
122 } else if diff.prove_negative_or_zero() {
123 false
124 } else {
125 bail!(TooEarly::UndeterminedSymbol(diff.to_string()));
126 }
127 }
128 };
129 }
130 Ok(tvec!(c.into_tvalue()))
131 } else if inputs[0].datum_type().is::<String>() {
132 let t = self.eval::<String>(&inputs[0], &inputs[1])?;
133 Ok(tvec!(t.into_tvalue()))
134 } else {
135 let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
136 self, &inputs[0], &inputs[1]
137 ))?;
138 Ok(tvec!(t.into_tvalue()))
139 }
140 }
141}
142
143impl TypedOp for Comp {
144 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
145 let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
146 Ok(tvec!(bool::datum_type().fact(shape)))
147 }
148
149 fn change_axes(
150 &self,
151 model: &TypedModel,
152 node: &TypedNode,
153 _io: InOut,
154 change: &AxisOp,
155 ) -> TractResult<Option<AxisChangeConsequence>> {
156 if let AxisOp::Rm(rm) = change {
157 let (inputs, outputs) = model.node_facts(node.id)?;
158 rule_if!(inputs[0].shape[*rm].is_one());
159 rule_if!(inputs[1].shape[*rm].is_one());
160 rule_if!(outputs[0].shape[*rm].is_one());
161 }
162 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
163 }
164
165 fn slice(
166 &self,
167 patch: &mut TypedModelPatch,
168 _model: &TypedModel,
169 _node: &TypedNode,
170 prefix: &str,
171 inputs: &[OutletId],
172 _output_axis: usize,
173 _start: &TDim,
174 _end: &TDim,
175 ) -> TractResult<Option<TVec<OutletId>>> {
176 Ok(Some(patch.wire_node(prefix, *self, inputs)?))
177 }
178
179 fn axes_mapping(
180 &self,
181 inputs: &[&TypedFact],
182 outputs: &[&TypedFact],
183 ) -> TractResult<AxesMapping> {
184 AxesMapping::natural(inputs, outputs)
185 }
186
187 as_op!();
188}