Skip to main content

vortex_array/scalar_fn/fns/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12
13use crate::ArrayRef;
14use crate::Canonical;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::ConstantVTable;
19use crate::arrays::ScalarFnVTable;
20use crate::arrays::scalar_fn::ExactScalarFn;
21use crate::arrays::scalar_fn::ScalarFnArrayView;
22use crate::arrow::Datum;
23use crate::arrow::IntoArrowArray;
24use crate::arrow::from_arrow_array_with_len;
25use crate::dtype::DType;
26use crate::dtype::Nullability;
27use crate::kernel::ExecuteParentKernel;
28use crate::scalar::Scalar;
29use crate::scalar_fn::fns::binary::Binary;
30use crate::scalar_fn::fns::operators::CompareOperator;
31use crate::vtable::VTable;
32
33/// Trait for encoding-specific comparison kernels that operate in encoded space.
34///
35/// Implementations can compare an encoded array against another array (typically a constant)
36/// without first decompressing. The adaptor normalizes operand order so `array` is always
37/// the left-hand side, swapping the operator when necessary.
38pub trait CompareKernel: VTable {
39    fn compare(
40        lhs: &Self::Array,
41        rhs: &ArrayRef,
42        operator: CompareOperator,
43        ctx: &mut ExecutionCtx,
44    ) -> VortexResult<Option<ArrayRef>>;
45}
46
47/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
48///
49/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
50/// this adaptor extracts the comparison operator and other operand, normalizes operand order
51/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
52#[derive(Default, Debug)]
53pub struct CompareExecuteAdaptor<V>(pub V);
54
55impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
56where
57    V: CompareKernel,
58{
59    type Parent = ExactScalarFn<Binary>;
60
61    fn execute_parent(
62        &self,
63        array: &V::Array,
64        parent: ScalarFnArrayView<'_, Binary>,
65        child_idx: usize,
66        ctx: &mut ExecutionCtx,
67    ) -> VortexResult<Option<ArrayRef>> {
68        // Only handle comparison operators
69        let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
70            return Ok(None);
71        };
72
73        // Get the ScalarFnArray to access children
74        let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
75            return Ok(None);
76        };
77        let children = scalar_fn_array.children();
78
79        // Normalize so `array` is always LHS, swapping the operator if needed
80        // TODO(joe): should be go this here or in the Rule/Kernel
81        let (cmp_op, other) = match child_idx {
82            0 => (cmp_op, &children[1]),
83            1 => (cmp_op.swap(), &children[0]),
84            _ => return Ok(None),
85        };
86
87        let len = array.len();
88        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
89
90        // Empty array → empty bool result
91        if len == 0 {
92            return Ok(Some(
93                Canonical::empty(&DType::Bool(nullable.into())).into_array(),
94            ));
95        }
96
97        // Null constant on either side → all-null bool result
98        if other.as_constant().is_some_and(|s| s.is_null()) {
99            return Ok(Some(
100                ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
101            ));
102        }
103
104        V::compare(array, other, cmp_op, ctx)
105    }
106}
107
108/// Execute a compare operation between two arrays.
109///
110/// This is the entry point for compare operations from the binary expression.
111/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
112pub(crate) fn execute_compare(
113    lhs: &ArrayRef,
114    rhs: &ArrayRef,
115    op: CompareOperator,
116) -> VortexResult<ArrayRef> {
117    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
118
119    if lhs.is_empty() {
120        return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
121    }
122
123    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
124    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
125    if left_constant_null || right_constant_null {
126        return Ok(
127            ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
128        );
129    }
130
131    // Constant-constant fast path
132    if let (Some(lhs_const), Some(rhs_const)) = (
133        lhs.as_opt::<ConstantVTable>(),
134        rhs.as_opt::<ConstantVTable>(),
135    ) {
136        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
137        return Ok(ConstantArray::new(result, lhs.len()).into_array());
138    }
139
140    arrow_compare_arrays(lhs, rhs, op)
141}
142
143/// Fall back to Arrow for comparison.
144fn arrow_compare_arrays(
145    left: &ArrayRef,
146    right: &ArrayRef,
147    operator: CompareOperator,
148) -> VortexResult<ArrayRef> {
149    assert_eq!(left.len(), right.len());
150
151    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
152
153    // Arrow's vectorized comparison kernels don't support nested types.
154    // For nested types, fall back to `make_comparator` which does element-wise comparison.
155    let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
156        let rhs = right.to_array().into_arrow_preferred()?;
157        let lhs = left.to_array().into_arrow(rhs.data_type())?;
158
159        assert!(
160            lhs.data_type().equals_datatype(rhs.data_type()),
161            "lhs data_type: {}, rhs data_type: {}",
162            lhs.data_type(),
163            rhs.data_type()
164        );
165
166        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
167    } else {
168        // Fast path: use vectorized kernels for primitive types.
169        let lhs = Datum::try_new(left)?;
170        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
171
172        match operator {
173            CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
174            CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
175            CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
176            CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
177            CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
178            CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
179        }
180    };
181    from_arrow_array_with_len(&array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
185    if lhs.is_null() | rhs.is_null() {
186        Scalar::null(DType::Bool(Nullability::Nullable))
187    } else {
188        let b = match operator {
189            CompareOperator::Eq => lhs == rhs,
190            CompareOperator::NotEq => lhs != rhs,
191            CompareOperator::Gt => lhs > rhs,
192            CompareOperator::Gte => lhs >= rhs,
193            CompareOperator::Lt => lhs < rhs,
194            CompareOperator::Lte => lhs <= rhs,
195        };
196
197        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
198    }
199}
200
201/// Compare two Arrow arrays element-wise using [`make_comparator`].
202///
203/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
204/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
205///
206/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
207/// prefer using the vectorized kernels directly for better performance.
208pub fn compare_nested_arrow_arrays(
209    lhs: &dyn arrow_array::Array,
210    rhs: &dyn arrow_array::Array,
211    operator: CompareOperator,
212) -> VortexResult<BooleanArray> {
213    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
214
215    let cmp_fn = match operator {
216        CompareOperator::Eq => Ordering::is_eq,
217        CompareOperator::NotEq => Ordering::is_ne,
218        CompareOperator::Gt => Ordering::is_gt,
219        CompareOperator::Gte => Ordering::is_ge,
220        CompareOperator::Lt => Ordering::is_lt,
221        CompareOperator::Lte => Ordering::is_le,
222    };
223
224    let values = (0..lhs.len())
225        .map(|i| cmp_fn(compare_arrays_at(i, i)))
226        .collect();
227    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
228
229    Ok(BooleanArray::new(values, nulls))
230}
231
232#[cfg(test)]
233mod tests {
234    use std::sync::Arc;
235
236    use rstest::rstest;
237    use vortex_buffer::buffer;
238
239    use crate::ArrayRef;
240    use crate::IntoArray;
241    use crate::ToCanonical;
242    use crate::arrays::BoolArray;
243    use crate::arrays::ListArray;
244    use crate::arrays::ListViewArray;
245    use crate::arrays::PrimitiveArray;
246    use crate::arrays::StructArray;
247    use crate::arrays::VarBinArray;
248    use crate::arrays::VarBinViewArray;
249    use crate::assert_arrays_eq;
250    use crate::builtins::ArrayBuiltins;
251    use crate::dtype::DType;
252    use crate::dtype::FieldName;
253    use crate::dtype::FieldNames;
254    use crate::dtype::Nullability;
255    use crate::dtype::PType;
256    use crate::scalar::Scalar;
257    use crate::scalar_fn::fns::binary::compare::ConstantArray;
258    use crate::scalar_fn::fns::operators::Operator;
259    use crate::test_harness::to_int_indices;
260    use crate::validity::Validity;
261
262    #[test]
263    fn test_bool_basic_comparisons() {
264        use vortex_buffer::BitBuffer;
265
266        let arr = BoolArray::new(
267            BitBuffer::from_iter([true, true, false, true, false]),
268            Validity::from_iter([false, true, true, true, true]),
269        );
270
271        let matches = arr
272            .clone()
273            .into_array()
274            .binary(arr.clone().into_array(), Operator::Eq)
275            .unwrap()
276            .to_bool();
277        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
278
279        let matches = arr
280            .clone()
281            .into_array()
282            .binary(arr.clone().into_array(), Operator::NotEq)
283            .unwrap()
284            .to_bool();
285        let empty: [u64; 0] = [];
286        assert_eq!(to_int_indices(matches).unwrap(), empty);
287
288        let other = BoolArray::new(
289            BitBuffer::from_iter([false, false, false, true, true]),
290            Validity::from_iter([false, true, true, true, true]),
291        );
292
293        let matches = arr
294            .clone()
295            .into_array()
296            .binary(other.clone().into_array(), Operator::Lte)
297            .unwrap()
298            .to_bool();
299        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
300
301        let matches = arr
302            .clone()
303            .into_array()
304            .binary(other.clone().into_array(), Operator::Lt)
305            .unwrap()
306            .to_bool();
307        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
308
309        let matches = other
310            .clone()
311            .into_array()
312            .binary(arr.clone().into_array(), Operator::Gte)
313            .unwrap()
314            .to_bool();
315        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
316
317        let matches = other
318            .into_array()
319            .binary(arr.into_array(), Operator::Gt)
320            .unwrap()
321            .to_bool();
322        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
323    }
324
325    #[test]
326    fn constant_compare() {
327        let left = ConstantArray::new(Scalar::from(2u32), 10);
328        let right = ConstantArray::new(Scalar::from(10u32), 10);
329
330        let result = left
331            .into_array()
332            .binary(right.into_array(), Operator::Gt)
333            .unwrap();
334        assert_eq!(result.len(), 10);
335        let scalar = result.scalar_at(0).unwrap();
336        assert_eq!(scalar.as_bool().value(), Some(false));
337    }
338
339    #[rstest]
340    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
341    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
342    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
343    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
344    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
345        let res = left.binary(right, Operator::Eq).unwrap();
346        let expected = BoolArray::from_iter([true, true]);
347        assert_arrays_eq!(res, expected);
348    }
349
350    #[ignore = "Arrow's ListView cannot be compared"]
351    #[test]
352    fn test_list_array_comparison() {
353        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
354        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
355        let list1 = ListArray::try_new(
356            values1.into_array(),
357            offsets1.into_array(),
358            Validity::NonNullable,
359        )
360        .unwrap();
361
362        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
363        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
364        let list2 = ListArray::try_new(
365            values2.into_array(),
366            offsets2.into_array(),
367            Validity::NonNullable,
368        )
369        .unwrap();
370
371        let result = list1
372            .clone()
373            .into_array()
374            .binary(list2.clone().into_array(), Operator::Eq)
375            .unwrap();
376        let expected = BoolArray::from_iter([true, true, false]);
377        assert_arrays_eq!(result, expected);
378
379        let result = list1
380            .clone()
381            .into_array()
382            .binary(list2.clone().into_array(), Operator::NotEq)
383            .unwrap();
384        let expected = BoolArray::from_iter([false, false, true]);
385        assert_arrays_eq!(result, expected);
386
387        let result = list1
388            .into_array()
389            .binary(list2.into_array(), Operator::Lt)
390            .unwrap();
391        let expected = BoolArray::from_iter([false, false, true]);
392        assert_arrays_eq!(result, expected);
393    }
394
395    #[ignore = "Arrow's ListView cannot be compared"]
396    #[test]
397    fn test_list_array_constant_comparison() {
398        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
399        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
400        let list = ListArray::try_new(
401            values.into_array(),
402            offsets.into_array(),
403            Validity::NonNullable,
404        )
405        .unwrap();
406
407        let list_scalar = Scalar::list(
408            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
409            vec![3i32.into(), 4i32.into()],
410            Nullability::NonNullable,
411        );
412        let constant = ConstantArray::new(list_scalar, 3);
413
414        let result = list
415            .into_array()
416            .binary(constant.into_array(), Operator::Eq)
417            .unwrap();
418        let expected = BoolArray::from_iter([false, true, false]);
419        assert_arrays_eq!(result, expected);
420    }
421
422    #[test]
423    fn test_struct_array_comparison() {
424        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
425        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
426
427        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
428        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
429
430        let struct1 = StructArray::from_fields(&[
431            ("bool_col", bool_field1.into_array()),
432            ("int_col", int_field1.into_array()),
433        ])
434        .unwrap();
435
436        let struct2 = StructArray::from_fields(&[
437            ("bool_col", bool_field2.into_array()),
438            ("int_col", int_field2.into_array()),
439        ])
440        .unwrap();
441
442        let result = struct1
443            .clone()
444            .into_array()
445            .binary(struct2.clone().into_array(), Operator::Eq)
446            .unwrap();
447        let expected = BoolArray::from_iter([true, true, false]);
448        assert_arrays_eq!(result, expected);
449
450        let result = struct1
451            .into_array()
452            .binary(struct2.into_array(), Operator::Gt)
453            .unwrap();
454        let expected = BoolArray::from_iter([false, false, true]);
455        assert_arrays_eq!(result, expected);
456    }
457
458    #[test]
459    fn test_empty_struct_compare() {
460        let empty1 = StructArray::try_new(
461            FieldNames::from(Vec::<FieldName>::new()),
462            Vec::new(),
463            5,
464            Validity::NonNullable,
465        )
466        .unwrap();
467
468        let empty2 = StructArray::try_new(
469            FieldNames::from(Vec::<FieldName>::new()),
470            Vec::new(),
471            5,
472            Validity::NonNullable,
473        )
474        .unwrap();
475
476        let result = empty1
477            .into_array()
478            .binary(empty2.into_array(), Operator::Eq)
479            .unwrap();
480        let expected = BoolArray::from_iter([true, true, true, true, true]);
481        assert_arrays_eq!(result, expected);
482    }
483
484    #[test]
485    fn test_empty_list() {
486        let list = ListViewArray::new(
487            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
488            buffer![0i32, 0i32, 0i32].into_array(),
489            buffer![0i32, 0i32, 0i32].into_array(),
490            Validity::AllValid,
491        );
492
493        let result = list
494            .clone()
495            .into_array()
496            .binary(list.into_array(), Operator::Eq)
497            .unwrap();
498        assert!(result.scalar_at(0).unwrap().is_valid());
499        assert!(result.scalar_at(1).unwrap().is_valid());
500        assert!(result.scalar_at(2).unwrap().is_valid());
501    }
502}