vortex_array/operator/compare.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12
13use crate::arrays::ConstantArray;
14use crate::compute::Operator as Op;
15use crate::operator::{LengthBounds, Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
16use crate::pipeline::bits::BitView;
17use crate::pipeline::vec::Selection;
18use crate::pipeline::view::ViewMut;
19use crate::pipeline::{
20 BindContext, Element, Kernel, KernelContext, PipelinedOperator, RowSelection, VectorId,
21};
22
23#[derive(Debug)]
24pub struct CompareOperator {
25 children: [OperatorRef; 2],
26 op: Op,
27 dtype: DType,
28}
29
30impl CompareOperator {
31 pub fn try_new(lhs: OperatorRef, rhs: OperatorRef, op: Op) -> VortexResult<CompareOperator> {
32 if lhs.dtype() != rhs.dtype() {
33 vortex_bail!(
34 "Cannot compare arrays with different dtypes: {} and {}",
35 lhs.dtype(),
36 rhs.dtype()
37 );
38 }
39
40 let lhs_const = lhs.as_any().downcast_ref::<ConstantArray>();
41 let rhs_const = rhs.as_any().downcast_ref::<ConstantArray>();
42 if lhs_const.is_some() && rhs_const.is_some() {
43 // TODO(ngates): we should return the Constant result!
44 }
45
46 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
47 let dtype = DType::Bool(nullability);
48
49 Ok(CompareOperator {
50 children: [lhs, rhs],
51 op,
52 dtype,
53 })
54 }
55
56 pub fn op(&self) -> Op {
57 self.op
58 }
59}
60
61impl OperatorHash for CompareOperator {
62 fn operator_hash<H: Hasher>(&self, state: &mut H) {
63 self.op.hash(state);
64 self.dtype.hash(state);
65 self.children.iter().for_each(|c| c.operator_hash(state));
66 }
67}
68
69impl OperatorEq for CompareOperator {
70 fn operator_eq(&self, other: &Self) -> bool {
71 self.op == other.op
72 && self.dtype == other.dtype
73 && self
74 .children
75 .iter()
76 .zip(other.children.iter())
77 .all(|(a, b)| a.operator_eq(b))
78 }
79}
80
81impl Operator for CompareOperator {
82 fn id(&self) -> OperatorId {
83 OperatorId::from("vortex.compare")
84 }
85
86 fn as_any(&self) -> &dyn Any {
87 self
88 }
89
90 fn dtype(&self) -> &DType {
91 &self.dtype
92 }
93
94 fn bounds(&self) -> LengthBounds {
95 self.children[0].bounds() & self.children[1].bounds()
96 }
97
98 fn children(&self) -> &[OperatorRef] {
99 &self.children
100 }
101
102 fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
103 let (lhs, rhs) = children
104 .into_iter()
105 .tuples()
106 .next()
107 .vortex_expect("missing");
108 Ok(Arc::new(CompareOperator {
109 children: [lhs, rhs],
110 op: self.op,
111 dtype: self.dtype.clone(),
112 }))
113 }
114
115 fn as_pipelined(&self) -> Option<&dyn PipelinedOperator> {
116 // If both children support pipelining, but have different row selections, then we cannot
117 // pipeline without an alignment step (which we currently do not support).
118 if let Some((left, right)) = self.children[0]
119 .as_pipelined()
120 .zip(self.children[1].as_pipelined())
121 && left.row_selection() != right.row_selection()
122 {
123 return None;
124 }
125
126 Some(self)
127 }
128}
129
130macro_rules! match_each_compare_op {
131 ($self:expr, | $enc:ident | $body:block) => {{
132 match $self {
133 Op::Eq => {
134 type $enc = Eq;
135 $body
136 }
137 Op::NotEq => {
138 type $enc = NotEq;
139 $body
140 }
141 Op::Gt => {
142 type $enc = Gt;
143 $body
144 }
145 Op::Gte => {
146 type $enc = Gte;
147 $body
148 }
149 Op::Lt => {
150 type $enc = Lt;
151 $body
152 }
153 Op::Lte => {
154 type $enc = Lte;
155 $body
156 }
157 }
158 }};
159}
160
161impl PipelinedOperator for CompareOperator {
162 fn row_selection(&self) -> RowSelection {
163 self.children[0]
164 .as_pipelined()
165 .map(|p| p.row_selection())
166 .unwrap_or(RowSelection::All)
167 }
168
169 #[allow(clippy::cognitive_complexity)]
170 fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
171 debug_assert_eq!(self.children[0].dtype(), self.children[1].dtype());
172
173 let DType::Primitive(ptype, _) = self.children[0].dtype() else {
174 vortex_bail!(
175 "Unsupported type for comparison: {}",
176 self.children[0].dtype()
177 )
178 };
179
180 let lhs_const = self.children[0].as_any().downcast_ref::<ConstantArray>();
181 if let Some(lhs_const) = lhs_const {
182 // LHS is constant, use ScalarComparePrimitiveKernel
183 return match_each_native_ptype!(ptype, |T| {
184 match_each_compare_op!(self.op.swap(), |Op| {
185 Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
186 lhs: ctx.children()[1],
187 rhs: lhs_const
188 .scalar()
189 .as_primitive()
190 .typed_value::<T>()
191 .vortex_expect("scalar value not of type T"),
192 _phantom: PhantomData,
193 }) as Box<dyn Kernel>)
194 })
195 });
196 }
197
198 let rhs_const = self.children[1].as_any().downcast_ref::<ConstantArray>();
199 if let Some(rhs_const) = rhs_const {
200 // RHS is constant, use ScalarComparePrimitiveKernel
201 return match_each_native_ptype!(ptype, |T| {
202 match_each_compare_op!(self.op, |Op| {
203 Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
204 lhs: ctx.children()[0],
205 rhs: rhs_const
206 .scalar()
207 .as_primitive()
208 .typed_value::<T>()
209 .vortex_expect("scalar value not of type T"),
210 _phantom: PhantomData,
211 }) as Box<dyn Kernel>)
212 })
213 });
214 }
215
216 match_each_native_ptype!(ptype, |T| {
217 match_each_compare_op!(self.op, |Op| {
218 Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
219 lhs: ctx.children()[0],
220 rhs: ctx.children()[1],
221 _phantom: PhantomData,
222 }) as Box<dyn Kernel>)
223 })
224 })
225 }
226
227 fn vector_children(&self) -> Vec<usize> {
228 vec![0, 1]
229 }
230
231 fn batch_children(&self) -> Vec<usize> {
232 vec![]
233 }
234}
235
236/// A compare operator for primitive types that compares two vectors element-wise using a binary
237/// operation.
238/// Kernel that performs primitive type comparisons between two input vectors.
239pub struct ComparePrimitiveKernel<T, Op> {
240 lhs: VectorId,
241 rhs: VectorId,
242 _phantom: PhantomData<(T, Op)>,
243}
244
245impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel for ComparePrimitiveKernel<T, Op> {
246 fn step(
247 &self,
248 ctx: &KernelContext,
249 _chunk_idx: usize,
250 selection: &BitView,
251 out: &mut ViewMut,
252 ) -> VortexResult<()> {
253 let lhs_vec = ctx.vector(self.lhs);
254 let lhs = lhs_vec.as_array::<T>();
255 let rhs_vec = ctx.vector(self.rhs);
256 let rhs = rhs_vec.as_array::<T>();
257 let bools = out.as_array_mut::<bool>();
258
259 match (lhs_vec.selection(), rhs_vec.selection()) {
260 (Selection::Prefix, Selection::Prefix) => {
261 for i in 0..selection.true_count() {
262 bools[i] = Op::compare(&lhs[i], &rhs[i]);
263 }
264 out.set_selection(Selection::Prefix)
265 }
266 (Selection::Mask, Selection::Mask) => {
267 // TODO(ngates): check density to decide if we should iterate indices or do
268 // a full scan
269 let mut pos = 0;
270 selection.iter_ones(|idx| {
271 bools[pos] = Op::compare(&lhs[idx], &rhs[idx]);
272 pos += 1;
273 });
274 out.set_selection(Selection::Prefix)
275 }
276 (Selection::Mask, Selection::Prefix) => {
277 let mut pos = 0;
278 selection.iter_ones(|idx| {
279 bools[pos] = Op::compare(&lhs[idx], &rhs[pos]);
280 pos += 1;
281 });
282 out.set_selection(Selection::Prefix)
283 }
284 (Selection::Prefix, Selection::Mask) => {
285 let mut pos = 0;
286 selection.iter_ones(|idx| {
287 bools[pos] = Op::compare(&lhs[pos], &rhs[idx]);
288 pos += 1;
289 });
290 out.set_selection(Selection::Prefix)
291 }
292 }
293
294 Ok(())
295 }
296}
297
298struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
299 lhs: VectorId,
300 rhs: T,
301 _phantom: PhantomData<Op>,
302}
303
304impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel
305 for ScalarComparePrimitiveKernel<T, Op>
306{
307 fn step(
308 &self,
309 ctx: &KernelContext,
310 _chunk_idx: usize,
311 selection: &BitView,
312 out: &mut ViewMut,
313 ) -> VortexResult<()> {
314 let lhs_vec = ctx.vector(self.lhs);
315 let lhs = lhs_vec.as_array::<T>();
316 let bools = out.as_array_mut::<bool>();
317
318 match lhs_vec.selection() {
319 Selection::Prefix => {
320 for i in 0..selection.true_count() {
321 bools[i] = Op::compare(&lhs[i], &self.rhs);
322 }
323 out.set_selection(Selection::Prefix)
324 }
325 Selection::Mask => {
326 // TODO(ngates): decide at what true count we should iter indices...
327 selection.iter_ones(|idx| {
328 bools[idx] = Op::compare(&lhs[idx], &self.rhs);
329 });
330 out.set_selection(Selection::Mask)
331 }
332 }
333
334 Ok(())
335 }
336}
337
338pub(crate) trait CompareOp<T> {
339 fn compare(lhs: &T, rhs: &T) -> bool;
340}
341
342/// Equality comparison operation.
343pub struct Eq;
344impl<T: PartialEq> CompareOp<T> for Eq {
345 #[inline(always)]
346 fn compare(lhs: &T, rhs: &T) -> bool {
347 lhs == rhs
348 }
349}
350
351/// Not equal comparison operation.
352pub struct NotEq;
353impl<T: PartialEq> CompareOp<T> for NotEq {
354 #[inline(always)]
355 fn compare(lhs: &T, rhs: &T) -> bool {
356 lhs != rhs
357 }
358}
359
360/// Greater than comparison operation.
361pub struct Gt;
362impl<T: PartialOrd> CompareOp<T> for Gt {
363 #[inline(always)]
364 fn compare(lhs: &T, rhs: &T) -> bool {
365 lhs > rhs
366 }
367}
368
369/// Greater than or equal comparison operation.
370pub struct Gte;
371impl<T: PartialOrd> CompareOp<T> for Gte {
372 #[inline(always)]
373 fn compare(lhs: &T, rhs: &T) -> bool {
374 lhs >= rhs
375 }
376}
377
378/// Less than comparison operation.
379pub struct Lt;
380impl<T: PartialOrd> CompareOp<T> for Lt {
381 #[inline(always)]
382 fn compare(lhs: &T, rhs: &T) -> bool {
383 lhs < rhs
384 }
385}
386
387/// Less than or equal comparison operation.
388pub struct Lte;
389impl<T: PartialOrd> CompareOp<T> for Lte {
390 #[inline(always)]
391 fn compare(lhs: &T, rhs: &T) -> bool {
392 lhs <= rhs
393 }
394}
395
396// TODO(ngates): bring these back!
397// #[cfg(test)]
398// mod tests {
399// use std::rc::Rc;
400//
401// use vortex_buffer::BufferMut;
402// use vortex_dtype::Nullability;
403// use vortex_scalar::Scalar;
404//
405// use crate::arrays::PrimitiveArray;
406// use crate::operator::bits::BitView;
407//
408// #[test]
409// fn test_scalar_compare_stacked_on_primitive() {
410// // Create input data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
411// let size = 16;
412// let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
413// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
414//
415// // Create scalar compare operator: primitive_value > 10
416// let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
417// let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
418// primitive_op,
419// BinaryOperator::Gt,
420// compare_value,
421// ));
422//
423// // Create query plan from the stacked operators
424// let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
425// let mut operator = plan.executable_plan().unwrap();
426//
427// // Create all-true mask for simplicity
428// let mask_data = [usize::MAX; N_WORDS];
429// let mask_view = BitView::new(&mask_data);
430//
431// // Create output buffer for boolean results
432// let mut output = BufferMut::<bool>::with_capacity(N);
433// unsafe { output.set_len(N) };
434// let mut output_view = ViewMut::new(&mut output[..], None);
435//
436// // Execute the operator
437// let result = operator._step(mask_view, &mut output_view);
438// assert!(result.is_ok());
439//
440// // Verify results: values 0-10 should be false, values 11-15 should be true
441// for i in 0..size {
442// let expected = i > 10;
443// assert_eq!(
444// output[i], expected,
445// "Position {}: expected {}, got {}",
446// i, expected, output[i]
447// );
448// }
449// }
450//
451// #[test]
452// fn test_scalar_compare_different_operators() {
453// // Test with different comparison operators
454// let size = 8;
455// let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
456//
457// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
458//
459// // Test Eq: values == 3
460// let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
461// let eq_op = Rc::new(ScalarCompareOperator::new(
462// primitive_op,
463// BinaryOperator::Eq,
464// compare_value,
465// ));
466//
467// let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
468// let mut operator = plan.executable_plan().unwrap();
469//
470// let mask_data = [usize::MAX; N_WORDS];
471// let mask_view = BitView::new(&mask_data);
472//
473// let mut output = BufferMut::<bool>::with_capacity(N);
474// unsafe { output.set_len(N) };
475// let mut output_view = ViewMut::new(&mut output[..], None);
476//
477// let result = operator._step(mask_view, &mut output_view);
478// assert!(result.is_ok());
479//
480// // Only position 3 should be true
481// for i in 0..size {
482// let expected = i == 3;
483// assert_eq!(
484// output[i], expected,
485// "Eq test - Position {}: expected {}, got {}",
486// i, expected, output[i]
487// );
488// }
489// }
490//
491// #[test]
492// fn test_scalar_compare_with_f32() {
493// // Test with floating-point values
494// let size = 8;
495// let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
496// let primitive_array = values.into_iter().collect::<PrimitiveArray>();
497//
498// let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
499//
500// // Test Lt: values < 3.5
501// let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
502// let lt_op = Rc::new(ScalarCompareOperator::new(
503// primitive_op,
504// BinaryOperator::Lt,
505// compare_value,
506// ));
507//
508// let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
509// let mut operator = plan.executable_plan().unwrap();
510//
511// let mask_data = [usize::MAX; N_WORDS];
512// let mask_view = BitView::new(&mask_data);
513//
514// let mut output = BufferMut::<bool>::with_capacity(N);
515// unsafe { output.set_len(N) };
516// let mut output_view = ViewMut::new(&mut output[..], None);
517//
518// let result = operator._step(mask_view, &mut output_view);
519// assert!(result.is_ok());
520//
521// // Values 0.5, 1.5, 2.5 should be < 3.5 (true), 3.5+ should be false
522// for i in 0..size {
523// let value = i as f32 + 0.5;
524// let expected = value < 3.5;
525// assert_eq!(
526// output[i], expected,
527// "Lt test - Position {}: value {} should be {}, got {}",
528// i, value, expected, output[i]
529// );
530// }
531// }
532// }