tract_core/ops/logic/
comparison.rs1use crate::broadcast::multi_broadcast;
2use crate::internal::*;
3use crate::ndarray::Zip;
4
5#[derive(Clone, Copy, Debug, Hash)]
6pub enum Comp {
7 Eq,
8 NE,
9 LT,
10 GT,
11 GTE,
12 LTE,
13}
14
15use tract_data::TooEarly;
16use Comp::*;
17
18impl Op for Comp {
19 fn name(&self) -> StaticName {
20 match *self {
21 Eq => "==",
22 NE => "!=",
23 LT => "<",
24 GT => ">",
25 LTE => "<=",
26 GTE => ">=",
27 }
28 .into()
29 }
30
31 op_as_typed_op!();
32}
33
34impl Comp {
35 fn eval<T: Datum + PartialOrd>(&self, a: &Tensor, b: &Tensor) -> TractResult<Tensor> {
36 let a = a.to_array_view::<T>()?;
37 let b = b.to_array_view::<T>()?;
38 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
39 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
40 let mut view = c.to_array_view_mut::<bool>()?;
41 let zipped = Zip::from(&mut view).and_broadcast(&a).and_broadcast(&b);
42 match *self {
43 Eq => zipped.for_each(|c, a, b| *c = a == b),
44 NE => zipped.for_each(|c, a, b| *c = a != b),
45 LT => zipped.for_each(|c, a, b| *c = a < b),
46 GT => zipped.for_each(|c, a, b| *c = a > b),
47 LTE => zipped.for_each(|c, a, b| *c = a <= b),
48 GTE => zipped.for_each(|c, a, b| *c = a >= b),
49 }
50 Ok(c)
51 }
52}
53
54impl EvalOp for Comp {
55 fn is_stateless(&self) -> bool {
56 true
57 }
58
59 fn eval_with_session(
60 &self,
61 _node_id: usize,
62 session: &SessionState,
63 inputs: TVec<TValue>,
64 ) -> TractResult<TVec<TValue>> {
65 if inputs[0].datum_type() == TDim::datum_type() {
66 let mut a = inputs[0].clone().into_tensor();
67 let mut b = inputs[1].clone().into_tensor();
68 for a in a.as_slice_mut::<TDim>()? {
69 *a = a.eval(&session.resolved_symbols);
70 }
71 for b in b.as_slice_mut::<TDim>()? {
72 *b = b.eval(&session.resolved_symbols);
73 }
74 if let (Ok(a), Ok(b)) = (a.cast_to::<i64>(), b.cast_to::<i64>()) {
75 return Ok(tvec!(self.eval::<i64>(&a, &b)?.into_tvalue()));
76 }
77 let a = inputs[0].to_array_view::<TDim>()?;
78 let b = inputs[0].to_array_view::<TDim>()?;
79 let shape = multi_broadcast(&[a.shape(), b.shape()])?;
80 let mut c = unsafe { Tensor::uninitialized::<bool>(&shape)? };
81 let mut view = c.to_array_view_mut::<bool>()?;
82 let a = a.broadcast(&*shape).unwrap();
83 let b = b.broadcast(&*shape).unwrap();
84 for ixs in tract_ndarray::indices(&*shape) {
85 let (a, b) = (&a[&ixs], &b[&ixs]);
86 let diff = a.clone() - b;
87 view[&ixs] = match *self {
88 Eq => a == b,
89 NE => a != b,
90 GTE => {
91 if diff.prove_positive_or_zero() {
92 true
93 } else if diff.prove_strict_negative() {
94 false
95 } else {
96 bail!(TooEarly::UndeterminedSymbol(diff));
97 }
98 }
99 GT => {
100 if diff.prove_strict_positive() {
101 true
102 } else if diff.prove_negative_or_zero() {
103 false
104 } else {
105 bail!(TooEarly::UndeterminedSymbol(diff));
106 }
107 }
108 LTE => {
109 if diff.prove_negative_or_zero() {
110 true
111 } else if diff.prove_strict_positive() {
112 false
113 } else {
114 bail!(TooEarly::UndeterminedSymbol(diff));
115 }
116 }
117 LT => {
118 if diff.prove_strict_negative() {
119 true
120 } else if diff.prove_negative_or_zero() {
121 false
122 } else {
123 bail!(TooEarly::UndeterminedSymbol(diff));
124 }
125 }
126 };
127 }
128 Ok(tvec!(c.into_tvalue()))
129 } else {
130 let t = dispatch_numbers!(Self::eval(inputs[0].datum_type())(
131 self, &inputs[0], &inputs[1]
132 ))?;
133 Ok(tvec!(t.into_tvalue()))
134 }
135 }
136}
137
138impl TypedOp for Comp {
139 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
140 let shape = multi_broadcast(&[&inputs[0].shape, &inputs[1].shape])?;
141 Ok(tvec!(bool::datum_type().fact(shape)))
142 }
143
144 fn change_axes(
145 &self,
146 model: &TypedModel,
147 node: &TypedNode,
148 _io: InOut,
149 change: &AxisOp,
150 ) -> TractResult<Option<AxisChangeConsequence>> {
151 if let AxisOp::Rm(rm) = change {
152 let (inputs, outputs) = model.node_facts(node.id)?;
153 if !inputs[0].shape[*rm].is_one()
154 || !inputs[0].shape[*rm].is_one()
155 || !outputs[0].shape[*rm].is_one()
156 {
157 return Ok(None);
158 }
159 }
160 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
161 }
162
163 fn slice(
164 &self,
165 patch: &mut TypedModelPatch,
166 _model: &TypedModel,
167 _node: &TypedNode,
168 prefix: &str,
169 inputs: &[OutletId],
170 _output_axis: usize,
171 _start: &TDim,
172 _end: &TDim,
173 ) -> TractResult<Option<TVec<OutletId>>> {
174 Ok(Some(patch.wire_node(prefix, *self, inputs)?))
175 }
176
177 fn axes_mapping(
178 &self,
179 inputs: &[&TypedFact],
180 outputs: &[&TypedFact],
181 ) -> TractResult<AxesMapping> {
182 AxesMapping::natural(inputs, outputs)
183 }
184
185 as_op!();
186}