tract_core/ops/logic/
comparison.rs1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash)]
6pub enum Comp {
7 Eq,
8 NE,
9 LT,
10 GT,
11 GTE,
12 LTE,
13}
14
15use tract_data::TooEarly;
16use Comp::*;
17
18impl Op for Comp {
19 fn name(&self) -> Cow<str> {
20 match *self {
21 Eq => "==",
22 NE => "!=",
23 LT => "<",
24 GT => ">",
25 LTE => "<=",
26 GTE => ">=",
27 }
28 .into()
29 }
30
31 op_as_typed_op!();
32}
33
34impl Comp {
35 fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36 let a = a.to_array_view::<T>()?;
37 let b = b.to_array_view::<T>()?;
38 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40 let mut view = c.to_array_view_mut::<bool>()?;
41 let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
42 match *self {
43 Eq => zipped.for_each(|c, a, b| *c = a == b),
44 NE => zipped.for_each(|c, a, b| *c = a != b),
45 LT => zipped.for_each(|c, a, b| *c = a < b),
46 GT => zipped.for_each(|c, a, b| *c = a > b),
47 LTE => zipped.for_each(|c, a, b| *c = a <= b),
48 GTE => zipped.for_each(|c, a, b| *c = a >= b),
49 }
50 Ok(c)
51 }
52}
53
54impl EvalOp for Comp {
55 fn is_stateless(&self) -> bool {
56 true
57 }
58
59 fn eval_with_session(
60 &self,
61 session: &SessionState,
62 inputs: TVec<TValue>,
63 ) -> TractResult<TVec<TValue>> {
64 if inputs[0].datum_type() == TDim::datum_type() {
65 let mut a = inputs[0].clone().into_tensor();
66 let mut b = inputs[1].clone().into_tensor();
67 for a in a.as_slice_mut::<TDim>()? {
68 *a = a.eval(&session.resolved_symbols);
69 }
70 for b in b.as_slice_mut::<TDim>()? {
71 *b = b.eval(&session.resolved_symbols);
72 }
73 if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
74 return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
75 }
76 let a = inputs[0].to_array_view::<TDim>()?;
77 let b = inputs[0].to_array_view::<TDim>()?;
78 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
79 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
80 let mut view = c.to_array_view_mut::<bool>()?;
81 let a = a.broadcast(&*shape).unwrap();
82 let b = b.broadcast(&*shape).unwrap();
83 for ixs in tract_ndarray::indices(&*shape) {
84 let (a, b) = (&a[&ixs], &b[&ixs]);
85 let diff = a.clone() - b;
86 view[&ixs] = match *self {
87 Eq => a == b,
88 NE => a != b,
89 GTE => {
90 if diff.prove_positive_or_zero() {
91 true
92 } else if diff.prove_strict_negative() {
93 false
94 } else {
95 bail!(TooEarly::UndeterminedSymbol(diff));
96 }
97 }
98 GT => {
99 if diff.prove_strict_positive() {
100 true
101 } else if diff.prove_negative_or_zero() {
102 false
103 } else {
104 bail!(TooEarly::UndeterminedSymbol(diff));
105 }
106 }
107 LTE => {
108 if diff.prove_negative_or_zero() {
109 true
110 } else if diff.prove_strict_positive() {
111 false
112 } else {
113 bail!(TooEarly::UndeterminedSymbol(diff));
114 }
115 }
116 LT => {
117 if diff.prove_strict_negative() {
118 true
119 } else if diff.prove_negative_or_zero() {
120 false
121 } else {
122 bail!(TooEarly::UndeterminedSymbol(diff));
123 }
124 }
125 };
126 }
127 Ok(tvec!(c.into_tvalue()))
128 } else {
129 let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
130 self, &inputs[0], &inputs[1]
131 ))?;
132 Ok(tvec!(t.into_tvalue()))
133 }
134 }
135}
136
137impl TypedOp for Comp {
138 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
139 let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
140 Ok(tvec!(bool::datum_type().fact(shape)))
141 }
142
143 fn change_axes(
144 &self,
145 model: &TypedModel,
146 node: &TypedNode,
147 _io: InOut,
148 change: &AxisOp,
149 ) -> TractResult<Option<AxisChangeConsequence>> {
150 if let AxisOp::Rm(rm) = change {
151 let (inputs, outputs) = model.node_facts(node.id)?;
152 if !inputs[0].shape[*rm].is_one()
153 || !inputs[0].shape[*rm].is_one()
154 || !outputs[0].shape[*rm].is_one()
155 {
156 return Ok(None);
157 }
158 }
159 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
160 }
161
162 fn slice(
163 &self,
164 patch: &mut TypedModelPatch,
165 _model: &TypedModel,
166 _node: &TypedNode,
167 prefix: &str,
168 inputs: &[OutletId],
169 _output_axis: usize,
170 _start: usize,
171 _end: usize,
172 ) -> TractResult<Option<TVec<OutletId>>> {
173 Ok(Some(patch.wire_node(prefix, *self, inputs)?))
174 }
175
176 fn axes_mapping(
177 &self,
178 inputs: &[&TypedFact],
179 outputs: &[&TypedFact],
180 ) -> TractResult<AxesMapping> {
181 AxesMapping::natural(inputs, outputs)
182 }
183
184 as_op!();
185}