vortex_array/operator/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use itertools::Itertools;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12
13use crate::arrays::ConstantArray;
14use crate::compute::Operator as Op;
15use crate::operator::{Operator, OperatorEq, OperatorHash, OperatorId, OperatorRef};
16use crate::pipeline::view::ViewMut;
17use crate::pipeline::{BindContext, Element, Kernel, KernelContext, PipelinedOperator, VectorId};
18
19#[derive(Debug)]
20pub struct CompareOperator {
21    children: [OperatorRef; 2],
22    op: Op,
23    dtype: DType,
24}
25
26impl CompareOperator {
27    pub fn try_new(lhs: OperatorRef, rhs: OperatorRef, op: Op) -> VortexResult<CompareOperator> {
28        if lhs.dtype() != rhs.dtype() {
29            vortex_bail!(
30                "Cannot compare arrays with different dtypes: {} and {}",
31                lhs.dtype(),
32                rhs.dtype()
33            );
34        }
35
36        let lhs_const = lhs.as_any().downcast_ref::<ConstantArray>();
37        let rhs_const = rhs.as_any().downcast_ref::<ConstantArray>();
38        if lhs_const.is_some() && rhs_const.is_some() {
39            // TODO(ngates): we should return the Constant result!
40        }
41
42        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
43        let dtype = DType::Bool(nullability);
44
45        Ok(CompareOperator {
46            children: [lhs, rhs],
47            op,
48            dtype,
49        })
50    }
51
52    pub fn op(&self) -> Op {
53        self.op
54    }
55}
56
57impl OperatorHash for CompareOperator {
58    fn operator_hash<H: Hasher>(&self, state: &mut H) {
59        self.op.hash(state);
60        self.dtype.hash(state);
61        self.children.iter().for_each(|c| c.operator_hash(state));
62    }
63}
64
65impl OperatorEq for CompareOperator {
66    fn operator_eq(&self, other: &Self) -> bool {
67        self.op == other.op
68            && self.dtype == other.dtype
69            && self
70                .children
71                .iter()
72                .zip(other.children.iter())
73                .all(|(a, b)| a.operator_eq(b))
74    }
75}
76
77impl Operator for CompareOperator {
78    fn id(&self) -> OperatorId {
79        OperatorId::from("vortex.compare")
80    }
81
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn dtype(&self) -> &DType {
87        &self.dtype
88    }
89
90    fn len(&self) -> usize {
91        self.children[0].len() & self.children[1].len()
92    }
93
94    fn children(&self) -> &[OperatorRef] {
95        &self.children
96    }
97
98    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99        let (lhs, rhs) = children
100            .into_iter()
101            .tuples()
102            .next()
103            .vortex_expect("missing");
104        Ok(Arc::new(CompareOperator {
105            children: [lhs, rhs],
106            op: self.op,
107            dtype: self.dtype.clone(),
108        }))
109    }
110
111    fn as_pipelined(&self) -> Option<&dyn PipelinedOperator> {
112        Some(self)
113    }
114}
115
116macro_rules! match_each_compare_op {
117    ($self:expr, | $enc:ident | $body:block) => {{
118        match $self {
119            Op::Eq => {
120                type $enc = Eq;
121                $body
122            }
123            Op::NotEq => {
124                type $enc = NotEq;
125                $body
126            }
127            Op::Gt => {
128                type $enc = Gt;
129                $body
130            }
131            Op::Gte => {
132                type $enc = Gte;
133                $body
134            }
135            Op::Lt => {
136                type $enc = Lt;
137                $body
138            }
139            Op::Lte => {
140                type $enc = Lte;
141                $body
142            }
143        }
144    }};
145}
146
147impl PipelinedOperator for CompareOperator {
148    #[allow(clippy::cognitive_complexity)]
149    fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
150        debug_assert_eq!(self.children[0].dtype(), self.children[1].dtype());
151
152        let DType::Primitive(ptype, _) = self.children[0].dtype() else {
153            vortex_bail!(
154                "Unsupported type for comparison: {}",
155                self.children[0].dtype()
156            )
157        };
158
159        let lhs_const = self.children[0].as_any().downcast_ref::<ConstantArray>();
160        if let Some(lhs_const) = lhs_const {
161            // LHS is constant, use ScalarComparePrimitiveKernel
162            return match_each_native_ptype!(ptype, |T| {
163                match_each_compare_op!(self.op.swap(), |Op| {
164                    Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
165                        lhs: ctx.children()[1],
166                        rhs: lhs_const
167                            .scalar()
168                            .as_primitive()
169                            .typed_value::<T>()
170                            .vortex_expect("scalar value not of type T"),
171                        _phantom: PhantomData,
172                    }) as Box<dyn Kernel>)
173                })
174            });
175        }
176
177        let rhs_const = self.children[1].as_any().downcast_ref::<ConstantArray>();
178        if let Some(rhs_const) = rhs_const {
179            // RHS is constant, use ScalarComparePrimitiveKernel
180            return match_each_native_ptype!(ptype, |T| {
181                match_each_compare_op!(self.op, |Op| {
182                    Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
183                        lhs: ctx.children()[0],
184                        rhs: rhs_const
185                            .scalar()
186                            .as_primitive()
187                            .typed_value::<T>()
188                            .vortex_expect("scalar value not of type T"),
189                        _phantom: PhantomData,
190                    }) as Box<dyn Kernel>)
191                })
192            });
193        }
194
195        match_each_native_ptype!(ptype, |T| {
196            match_each_compare_op!(self.op, |Op| {
197                Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
198                    lhs: ctx.children()[0],
199                    rhs: ctx.children()[1],
200                    _phantom: PhantomData,
201                }) as Box<dyn Kernel>)
202            })
203        })
204    }
205
206    fn vector_children(&self) -> Vec<usize> {
207        vec![0, 1]
208    }
209
210    fn batch_children(&self) -> Vec<usize> {
211        vec![]
212    }
213}
214
215/// A compare operator for primitive types that compares two vectors element-wise using a binary
216/// operation.
217/// Kernel that performs primitive type comparisons between two input vectors.
218pub struct ComparePrimitiveKernel<T, Op> {
219    lhs: VectorId,
220    rhs: VectorId,
221    _phantom: PhantomData<(T, Op)>,
222}
223
224impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel for ComparePrimitiveKernel<T, Op> {
225    fn step(&mut self, ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
226        let lhs_vec = ctx.vector(self.lhs);
227        let lhs = lhs_vec.as_slice::<T>();
228        let rhs_vec = ctx.vector(self.rhs);
229        let rhs = rhs_vec.as_slice::<T>();
230        let bools = out.as_slice_mut::<bool>();
231
232        assert_eq!(
233            lhs.len(),
234            rhs.len(),
235            "LHS and RHS must have the same length"
236        );
237
238        lhs.iter()
239            .zip(rhs.iter())
240            .zip(bools)
241            .for_each(|((lhs, rhs), bool)| *bool = Op::compare(lhs, rhs));
242
243        out.set_len(lhs.len());
244
245        Ok(())
246    }
247}
248
249struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
250    lhs: VectorId,
251    rhs: T,
252    _phantom: PhantomData<Op>,
253}
254
255impl<T: Element + NativePType, Op: CompareOp<T> + Send> Kernel
256    for ScalarComparePrimitiveKernel<T, Op>
257{
258    fn step(&mut self, ctx: &KernelContext, out: &mut ViewMut) -> VortexResult<()> {
259        let lhs_vec = ctx.vector(self.lhs);
260        let lhs = lhs_vec.as_slice::<T>();
261        let bools = out.as_slice_mut::<bool>();
262
263        // Note we zip only over the shortest iterator which is LHS
264        lhs.iter().zip(bools).for_each(|(lhs, bool)| {
265            *bool = Op::compare(lhs, &self.rhs);
266        });
267        out.set_len(lhs.len());
268
269        Ok(())
270    }
271}
272
273pub(crate) trait CompareOp<T> {
274    fn compare(lhs: &T, rhs: &T) -> bool;
275}
276
277/// Equality comparison operation.
278pub struct Eq;
279impl<T: PartialEq> CompareOp<T> for Eq {
280    #[inline(always)]
281    fn compare(lhs: &T, rhs: &T) -> bool {
282        lhs == rhs
283    }
284}
285
286/// Not equal comparison operation.
287pub struct NotEq;
288impl<T: PartialEq> CompareOp<T> for NotEq {
289    #[inline(always)]
290    fn compare(lhs: &T, rhs: &T) -> bool {
291        lhs != rhs
292    }
293}
294
295/// Greater than comparison operation.
296pub struct Gt;
297impl<T: PartialOrd> CompareOp<T> for Gt {
298    #[inline(always)]
299    fn compare(lhs: &T, rhs: &T) -> bool {
300        lhs > rhs
301    }
302}
303
304/// Greater than or equal comparison operation.
305pub struct Gte;
306impl<T: PartialOrd> CompareOp<T> for Gte {
307    #[inline(always)]
308    fn compare(lhs: &T, rhs: &T) -> bool {
309        lhs >= rhs
310    }
311}
312
313/// Less than comparison operation.
314pub struct Lt;
315impl<T: PartialOrd> CompareOp<T> for Lt {
316    #[inline(always)]
317    fn compare(lhs: &T, rhs: &T) -> bool {
318        lhs < rhs
319    }
320}
321
322/// Less than or equal comparison operation.
323pub struct Lte;
324impl<T: PartialOrd> CompareOp<T> for Lte {
325    #[inline(always)]
326    fn compare(lhs: &T, rhs: &T) -> bool {
327        lhs <= rhs
328    }
329}
330
331// TODO(ngates): bring these back!
332// #[cfg(test)]
333// mod tests {
334//     use std::rc::Rc;
335//
336//     use vortex_buffer::BufferMut;
337//     use vortex_dtype::Nullability;
338//     use vortex_scalar::Scalar;
339//
340//     use crate::arrays::PrimitiveArray;
341//     use crate::operator::bits::BitView;
342//
343//     #[test]
344//     fn test_scalar_compare_stacked_on_primitive() {
345//         // Create input data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
346//         let size = 16;
347//         let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
348//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
349//
350//         // Create scalar compare operator: primitive_value > 10
351//         let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
352//         let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
353//             primitive_op,
354//             BinaryOperator::Gt,
355//             compare_value,
356//         ));
357//
358//         // Create query plan from the stacked operators
359//         let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
360//         let mut operator = plan.executable_plan().unwrap();
361//
362//         // Create all-true mask for simplicity
363//         let mask_data = [usize::MAX; N_WORDS];
364//         let mask_view = BitView::new(&mask_data);
365//
366//         // Create output buffer for boolean results
367//         let mut output = BufferMut::<bool>::with_capacity(N);
368//         unsafe { output.set_len(N) };
369//         let mut output_view = ViewMut::new(&mut output[..], None);
370//
371//         // Execute the operator
372//         let result = operator._step(mask_view, &mut output_view);
373//         assert!(result.is_ok());
374//
375//         // Verify results: values 0-10 should be false, values 11-15 should be true
376//         for i in 0..size {
377//             let expected = i > 10;
378//             assert_eq!(
379//                 output[i], expected,
380//                 "Position {}: expected {}, got {}",
381//                 i, expected, output[i]
382//             );
383//         }
384//     }
385//
386//     #[test]
387//     fn test_scalar_compare_different_operators() {
388//         // Test with different comparison operators
389//         let size = 8;
390//         let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
391//
392//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
393//
394//         // Test Eq: values == 3
395//         let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
396//         let eq_op = Rc::new(ScalarCompareOperator::new(
397//             primitive_op,
398//             BinaryOperator::Eq,
399//             compare_value,
400//         ));
401//
402//         let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
403//         let mut operator = plan.executable_plan().unwrap();
404//
405//         let mask_data = [usize::MAX; N_WORDS];
406//         let mask_view = BitView::new(&mask_data);
407//
408//         let mut output = BufferMut::<bool>::with_capacity(N);
409//         unsafe { output.set_len(N) };
410//         let mut output_view = ViewMut::new(&mut output[..], None);
411//
412//         let result = operator._step(mask_view, &mut output_view);
413//         assert!(result.is_ok());
414//
415//         // Only position 3 should be true
416//         for i in 0..size {
417//             let expected = i == 3;
418//             assert_eq!(
419//                 output[i], expected,
420//                 "Eq test - Position {}: expected {}, got {}",
421//                 i, expected, output[i]
422//             );
423//         }
424//     }
425//
426//     #[test]
427//     fn test_scalar_compare_with_f32() {
428//         // Test with floating-point values
429//         let size = 8;
430//         let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
431//         let primitive_array = values.into_iter().collect::<PrimitiveArray>();
432//
433//         let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
434//
435//         // Test Lt: values < 3.5
436//         let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
437//         let lt_op = Rc::new(ScalarCompareOperator::new(
438//             primitive_op,
439//             BinaryOperator::Lt,
440//             compare_value,
441//         ));
442//
443//         let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
444//         let mut operator = plan.executable_plan().unwrap();
445//
446//         let mask_data = [usize::MAX; N_WORDS];
447//         let mask_view = BitView::new(&mask_data);
448//
449//         let mut output = BufferMut::<bool>::with_capacity(N);
450//         unsafe { output.set_len(N) };
451//         let mut output_view = ViewMut::new(&mut output[..], None);
452//
453//         let result = operator._step(mask_view, &mut output_view);
454//         assert!(result.is_ok());
455//
456//         // Values 0.5, 1.5, 2.5 should be < 3.5 (true), 3.5+ should be false
457//         for i in 0..size {
458//             let value = i as f32 + 0.5;
459//             let expected = value < 3.5;
460//             assert_eq!(
461//                 output[i], expected,
462//                 "Lt test - Position {}: value {} should be {}, got {}",
463//                 i, value, expected, output[i]
464//             );
465//         }
466//     }
467// }