vortex_array/operator/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12
13use crate::arrays::ConstantArray;
14use crate::compute::Operator as Op;
15use crate::operator::{LengthBounds, Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
16use crate::pipeline::bits::BitView;
17use crate::pipeline::vec::Selection;
18use crate::pipeline::view::ViewMut;
19use crate::pipeline::{
20    BindContext, Element, Kernel, KernelContext, PipelinedOperator, RowSelection, VectorId,
21};
22
23#[derive(Debug)]
24pub struct CompareOperator {
25    children: [OperatorRef; 2],
26    op: Op,
27    dtype: DType,
28}
29
30impl CompareOperator {
31    pub fn try_new(lhs: OperatorRef, rhs: OperatorRef, op: Op) -> VortexResult<CompareOperator> {
32        if lhs.dtype() != rhs.dtype() {
33            vortex_bail!(
34                "Cannot compare arrays with different dtypes: {} and {}",
35                lhs.dtype(),
36                rhs.dtype()
37            );
38        }
39
40        let lhs_const = lhs.as_any().downcast_ref::<ConstantArray>();
41        let rhs_const = rhs.as_any().downcast_ref::<ConstantArray>();
42        if lhs_const.is_some() && rhs_const.is_some() {
43            // TODO(ngates): we should return the Constant result!
44        }
45
46        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
47        let dtype = DType::Bool(nullability);
48
49        Ok(CompareOperator {
50            children: [lhs, rhs],
51            op,
52            dtype,
53        })
54    }
55
56    pub fn op(&self) -> Op {
57        self.op
58    }
59}
60
61impl OperatorHash for CompareOperator {
62    fn operator_hash<H: Hasher>(&self, state: &mut H) {
63        self.op.hash(state);
64        self.dtype.hash(state);
65        self.children.iter().for_each(|c| c.operator_hash(state));
66    }
67}
68
69impl OperatorEq for CompareOperator {
70    fn operator_eq(&self, other: &Self) -> bool {
71        self.op == other.op
72            && self.dtype == other.dtype
73            && self
74                .children
75                .iter()
76                .zip(other.children.iter())
77                .all(|(a, b)| a.operator_eq(b))
78    }
79}
80
81impl Operator for CompareOperator {
82    fn id(&self) -> OperatorId {
83        OperatorId::from("vortex.compare")
84    }
85
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89
90    fn dtype(&self) -> &DType {
91        &self.dtype
92    }
93
94    fn bounds(&self) -> LengthBounds {
95        self.children[0].bounds() & self.children[1].bounds()
96    }
97
98    fn children(&self) -> &[OperatorRef] {
99        &self.children
100    }
101
102    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
103        let (lhs, rhs) = children
104            .into_iter()
105            .tuples()
106            .next()
107            .vortex_expect("missing");
108        Ok(Arc::new(CompareOperator {
109            children: [lhs, rhs],
110            op: self.op,
111            dtype: self.dtype.clone(),
112        }))
113    }
114
115    fn as_pipelined(&self) -> Option<&dyn PipelinedOperator> {
116        // If both children support pipelining, but have different row selections, then we cannot
117        // pipeline without an alignment step (which we currently do not support).
118        if let Some((left, right)) = self.children[0]
119            .as_pipelined()
120            .zip(self.children[1].as_pipelined())
121            && left.row_selection() != right.row_selection()
122        {
123            return None;
124        }
125
126        Some(self)
127    }
128}
129
130macro_rules! match_each_compare_op {
131    ($self:expr, | $enc:ident | $body:block) => {{
132        match $self {
133            Op::Eq => {
134                type $enc = Eq;
135                $body
136            }
137            Op::NotEq => {
138                type $enc = NotEq;
139                $body
140            }
141            Op::Gt => {
142                type $enc = Gt;
143                $body
144            }
145            Op::Gte => {
146                type $enc = Gte;
147                $body
148            }
149            Op::Lt => {
150                type $enc = Lt;
151                $body
152            }
153            Op::Lte => {
154                type $enc = Lte;
155                $body
156            }
157        }
158    }};
159}
160
161impl PipelinedOperator for CompareOperator {
162    fn row_selection(&self) -> RowSelection {
163        self.children[0]
164            .as_pipelined()
165            .map(|p| p.row_selection())
166            .unwrap_or(RowSelection::All)
167    }
168
169    #[allow(clippy::cognitive_complexity)]
170    fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
171        debug_assert_eq!(self.children[0].dtype(), self.children[1].dtype());
172
173        let DType::Primitive(ptype, _) = self.children[0].dtype() else {
174            vortex_bail!(
175                "Unsupported type for comparison: {}",
176                self.children[0].dtype()
177            )
178        };
179
180        let lhs_const = self.children[0].as_any().downcast_ref::<ConstantArray>();
181        if let Some(lhs_const) = lhs_const {
182            // LHS is constant, use ScalarComparePrimitiveKernel
183            return match_each_native_ptype!(ptype, |T| {
184                match_each_compare_op!(self.op.swap(), |Op| {
185                    Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
186                        lhs: ctx.children()[1],
187                        rhs: lhs_const
188                            .scalar()
189                            .as_primitive()
190                            .typed_value::<T>()
191                            .vortex_expect("scalar value not of type T"),
192                        _phantom: PhantomData,
193                    }) as Box<dyn Kernel>)
194                })
195            });
196        }
197
198        let rhs_const = self.children[1].as_any().downcast_ref::<ConstantArray>();
199        if let Some(rhs_const) = rhs_const {
200            // RHS is constant, use ScalarComparePrimitiveKernel
201            return match_each_native_ptype!(ptype, |T| {
202                match_each_compare_op!(self.op, |Op| {
203                    Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
204                        lhs: ctx.children()[0],
205                        rhs: rhs_const
206                            .scalar()
207                            .as_primitive()
208                            .typed_value::<T>()
209                            .vortex_expect("scalar value not of type T"),
210                        _phantom: PhantomData,
211                    }) as Box<dyn Kernel>)
212                })
213            });
214        }
215
216        match_each_native_ptype!(ptype, |T| {
217            match_each_compare_op!(self.op, |Op| {
218                Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
219                    lhs: ctx.children()[0],
220                    rhs: ctx.children()[1],
221                    _phantom: PhantomData,
222                }) as Box<dyn Kernel>)
223            })
224        })
225    }
226
227    fn vector_children(&self) -> Vec<usize> {
228        vec![0, 1]
229    }
230
231    fn batch_children(&self) -> Vec<usize> {
232        vec![]
233    }
234}
235
236/// A compare operator for primitive types that compares two vectors element-wise using a binary
237/// operation.
238/// Kernel that performs primitive type comparisons between two input vectors.
239pub struct ComparePrimitiveKernel<T, Op> {
240    lhs: VectorId,
241    rhs: VectorId,
242    _phantom: PhantomData<(T, Op)>,
243}
244
245impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel for ComparePrimitiveKernel<T, Op> {
246    fn step(
247        &self,
248        ctx: &KernelContext,
249        _chunk_idx: usize,
250        selection: &BitView,
251        out: &mut ViewMut,
252    ) -> VortexResult<()> {
253        let lhs_vec = ctx.vector(self.lhs);
254        let lhs = lhs_vec.as_array::<T>();
255        let rhs_vec = ctx.vector(self.rhs);
256        let rhs = rhs_vec.as_array::<T>();
257        let bools = out.as_array_mut::<bool>();
258
259        match (lhs_vec.selection(), rhs_vec.selection()) {
260            (Selection::Prefix, Selection::Prefix) => {
261                for i in 0..selection.true_count() {
262                    bools[i] = Op::compare(&lhs[i], &rhs[i]);
263                }
264                out.set_selection(Selection::Prefix)
265            }
266            (Selection::Mask, Selection::Mask) => {
267                // TODO(ngates): check density to decide if we should iterate indices or do
268                //  a full scan
269                let mut pos = 0;
270                selection.iter_ones(|idx| {
271                    bools[pos] = Op::compare(&lhs[idx], &rhs[idx]);
272                    pos += 1;
273                });
274                out.set_selection(Selection::Prefix)
275            }
276            (Selection::Mask, Selection::Prefix) => {
277                let mut pos = 0;
278                selection.iter_ones(|idx| {
279                    bools[pos] = Op::compare(&lhs[idx], &rhs[pos]);
280                    pos += 1;
281                });
282                out.set_selection(Selection::Prefix)
283            }
284            (Selection::Prefix, Selection::Mask) => {
285                let mut pos = 0;
286                selection.iter_ones(|idx| {
287                    bools[pos] = Op::compare(&lhs[pos], &rhs[idx]);
288                    pos += 1;
289                });
290                out.set_selection(Selection::Prefix)
291            }
292        }
293
294        Ok(())
295    }
296}
297
298struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
299    lhs: VectorId,
300    rhs: T,
301    _phantom: PhantomData<Op>,
302}
303
304impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel
305    for ScalarComparePrimitiveKernel<T, Op>
306{
307    fn step(
308        &self,
309        ctx: &KernelContext,
310        _chunk_idx: usize,
311        selection: &BitView,
312        out: &mut ViewMut,
313    ) -> VortexResult<()> {
314        let lhs_vec = ctx.vector(self.lhs);
315        let lhs = lhs_vec.as_array::<T>();
316        let bools = out.as_array_mut::<bool>();
317
318        match lhs_vec.selection() {
319            Selection::Prefix => {
320                for i in 0..selection.true_count() {
321                    bools[i] = Op::compare(&lhs[i], &self.rhs);
322                }
323                out.set_selection(Selection::Prefix)
324            }
325            Selection::Mask => {
326                // TODO(ngates): decide at what true count we should iter indices...
327                selection.iter_ones(|idx| {
328                    bools[idx] = Op::compare(&lhs[idx], &self.rhs);
329                });
330                out.set_selection(Selection::Mask)
331            }
332        }
333
334        Ok(())
335    }
336}
337
338pub(crate) trait CompareOp<T> {
339    fn compare(lhs: &T, rhs: &T) -> bool;
340}
341
342/// Equality comparison operation.
343pub struct Eq;
344impl<T: PartialEq> CompareOp<T> for Eq {
345    #[inline(always)]
346    fn compare(lhs: &T, rhs: &T) -> bool {
347        lhs == rhs
348    }
349}
350
351/// Not equal comparison operation.
352pub struct NotEq;
353impl<T: PartialEq> CompareOp<T> for NotEq {
354    #[inline(always)]
355    fn compare(lhs: &T, rhs: &T) -> bool {
356        lhs != rhs
357    }
358}
359
360/// Greater than comparison operation.
361pub struct Gt;
362impl<T: PartialOrd> CompareOp<T> for Gt {
363    #[inline(always)]
364    fn compare(lhs: &T, rhs: &T) -> bool {
365        lhs > rhs
366    }
367}
368
369/// Greater than or equal comparison operation.
370pub struct Gte;
371impl<T: PartialOrd> CompareOp<T> for Gte {
372    #[inline(always)]
373    fn compare(lhs: &T, rhs: &T) -> bool {
374        lhs >= rhs
375    }
376}
377
378/// Less than comparison operation.
379pub struct Lt;
380impl<T: PartialOrd> CompareOp<T> for Lt {
381    #[inline(always)]
382    fn compare(lhs: &T, rhs: &T) -> bool {
383        lhs < rhs
384    }
385}
386
387/// Less than or equal comparison operation.
388pub struct Lte;
389impl<T: PartialOrd> CompareOp<T> for Lte {
390    #[inline(always)]
391    fn compare(lhs: &T, rhs: &T) -> bool {
392        lhs <= rhs
393    }
394}
395
396// TODO(ngates): bring these back!
397// #[cfg(test)]
398// mod tests {
399//     use std::rc::Rc;
400//
401//     use vortex_buffer::BufferMut;
402//     use vortex_dtype::Nullability;
403//     use vortex_scalar::Scalar;
404//
405//     use crate::arrays::PrimitiveArray;
406//     use crate::operator::bits::BitView;
407//
408//     #[test]
409//     fn test_scalar_compare_stacked_on_primitive() {
410//         // Create input data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
411//         let size = 16;
412//         let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
413//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
414//
415//         // Create scalar compare operator: primitive_value > 10
416//         let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
417//         let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
418//             primitive_op,
419//             BinaryOperator::Gt,
420//             compare_value,
421//         ));
422//
423//         // Create query plan from the stacked operators
424//         let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
425//         let mut operator = plan.executable_plan().unwrap();
426//
427//         // Create all-true mask for simplicity
428//         let mask_data = [usize::MAX; N_WORDS];
429//         let mask_view = BitView::new(&mask_data);
430//
431//         // Create output buffer for boolean results
432//         let mut output = BufferMut::<bool>::with_capacity(N);
433//         unsafe { output.set_len(N) };
434//         let mut output_view = ViewMut::new(&mut output[..], None);
435//
436//         // Execute the operator
437//         let result = operator._step(mask_view, &mut output_view);
438//         assert!(result.is_ok());
439//
440//         // Verify results: values 0-10 should be false, values 11-15 should be true
441//         for i in 0..size {
442//             let expected = i > 10;
443//             assert_eq!(
444//                 output[i], expected,
445//                 "Position {}: expected {}, got {}",
446//                 i, expected, output[i]
447//             );
448//         }
449//     }
450//
451//     #[test]
452//     fn test_scalar_compare_different_operators() {
453//         // Test with different comparison operators
454//         let size = 8;
455//         let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
456//
457//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
458//
459//         // Test Eq: values == 3
460//         let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
461//         let eq_op = Rc::new(ScalarCompareOperator::new(
462//             primitive_op,
463//             BinaryOperator::Eq,
464//             compare_value,
465//         ));
466//
467//         let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
468//         let mut operator = plan.executable_plan().unwrap();
469//
470//         let mask_data = [usize::MAX; N_WORDS];
471//         let mask_view = BitView::new(&mask_data);
472//
473//         let mut output = BufferMut::<bool>::with_capacity(N);
474//         unsafe { output.set_len(N) };
475//         let mut output_view = ViewMut::new(&mut output[..], None);
476//
477//         let result = operator._step(mask_view, &mut output_view);
478//         assert!(result.is_ok());
479//
480//         // Only position 3 should be true
481//         for i in 0..size {
482//             let expected = i == 3;
483//             assert_eq!(
484//                 output[i], expected,
485//                 "Eq test - Position {}: expected {}, got {}",
486//                 i, expected, output[i]
487//             );
488//         }
489//     }
490//
491//     #[test]
492//     fn test_scalar_compare_with_f32() {
493//         // Test with floating-point values
494//         let size = 8;
495//         let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
496//         let primitive_array = values.into_iter().collect::<PrimitiveArray>();
497//
498//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
499//
500//         // Test Lt: values < 3.5
501//         let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
502//         let lt_op = Rc::new(ScalarCompareOperator::new(
503//             primitive_op,
504//             BinaryOperator::Lt,
505//             compare_value,
506//         ));
507//
508//         let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
509//         let mut operator = plan.executable_plan().unwrap();
510//
511//         let mask_data = [usize::MAX; N_WORDS];
512//         let mask_view = BitView::new(&mask_data);
513//
514//         let mut output = BufferMut::<bool>::with_capacity(N);
515//         unsafe { output.set_len(N) };
516//         let mut output_view = ViewMut::new(&mut output[..], None);
517//
518//         let result = operator._step(mask_view, &mut output_view);
519//         assert!(result.is_ok());
520//
521//         // Values 0.5, 1.5, 2.5 should be < 3.5 (true), 3.5+ should be false
522//         for i in 0..size {
523//             let value = i as f32 + 0.5;
524//             let expected = value < 3.5;
525//             assert_eq!(
526//                 output[i], expected,
527//                 "Lt test - Position {}: value {} should be {}, got {}",
528//                 i, value, expected, output[i]
529//             );
530//         }
531//     }
532// }