vortex_array/operator/compare.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12
13use crate::arrays::ConstantArray;
14use crate::compute::Operator as Op;
15use crate::operator::{Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
16use crate::pipeline::view::ViewMut;
17use crate::pipeline::{BindContext, Element, Kernel, KernelContext, PipelinedOperator, VectorId};
18
19#[derive(Debug)]
20pub struct CompareOperator {
21 children: [OperatorRef; 2],
22 op: Op,
23 dtype: DType,
24}
25
26impl CompareOperator {
27 pub fn try_new(lhs: OperatorRef, rhs: OperatorRef, op: Op) -> VortexResult<CompareOperator> {
28 if lhs.dtype() != rhs.dtype() {
29 vortex_bail!(
30 "Cannot compare arrays with different dtypes: {} and {}",
31 lhs.dtype(),
32 rhs.dtype()
33 );
34 }
35
36 let lhs_const = lhs.as_any().downcast_ref::<ConstantArray>();
37 let rhs_const = rhs.as_any().downcast_ref::<ConstantArray>();
38 if lhs_const.is_some() && rhs_const.is_some() {
39 // TODO(ngates): we should return the Constant result!
40 }
41
42 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
43 let dtype = DType::Bool(nullability);
44
45 Ok(CompareOperator {
46 children: [lhs, rhs],
47 op,
48 dtype,
49 })
50 }
51
52 pub fn op(&self) -> Op {
53 self.op
54 }
55}
56
57impl OperatorHash for CompareOperator {
58 fn operator_hash<H: Hasher>(&self, state: &mut H) {
59 self.op.hash(state);
60 self.dtype.hash(state);
61 self.children.iter().for_each(|c| c.operator_hash(state));
62 }
63}
64
65impl OperatorEq for CompareOperator {
66 fn operator_eq(&self, other: &Self) -> bool {
67 self.op == other.op
68 && self.dtype == other.dtype
69 && self
70 .children
71 .iter()
72 .zip(other.children.iter())
73 .all(|(a, b)| a.operator_eq(b))
74 }
75}
76
77impl Operator for CompareOperator {
78 fn id(&self) -> OperatorId {
79 OperatorId::from("vortex.compare")
80 }
81
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn dtype(&self) -> &DType {
87 &self.dtype
88 }
89
90 fn len(&self) -> usize {
91 self.children[0].len() & self.children[1].len()
92 }
93
94 fn children(&self) -> &[OperatorRef] {
95 &self.children
96 }
97
98 fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99 let (lhs, rhs) = children
100 .into_iter()
101 .tuples()
102 .next()
103 .vortex_expect("missing");
104 Ok(Arc::new(CompareOperator {
105 children: [lhs, rhs],
106 op: self.op,
107 dtype: self.dtype.clone(),
108 }))
109 }
110
111 fn as_pipelined(&self) -> Option<&dyn PipelinedOperator> {
112 Some(self)
113 }
114}
115
116macro_rules! match_each_compare_op {
117 ($self:expr, | $enc:ident | $body:block) => {{
118 match $self {
119 Op::Eq => {
120 type $enc = Eq;
121 $body
122 }
123 Op::NotEq => {
124 type $enc = NotEq;
125 $body
126 }
127 Op::Gt => {
128 type $enc = Gt;
129 $body
130 }
131 Op::Gte => {
132 type $enc = Gte;
133 $body
134 }
135 Op::Lt => {
136 type $enc = Lt;
137 $body
138 }
139 Op::Lte => {
140 type $enc = Lte;
141 $body
142 }
143 }
144 }};
145}
146
147impl PipelinedOperator for CompareOperator {
148 #[allow(clippy::cognitive_complexity)]
149 fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
150 debug_assert_eq!(self.children[0].dtype(), self.children[1].dtype());
151
152 let DType::Primitive(ptype, _) = self.children[0].dtype() else {
153 vortex_bail!(
154 "Unsupported type for comparison: {}",
155 self.children[0].dtype()
156 )
157 };
158
159 let lhs_const = self.children[0].as_any().downcast_ref::<ConstantArray>();
160 if let Some(lhs_const) = lhs_const {
161 // LHS is constant, use ScalarComparePrimitiveKernel
162 return match_each_native_ptype!(ptype, |T| {
163 match_each_compare_op!(self.op.swap(), |Op| {
164 Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
165 lhs: ctx.children()[1],
166 rhs: lhs_const
167 .scalar()
168 .as_primitive()
169 .typed_value::<T>()
170 .vortex_expect("scalar value not of type T"),
171 _phantom: PhantomData,
172 }) as Box<dyn Kernel>)
173 })
174 });
175 }
176
177 let rhs_const = self.children[1].as_any().downcast_ref::<ConstantArray>();
178 if let Some(rhs_const) = rhs_const {
179 // RHS is constant, use ScalarComparePrimitiveKernel
180 return match_each_native_ptype!(ptype, |T| {
181 match_each_compare_op!(self.op, |Op| {
182 Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
183 lhs: ctx.children()[0],
184 rhs: rhs_const
185 .scalar()
186 .as_primitive()
187 .typed_value::<T>()
188 .vortex_expect("scalar value not of type T"),
189 _phantom: PhantomData,
190 }) as Box<dyn Kernel>)
191 })
192 });
193 }
194
195 match_each_native_ptype!(ptype, |T| {
196 match_each_compare_op!(self.op, |Op| {
197 Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
198 lhs: ctx.children()[0],
199 rhs: ctx.children()[1],
200 _phantom: PhantomData,
201 }) as Box<dyn Kernel>)
202 })
203 })
204 }
205
206 fn vector_children(&self) -> Vec<usize> {
207 vec![0, 1]
208 }
209
210 fn batch_children(&self) -> Vec<usize> {
211 vec![]
212 }
213}
214
215/// A compare operator for primitive types that compares two vectors element-wise using a binary
216/// operation.
217/// Kernel that performs primitive type comparisons between two input vectors.
218pub struct ComparePrimitiveKernel<T, Op> {
219 lhs: VectorId,
220 rhs: VectorId,
221 _phantom: PhantomData<(T, Op)>,
222}
223
224impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel for ComparePrimitiveKernel<T, Op> {
225 fn step(&mut self, ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
226 let lhs_vec = ctx.vector(self.lhs);
227 let lhs = lhs_vec.as_slice::<T>();
228 let rhs_vec = ctx.vector(self.rhs);
229 let rhs = rhs_vec.as_slice::<T>();
230 let bools = out.as_slice_mut::<bool>();
231
232 assert_eq!(
233 lhs.len(),
234 rhs.len(),
235 "LHS and RHS must have the same length"
236 );
237
238 lhs.iter()
239 .zip(rhs.iter())
240 .zip(bools)
241 .for_each(|((lhs, rhs), bool)| *bool = Op::compare(lhs, rhs));
242
243 out.set_len(lhs.len());
244
245 Ok(())
246 }
247}
248
249struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
250 lhs: VectorId,
251 rhs: T,
252 _phantom: PhantomData<Op>,
253}
254
255impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel
256 for ScalarComparePrimitiveKernel<T, Op>
257{
258 fn step(&mut self, ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
259 let lhs_vec = ctx.vector(self.lhs);
260 let lhs = lhs_vec.as_slice::<T>();
261 let bools = out.as_slice_mut::<bool>();
262
263 // Note we zip only over the shortest iterator which is LHS
264 lhs.iter().zip(bools).for_each(|(lhs, bool)| {
265 *bool = Op::compare(lhs, &self.rhs);
266 });
267 out.set_len(lhs.len());
268
269 Ok(())
270 }
271}
272
273pub(crate) trait CompareOp<T> {
274 fn compare(lhs: &T, rhs: &T) -> bool;
275}
276
277/// Equality comparison operation.
278pub struct Eq;
279impl<T: PartialEq> CompareOp<T> for Eq {
280 #[inline(always)]
281 fn compare(lhs: &T, rhs: &T) -> bool {
282 lhs == rhs
283 }
284}
285
286/// Not equal comparison operation.
287pub struct NotEq;
288impl<T: PartialEq> CompareOp<T> for NotEq {
289 #[inline(always)]
290 fn compare(lhs: &T, rhs: &T) -> bool {
291 lhs != rhs
292 }
293}
294
295/// Greater than comparison operation.
296pub struct Gt;
297impl<T: PartialOrd> CompareOp<T> for Gt {
298 #[inline(always)]
299 fn compare(lhs: &T, rhs: &T) -> bool {
300 lhs > rhs
301 }
302}
303
304/// Greater than or equal comparison operation.
305pub struct Gte;
306impl<T: PartialOrd> CompareOp<T> for Gte {
307 #[inline(always)]
308 fn compare(lhs: &T, rhs: &T) -> bool {
309 lhs >= rhs
310 }
311}
312
313/// Less than comparison operation.
314pub struct Lt;
315impl<T: PartialOrd> CompareOp<T> for Lt {
316 #[inline(always)]
317 fn compare(lhs: &T, rhs: &T) -> bool {
318 lhs < rhs
319 }
320}
321
322/// Less than or equal comparison operation.
323pub struct Lte;
324impl<T: PartialOrd> CompareOp<T> for Lte {
325 #[inline(always)]
326 fn compare(lhs: &T, rhs: &T) -> bool {
327 lhs <= rhs
328 }
329}
330
331// TODO(ngates): bring these back!
332// #[cfg(test)]
333// mod tests {
334// use std::rc::Rc;
335//
336// use vortex_buffer::BufferMut;
337// use vortex_dtype::Nullability;
338// use vortex_scalar::Scalar;
339//
340// use crate::arrays::PrimitiveArray;
341// use crate::operator::bits::BitView;
342//
343// #[test]
344// fn test_scalar_compare_stacked_on_primitive() {
345// // Create input data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
346// let size = 16;
347// let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
348// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
349//
350// // Create scalar compare operator: primitive_value > 10
351// let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
352// let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
353// primitive_op,
354// BinaryOperator::Gt,
355// compare_value,
356// ));
357//
358// // Create query plan from the stacked operators
359// let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
360// let mut operator = plan.executable_plan().unwrap();
361//
362// // Create all-true mask for simplicity
363// let mask_data = [usize::MAX; N_WORDS];
364// let mask_view = BitView::new(&mask_data);
365//
366// // Create output buffer for boolean results
367// let mut output = BufferMut::<bool>::with_capacity(N);
368// unsafe { output.set_len(N) };
369// let mut output_view = ViewMut::new(&mut output[..], None);
370//
371// // Execute the operator
372// let result = operator._step(mask_view, &mut output_view);
373// assert!(result.is_ok());
374//
375// // Verify results: values 0-10 should be false, values 11-15 should be true
376// for i in 0..size {
377// let expected = i > 10;
378// assert_eq!(
379// output[i], expected,
380// "Position {}: expected {}, got {}",
381// i, expected, output[i]
382// );
383// }
384// }
385//
386// #[test]
387// fn test_scalar_compare_different_operators() {
388// // Test with different comparison operators
389// let size = 8;
390// let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
391//
392// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
393//
394// // Test Eq: values == 3
395// let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
396// let eq_op = Rc::new(ScalarCompareOperator::new(
397// primitive_op,
398// BinaryOperator::Eq,
399// compare_value,
400// ));
401//
402// let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
403// let mut operator = plan.executable_plan().unwrap();
404//
405// let mask_data = [usize::MAX; N_WORDS];
406// let mask_view = BitView::new(&mask_data);
407//
408// let mut output = BufferMut::<bool>::with_capacity(N);
409// unsafe { output.set_len(N) };
410// let mut output_view = ViewMut::new(&mut output[..], None);
411//
412// let result = operator._step(mask_view, &mut output_view);
413// assert!(result.is_ok());
414//
415// // Only position 3 should be true
416// for i in 0..size {
417// let expected = i == 3;
418// assert_eq!(
419// output[i], expected,
420// "Eq test - Position {}: expected {}, got {}",
421// i, expected, output[i]
422// );
423// }
424// }
425//
426// #[test]
427// fn test_scalar_compare_with_f32() {
428// // Test with floating-point values
429// let size = 8;
430// let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
431// let primitive_array = values.into_iter().collect::<PrimitiveArray>();
432//
433// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
434//
435// // Test Lt: values < 3.5
436// let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
437// let lt_op = Rc::new(ScalarCompareOperator::new(
438// primitive_op,
439// BinaryOperator::Lt,
440// compare_value,
441// ));
442//
443// let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
444// let mut operator = plan.executable_plan().unwrap();
445//
446// let mask_data = [usize::MAX; N_WORDS];
447// let mask_view = BitView::new(&mask_data);
448//
449// let mut output = BufferMut::<bool>::with_capacity(N);
450// unsafe { output.set_len(N) };
451// let mut output_view = ViewMut::new(&mut output[..], None);
452//
453// let result = operator._step(mask_view, &mut output_view);
454// assert!(result.is_ok());
455//
456// // Values 0.5, 1.5, 2.5 should be < 3.5 (true), 3.5+ should be false
457// for i in 0..size {
458// let value = i as f32 + 0.5;
459// let expected = value < 3.5;
460// assert_eq!(
461// output[i], expected,
462// "Lt test - Position {}: value {} should be {}, got {}",
463// i, value, expected, output[i]
464// );
465// }
466// }
467// }