vortex_array/pipeline/operators/
compare.rs1use std::any::Any;
5use std::marker::PhantomData;
6use std::rc::Rc;
7
8use itertools::Itertools;
9use vortex_dtype::{NativePType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11
12use crate::arrays::ConstantOperator;
13use crate::compute::Operator as BinaryOperator;
14use crate::pipeline::bits::BitView;
15use crate::pipeline::operators::scalar_compare::ScalarCompareOperator;
16use crate::pipeline::operators::{BindContext, Operator};
17use crate::pipeline::types::{Element, VType};
18use crate::pipeline::vec::VectorId;
19use crate::pipeline::view::ViewMut;
20use crate::pipeline::{Kernel, KernelContext};
21
22#[macro_export]
23macro_rules! match_each_compare_op {
24 ($self:expr, | $enc:ident | $body:block) => {{
25 match $self {
26 BinaryOperator::Eq => {
27 type $enc = $crate::pipeline::operators::compare::Eq;
28 $body
29 }
30 BinaryOperator::NotEq => {
31 type $enc = $crate::pipeline::operators::compare::NotEq;
32 $body
33 }
34 BinaryOperator::Gt => {
35 type $enc = $crate::pipeline::operators::compare::Gt;
36 $body
37 }
38 BinaryOperator::Gte => {
39 type $enc = $crate::pipeline::operators::compare::Gte;
40 $body
41 }
42 BinaryOperator::Lt => {
43 type $enc = $crate::pipeline::operators::compare::Lt;
44 $body
45 }
46 BinaryOperator::Lte => {
47 type $enc = $crate::pipeline::operators::compare::Lte;
48 $body
49 }
50 }
51 }};
52}
53
54#[derive(Debug, Hash)]
56pub struct CompareOperator {
57 children: [Rc<dyn Operator>; 2],
58 op: BinaryOperator,
59}
60
61impl CompareOperator {
62 pub fn new(lhs: Rc<dyn Operator>, rhs: Rc<dyn Operator>, op: BinaryOperator) -> Self {
63 assert_eq!(lhs.vtype(), rhs.vtype(), "Operands must have the same type");
64 Self {
65 children: [lhs, rhs],
66 op,
67 }
68 }
69}
70
71impl Operator for CompareOperator {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn vtype(&self) -> VType {
77 VType::Bool
78 }
79
80 fn children(&self) -> &[Rc<dyn Operator>] {
81 &self.children
82 }
83
84 fn with_children(&self, children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
85 let [lhs, rhs] = children
86 .try_into()
87 .ok()
88 .vortex_expect("Expected 2 children");
89 Rc::new(CompareOperator::new(lhs, rhs, self.op))
90 }
91
92 fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
93 debug_assert_eq!(self.children[0].vtype(), self.children[1].vtype());
94
95 let VType::Primitive(ptype) = self.children[0].vtype() else {
96 vortex_bail!(
97 "Unsupported type for comparison: {}",
98 self.children[0].vtype()
99 )
100 };
101
102 match_each_native_ptype!(ptype, |T| {
103 match_each_compare_op!(self.op, |Op| {
104 Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
105 lhs: ctx.children()[0],
106 rhs: ctx.children()[1],
107 _phantom: PhantomData,
108 }) as Box<dyn Kernel>)
109 })
110 })
111 }
112
113 fn reduce_children(&self, children: &[Rc<dyn Operator>]) -> Option<Rc<dyn Operator>> {
114 let constants = children
115 .iter()
116 .enumerate()
117 .filter_map(|(idx, c)| {
118 c.as_any()
119 .downcast_ref::<ConstantOperator>()
120 .map(|c| (idx, c))
121 })
122 .collect_vec();
123
124 if constants.len() != 1 {
125 return None;
126 }
127 let [(idx, lhs)] = constants
128 .try_into()
129 .ok()
130 .vortex_expect("Expected 1 constant");
131
132 if idx == 0 {
133 Some(Rc::new(ScalarCompareOperator::new(
134 children[1].clone(),
135 self.op.inverse(),
136 lhs.scalar.clone(),
137 )))
138 } else {
139 Some(Rc::new(ScalarCompareOperator::new(
140 children[0].clone(),
141 self.op,
142 lhs.scalar.clone(),
143 )))
144 }
145 }
146}
147
148pub struct ComparePrimitiveKernel<T, Op> {
152 lhs: VectorId,
153 rhs: VectorId,
154 _phantom: PhantomData<(T, Op)>,
155}
156
157impl<T: Element + NativePType, Op: CompareOp<T>> Kernel for ComparePrimitiveKernel<T, Op> {
158 fn step(
159 &mut self,
160 ctx: &KernelContext,
161 selected: BitView,
162 out: &mut ViewMut,
163 ) -> VortexResult<()> {
164 let lhs_vec = ctx.vector(self.lhs);
165 let lhs = lhs_vec.as_slice::<T>();
166 let rhs_vec = ctx.vector(self.rhs);
167 let rhs = rhs_vec.as_slice::<T>();
168 let bools = out.as_slice_mut::<bool>();
169
170 assert_eq!(
171 lhs.len(),
172 rhs.len(),
173 "LHS and RHS must have the same length"
174 );
175
176 lhs.iter()
177 .zip(rhs.iter())
178 .zip(bools)
179 .for_each(|((lhs, rhs), bool)| *bool = Op::compare(lhs, rhs));
180
181 Ok(())
182 }
183}
184
185pub(crate) trait CompareOp<T> {
186 fn compare(lhs: &T, rhs: &T) -> bool;
187}
188
189pub struct Eq;
191impl<T: PartialEq> CompareOp<T> for Eq {
192 #[inline(always)]
193 fn compare(lhs: &T, rhs: &T) -> bool {
194 lhs == rhs
195 }
196}
197
198pub struct NotEq;
200impl<T: PartialEq> CompareOp<T> for NotEq {
201 #[inline(always)]
202 fn compare(lhs: &T, rhs: &T) -> bool {
203 lhs != rhs
204 }
205}
206
207pub struct Gt;
209impl<T: PartialOrd> CompareOp<T> for Gt {
210 #[inline(always)]
211 fn compare(lhs: &T, rhs: &T) -> bool {
212 lhs > rhs
213 }
214}
215
216pub struct Gte;
218impl<T: PartialOrd> CompareOp<T> for Gte {
219 #[inline(always)]
220 fn compare(lhs: &T, rhs: &T) -> bool {
221 lhs >= rhs
222 }
223}
224
225pub struct Lt;
227impl<T: PartialOrd> CompareOp<T> for Lt {
228 #[inline(always)]
229 fn compare(lhs: &T, rhs: &T) -> bool {
230 lhs < rhs
231 }
232}
233
234pub struct Lte;
236impl<T: PartialOrd> CompareOp<T> for Lte {
237 #[inline(always)]
238 fn compare(lhs: &T, rhs: &T) -> bool {
239 lhs <= rhs
240 }
241}