Skip to main content

vortex_array/scalar_fn/fns/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BooleanArray;
5use arrow_ord::cmp;
6use vortex_error::VortexResult;
7
8use crate::Array;
9use crate::ArrayRef;
10use crate::Canonical;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::arrays::ConstantVTable;
15use crate::arrays::ExactScalarFn;
16use crate::arrays::ScalarFnArrayView;
17use crate::arrays::ScalarFnVTable;
18use crate::arrow::Datum;
19use crate::arrow::IntoArrowArray;
20use crate::arrow::from_arrow_array_with_len;
21use crate::compute::compare_nested_arrow_arrays;
22use crate::dtype::DType;
23use crate::dtype::Nullability;
24use crate::kernel::ExecuteParentKernel;
25use crate::scalar::Scalar;
26use crate::scalar_fn::fns::binary::Binary;
27use crate::scalar_fn::fns::operators::CompareOperator;
28use crate::vtable::VTable;
29
30/// Trait for encoding-specific comparison kernels that operate in encoded space.
31///
32/// Implementations can compare an encoded array against another array (typically a constant)
33/// without first decompressing. The adaptor normalizes operand order so `array` is always
34/// the left-hand side, swapping the operator when necessary.
35pub trait CompareKernel: VTable {
36    fn compare(
37        lhs: &Self::Array,
38        rhs: &dyn Array,
39        operator: CompareOperator,
40        ctx: &mut ExecutionCtx,
41    ) -> VortexResult<Option<ArrayRef>>;
42}
43
44/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
45///
46/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
47/// this adaptor extracts the comparison operator and other operand, normalizes operand order
48/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
49#[derive(Default, Debug)]
50pub struct CompareExecuteAdaptor<V>(pub V);
51
52impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
53where
54    V: CompareKernel,
55{
56    type Parent = ExactScalarFn<Binary>;
57
58    fn execute_parent(
59        &self,
60        array: &V::Array,
61        parent: ScalarFnArrayView<'_, Binary>,
62        child_idx: usize,
63        ctx: &mut ExecutionCtx,
64    ) -> VortexResult<Option<ArrayRef>> {
65        // Only handle comparison operators
66        let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
67            return Ok(None);
68        };
69
70        // Get the ScalarFnArray to access children
71        let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
72            return Ok(None);
73        };
74        let children = scalar_fn_array.children();
75
76        // Normalize so `array` is always LHS, swapping the operator if needed
77        // TODO(joe): should be go this here or in the Rule/Kernel
78        let (cmp_op, other) = match child_idx {
79            0 => (cmp_op, &children[1]),
80            1 => (cmp_op.swap(), &children[0]),
81            _ => return Ok(None),
82        };
83
84        let len = array.len();
85        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
86
87        // Empty array → empty bool result
88        if len == 0 {
89            return Ok(Some(
90                Canonical::empty(&DType::Bool(nullable.into())).into_array(),
91            ));
92        }
93
94        // Null constant on either side → all-null bool result
95        if other.as_constant().is_some_and(|s| s.is_null()) {
96            return Ok(Some(
97                ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
98            ));
99        }
100
101        V::compare(array, other.as_ref(), cmp_op, ctx)
102    }
103}
104
105/// Execute a compare operation between two arrays.
106///
107/// This is the entry point for compare operations from the binary expression.
108/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
109pub(crate) fn execute_compare(
110    lhs: &dyn Array,
111    rhs: &dyn Array,
112    op: CompareOperator,
113) -> VortexResult<ArrayRef> {
114    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
115
116    if lhs.is_empty() {
117        return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
118    }
119
120    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
121    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
122    if left_constant_null || right_constant_null {
123        return Ok(
124            ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
125        );
126    }
127
128    // Constant-constant fast path
129    if let (Some(lhs_const), Some(rhs_const)) = (
130        lhs.as_opt::<ConstantVTable>(),
131        rhs.as_opt::<ConstantVTable>(),
132    ) {
133        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
134        return Ok(ConstantArray::new(result, lhs.len()).into_array());
135    }
136
137    arrow_compare_arrays(lhs, rhs, op)
138}
139
140/// Fall back to Arrow for comparison.
141fn arrow_compare_arrays(
142    left: &dyn Array,
143    right: &dyn Array,
144    operator: CompareOperator,
145) -> VortexResult<ArrayRef> {
146    assert_eq!(left.len(), right.len());
147
148    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
149
150    // Arrow's vectorized comparison kernels don't support nested types.
151    // For nested types, fall back to `make_comparator` which does element-wise comparison.
152    let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
153        let rhs = right.to_array().into_arrow_preferred()?;
154        let lhs = left.to_array().into_arrow(rhs.data_type())?;
155
156        assert!(
157            lhs.data_type().equals_datatype(rhs.data_type()),
158            "lhs data_type: {}, rhs data_type: {}",
159            lhs.data_type(),
160            rhs.data_type()
161        );
162
163        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
164    } else {
165        // Fast path: use vectorized kernels for primitive types.
166        let lhs = Datum::try_new(left)?;
167        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
168
169        match operator {
170            CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
171            CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
172            CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
173            CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
174            CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
175            CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
176        }
177    };
178    from_arrow_array_with_len(&array, left.len(), nullable)
179}
180
181pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
182    if lhs.is_null() | rhs.is_null() {
183        Scalar::null(DType::Bool(Nullability::Nullable))
184    } else {
185        let b = match operator {
186            CompareOperator::Eq => lhs == rhs,
187            CompareOperator::NotEq => lhs != rhs,
188            CompareOperator::Gt => lhs > rhs,
189            CompareOperator::Gte => lhs >= rhs,
190            CompareOperator::Lt => lhs < rhs,
191            CompareOperator::Lte => lhs <= rhs,
192        };
193
194        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use std::sync::Arc;
201
202    use rstest::rstest;
203    use vortex_buffer::buffer;
204
205    use crate::ArrayRef;
206    use crate::IntoArray;
207    use crate::ToCanonical;
208    use crate::arrays::BoolArray;
209    use crate::arrays::ConstantArray;
210    use crate::arrays::ListArray;
211    use crate::arrays::ListViewArray;
212    use crate::arrays::PrimitiveArray;
213    use crate::arrays::StructArray;
214    use crate::arrays::VarBinArray;
215    use crate::arrays::VarBinViewArray;
216    use crate::assert_arrays_eq;
217    use crate::builtins::ArrayBuiltins;
218    use crate::dtype::DType;
219    use crate::dtype::FieldName;
220    use crate::dtype::FieldNames;
221    use crate::dtype::Nullability;
222    use crate::dtype::PType;
223    use crate::scalar::Scalar;
224    use crate::scalar_fn::fns::operators::Operator;
225    use crate::test_harness::to_int_indices;
226    use crate::validity::Validity;
227
228    #[test]
229    fn test_bool_basic_comparisons() {
230        use vortex_buffer::BitBuffer;
231
232        let arr = BoolArray::new(
233            BitBuffer::from_iter([true, true, false, true, false]),
234            Validity::from_iter([false, true, true, true, true]),
235        );
236
237        let matches = arr
238            .to_array()
239            .binary(arr.to_array(), Operator::Eq)
240            .unwrap()
241            .to_bool();
242        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
243
244        let matches = arr
245            .to_array()
246            .binary(arr.to_array(), Operator::NotEq)
247            .unwrap()
248            .to_bool();
249        let empty: [u64; 0] = [];
250        assert_eq!(to_int_indices(matches).unwrap(), empty);
251
252        let other = BoolArray::new(
253            BitBuffer::from_iter([false, false, false, true, true]),
254            Validity::from_iter([false, true, true, true, true]),
255        );
256
257        let matches = arr
258            .to_array()
259            .binary(other.to_array(), Operator::Lte)
260            .unwrap()
261            .to_bool();
262        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
263
264        let matches = arr
265            .to_array()
266            .binary(other.to_array(), Operator::Lt)
267            .unwrap()
268            .to_bool();
269        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
270
271        let matches = other
272            .to_array()
273            .binary(arr.to_array(), Operator::Gte)
274            .unwrap()
275            .to_bool();
276        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
277
278        let matches = other
279            .to_array()
280            .binary(arr.to_array(), Operator::Gt)
281            .unwrap()
282            .to_bool();
283        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
284    }
285
286    #[test]
287    fn constant_compare() {
288        let left = ConstantArray::new(Scalar::from(2u32), 10);
289        let right = ConstantArray::new(Scalar::from(10u32), 10);
290
291        let result = left
292            .to_array()
293            .binary(right.to_array(), Operator::Gt)
294            .unwrap();
295        assert_eq!(result.len(), 10);
296        let scalar = result.scalar_at(0).unwrap();
297        assert_eq!(scalar.as_bool().value(), Some(false));
298    }
299
300    #[rstest]
301    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
302    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
303    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
304    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
305    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
306        let res = left.binary(right, Operator::Eq).unwrap();
307        let expected = BoolArray::from_iter([true, true]);
308        assert_arrays_eq!(res, expected);
309    }
310
311    #[ignore = "Arrow's ListView cannot be compared"]
312    #[test]
313    fn test_list_array_comparison() {
314        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
315        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
316        let list1 = ListArray::try_new(
317            values1.into_array(),
318            offsets1.into_array(),
319            Validity::NonNullable,
320        )
321        .unwrap();
322
323        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
324        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
325        let list2 = ListArray::try_new(
326            values2.into_array(),
327            offsets2.into_array(),
328            Validity::NonNullable,
329        )
330        .unwrap();
331
332        let result = list1
333            .to_array()
334            .binary(list2.to_array(), Operator::Eq)
335            .unwrap();
336        let expected = BoolArray::from_iter([true, true, false]);
337        assert_arrays_eq!(result, expected);
338
339        let result = list1
340            .to_array()
341            .binary(list2.to_array(), Operator::NotEq)
342            .unwrap();
343        let expected = BoolArray::from_iter([false, false, true]);
344        assert_arrays_eq!(result, expected);
345
346        let result = list1
347            .to_array()
348            .binary(list2.to_array(), Operator::Lt)
349            .unwrap();
350        let expected = BoolArray::from_iter([false, false, true]);
351        assert_arrays_eq!(result, expected);
352    }
353
354    #[ignore = "Arrow's ListView cannot be compared"]
355    #[test]
356    fn test_list_array_constant_comparison() {
357        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
358        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
359        let list = ListArray::try_new(
360            values.into_array(),
361            offsets.into_array(),
362            Validity::NonNullable,
363        )
364        .unwrap();
365
366        let list_scalar = Scalar::list(
367            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
368            vec![3i32.into(), 4i32.into()],
369            Nullability::NonNullable,
370        );
371        let constant = ConstantArray::new(list_scalar, 3);
372
373        let result = list
374            .to_array()
375            .binary(constant.to_array(), Operator::Eq)
376            .unwrap();
377        let expected = BoolArray::from_iter([false, true, false]);
378        assert_arrays_eq!(result, expected);
379    }
380
381    #[test]
382    fn test_struct_array_comparison() {
383        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
384        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
385
386        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
387        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
388
389        let struct1 = StructArray::from_fields(&[
390            ("bool_col", bool_field1.into_array()),
391            ("int_col", int_field1.into_array()),
392        ])
393        .unwrap();
394
395        let struct2 = StructArray::from_fields(&[
396            ("bool_col", bool_field2.into_array()),
397            ("int_col", int_field2.into_array()),
398        ])
399        .unwrap();
400
401        let result = struct1
402            .to_array()
403            .binary(struct2.to_array(), Operator::Eq)
404            .unwrap();
405        let expected = BoolArray::from_iter([true, true, false]);
406        assert_arrays_eq!(result, expected);
407
408        let result = struct1
409            .to_array()
410            .binary(struct2.to_array(), Operator::Gt)
411            .unwrap();
412        let expected = BoolArray::from_iter([false, false, true]);
413        assert_arrays_eq!(result, expected);
414    }
415
416    #[test]
417    fn test_empty_struct_compare() {
418        let empty1 = StructArray::try_new(
419            FieldNames::from(Vec::<FieldName>::new()),
420            Vec::new(),
421            5,
422            Validity::NonNullable,
423        )
424        .unwrap();
425
426        let empty2 = StructArray::try_new(
427            FieldNames::from(Vec::<FieldName>::new()),
428            Vec::new(),
429            5,
430            Validity::NonNullable,
431        )
432        .unwrap();
433
434        let result = empty1
435            .to_array()
436            .binary(empty2.to_array(), Operator::Eq)
437            .unwrap();
438        let expected = BoolArray::from_iter([true, true, true, true, true]);
439        assert_arrays_eq!(result, expected);
440    }
441
442    #[test]
443    fn test_empty_list() {
444        let list = ListViewArray::new(
445            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
446            buffer![0i32, 0i32, 0i32].into_array(),
447            buffer![0i32, 0i32, 0i32].into_array(),
448            Validity::AllValid,
449        );
450
451        let result = list
452            .to_array()
453            .binary(list.to_array(), Operator::Eq)
454            .unwrap();
455        assert!(result.scalar_at(0).unwrap().is_valid());
456        assert!(result.scalar_at(1).unwrap().is_valid());
457        assert!(result.scalar_at(2).unwrap().is_valid());
458    }
459}