Skip to main content

vortex_array/scalar_fn/fns/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::arrays::Constant;
19use crate::arrays::ConstantArray;
20use crate::arrays::ScalarFnVTable;
21use crate::arrays::scalar_fn::ExactScalarFn;
22use crate::arrays::scalar_fn::ScalarFnArrayView;
23use crate::arrow::Datum;
24use crate::arrow::IntoArrowArray;
25use crate::arrow::from_arrow_array_with_len;
26use crate::dtype::DType;
27use crate::dtype::Nullability;
28use crate::kernel::ExecuteParentKernel;
29use crate::scalar::Scalar;
30use crate::scalar_fn::fns::binary::Binary;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::vtable::VTable;
33
34/// Trait for encoding-specific comparison kernels that operate in encoded space.
35///
36/// Implementations can compare an encoded array against another array (typically a constant)
37/// without first decompressing. The adaptor normalizes operand order so `array` is always
38/// the left-hand side, swapping the operator when necessary.
39pub trait CompareKernel: VTable {
40    fn compare(
41        lhs: &Self::Array,
42        rhs: &ArrayRef,
43        operator: CompareOperator,
44        ctx: &mut ExecutionCtx,
45    ) -> VortexResult<Option<ArrayRef>>;
46}
47
48/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
49///
50/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
51/// this adaptor extracts the comparison operator and other operand, normalizes operand order
52/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
53#[derive(Default, Debug)]
54pub struct CompareExecuteAdaptor<V>(pub V);
55
56impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
57where
58    V: CompareKernel,
59{
60    type Parent = ExactScalarFn<Binary>;
61
62    fn execute_parent(
63        &self,
64        array: &V::Array,
65        parent: ScalarFnArrayView<'_, Binary>,
66        child_idx: usize,
67        ctx: &mut ExecutionCtx,
68    ) -> VortexResult<Option<ArrayRef>> {
69        // Only handle comparison operators
70        let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
71            return Ok(None);
72        };
73
74        // Get the ScalarFnArray to access children
75        let Some(scalar_fn_array) = parent.as_opt::<ScalarFnVTable>() else {
76            return Ok(None);
77        };
78        let children = scalar_fn_array.children();
79
80        // Normalize so `array` is always LHS, swapping the operator if needed
81        // TODO(joe): should be go this here or in the Rule/Kernel
82        let (cmp_op, other) = match child_idx {
83            0 => (cmp_op, &children[1]),
84            1 => (cmp_op.swap(), &children[0]),
85            _ => return Ok(None),
86        };
87
88        let len = array.len();
89        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
90
91        // Empty array → empty bool result
92        if len == 0 {
93            return Ok(Some(
94                Canonical::empty(&DType::Bool(nullable.into())).into_array(),
95            ));
96        }
97
98        // Null constant on either side → all-null bool result
99        if other.as_constant().is_some_and(|s| s.is_null()) {
100            return Ok(Some(
101                ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
102            ));
103        }
104
105        V::compare(array, other, cmp_op, ctx)
106    }
107}
108
109/// Execute a compare operation between two arrays.
110///
111/// This is the entry point for compare operations from the binary expression.
112/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
113pub(crate) fn execute_compare(
114    lhs: &ArrayRef,
115    rhs: &ArrayRef,
116    op: CompareOperator,
117) -> VortexResult<ArrayRef> {
118    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
119
120    if lhs.is_empty() {
121        return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
122    }
123
124    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
125    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
126    if left_constant_null || right_constant_null {
127        return Ok(
128            ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
129        );
130    }
131
132    // Constant-constant fast path
133    if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
134    {
135        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
136        return Ok(ConstantArray::new(result, lhs.len()).into_array());
137    }
138
139    arrow_compare_arrays(lhs, rhs, op)
140}
141
142/// Fall back to Arrow for comparison.
143fn arrow_compare_arrays(
144    left: &ArrayRef,
145    right: &ArrayRef,
146    operator: CompareOperator,
147) -> VortexResult<ArrayRef> {
148    assert_eq!(left.len(), right.len());
149
150    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
151
152    // Arrow's vectorized comparison kernels don't support nested types.
153    // For nested types, fall back to `make_comparator` which does element-wise comparison.
154    let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
155        let rhs = right.to_array().into_arrow_preferred()?;
156        let lhs = left.to_array().into_arrow(rhs.data_type())?;
157
158        assert!(
159            lhs.data_type().equals_datatype(rhs.data_type()),
160            "lhs data_type: {}, rhs data_type: {}",
161            lhs.data_type(),
162            rhs.data_type()
163        );
164
165        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166    } else {
167        // Fast path: use vectorized kernels for primitive types.
168        let lhs = Datum::try_new(left)?;
169        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?;
170
171        match operator {
172            CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173            CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174            CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175            CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176            CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177            CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178        }
179    };
180
181    from_arrow_array_with_len(&arrow_array, left.len(), nullable)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185    if lhs.is_null() | rhs.is_null() {
186        return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187    }
188
189    let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191    // We use `partial_cmp` to ensure we do not lose a type mismatch error.
192    let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193        vortex_err!(
194            "Cannot compare scalars with incompatible types: {} and {}",
195            lhs.dtype(),
196            rhs.dtype()
197        )
198    })?;
199
200    let b = match operator {
201        CompareOperator::Eq => ordering.is_eq(),
202        CompareOperator::NotEq => ordering.is_ne(),
203        CompareOperator::Gt => ordering.is_gt(),
204        CompareOperator::Gte => ordering.is_ge(),
205        CompareOperator::Lt => ordering.is_lt(),
206        CompareOperator::Lte => ordering.is_le(),
207    };
208
209    Ok(Scalar::bool(b, nullability))
210}
211
212/// Compare two Arrow arrays element-wise using [`make_comparator`].
213///
214/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
215/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
216///
217/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
218/// prefer using the vectorized kernels directly for better performance.
219pub fn compare_nested_arrow_arrays(
220    lhs: &dyn arrow_array::Array,
221    rhs: &dyn arrow_array::Array,
222    operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226    let cmp_fn = match operator {
227        CompareOperator::Eq => Ordering::is_eq,
228        CompareOperator::NotEq => Ordering::is_ne,
229        CompareOperator::Gt => Ordering::is_gt,
230        CompareOperator::Gte => Ordering::is_ge,
231        CompareOperator::Lt => Ordering::is_lt,
232        CompareOperator::Lte => Ordering::is_le,
233    };
234
235    let values = (0..lhs.len())
236        .map(|i| cmp_fn(compare_arrays_at(i, i)))
237        .collect();
238    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240    Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245    use std::sync::Arc;
246
247    use rstest::rstest;
248    use vortex_buffer::buffer;
249
250    use crate::ArrayRef;
251    use crate::IntoArray;
252    use crate::ToCanonical;
253    use crate::arrays::BoolArray;
254    use crate::arrays::ListArray;
255    use crate::arrays::ListViewArray;
256    use crate::arrays::PrimitiveArray;
257    use crate::arrays::StructArray;
258    use crate::arrays::VarBinArray;
259    use crate::arrays::VarBinViewArray;
260    use crate::assert_arrays_eq;
261    use crate::builtins::ArrayBuiltins;
262    use crate::dtype::DType;
263    use crate::dtype::FieldName;
264    use crate::dtype::FieldNames;
265    use crate::dtype::Nullability;
266    use crate::dtype::PType;
267    use crate::extension::datetime::TimeUnit;
268    use crate::extension::datetime::Timestamp;
269    use crate::extension::datetime::TimestampOptions;
270    use crate::scalar::Scalar;
271    use crate::scalar_fn::fns::binary::compare::ConstantArray;
272    use crate::scalar_fn::fns::binary::scalar_cmp;
273    use crate::scalar_fn::fns::operators::CompareOperator;
274    use crate::scalar_fn::fns::operators::Operator;
275    use crate::test_harness::to_int_indices;
276    use crate::validity::Validity;
277
278    #[test]
279    fn test_bool_basic_comparisons() {
280        use vortex_buffer::BitBuffer;
281
282        let arr = BoolArray::new(
283            BitBuffer::from_iter([true, true, false, true, false]),
284            Validity::from_iter([false, true, true, true, true]),
285        );
286
287        let matches = arr
288            .clone()
289            .into_array()
290            .binary(arr.clone().into_array(), Operator::Eq)
291            .unwrap()
292            .to_bool();
293        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
294
295        let matches = arr
296            .clone()
297            .into_array()
298            .binary(arr.clone().into_array(), Operator::NotEq)
299            .unwrap()
300            .to_bool();
301        let empty: [u64; 0] = [];
302        assert_eq!(to_int_indices(matches).unwrap(), empty);
303
304        let other = BoolArray::new(
305            BitBuffer::from_iter([false, false, false, true, true]),
306            Validity::from_iter([false, true, true, true, true]),
307        );
308
309        let matches = arr
310            .clone()
311            .into_array()
312            .binary(other.clone().into_array(), Operator::Lte)
313            .unwrap()
314            .to_bool();
315        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
316
317        let matches = arr
318            .clone()
319            .into_array()
320            .binary(other.clone().into_array(), Operator::Lt)
321            .unwrap()
322            .to_bool();
323        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
324
325        let matches = other
326            .clone()
327            .into_array()
328            .binary(arr.clone().into_array(), Operator::Gte)
329            .unwrap()
330            .to_bool();
331        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
332
333        let matches = other
334            .into_array()
335            .binary(arr.into_array(), Operator::Gt)
336            .unwrap()
337            .to_bool();
338        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
339    }
340
341    #[test]
342    fn constant_compare() {
343        let left = ConstantArray::new(Scalar::from(2u32), 10);
344        let right = ConstantArray::new(Scalar::from(10u32), 10);
345
346        let result = left
347            .into_array()
348            .binary(right.into_array(), Operator::Gt)
349            .unwrap();
350        assert_eq!(result.len(), 10);
351        let scalar = result.scalar_at(0).unwrap();
352        assert_eq!(scalar.as_bool().value(), Some(false));
353    }
354
355    #[rstest]
356    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
357    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
358    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
359    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
360    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
361        let res = left.binary(right, Operator::Eq).unwrap();
362        let expected = BoolArray::from_iter([true, true]);
363        assert_arrays_eq!(res, expected);
364    }
365
366    #[ignore = "Arrow's ListView cannot be compared"]
367    #[test]
368    fn test_list_array_comparison() {
369        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
370        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
371        let list1 = ListArray::try_new(
372            values1.into_array(),
373            offsets1.into_array(),
374            Validity::NonNullable,
375        )
376        .unwrap();
377
378        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
379        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
380        let list2 = ListArray::try_new(
381            values2.into_array(),
382            offsets2.into_array(),
383            Validity::NonNullable,
384        )
385        .unwrap();
386
387        let result = list1
388            .clone()
389            .into_array()
390            .binary(list2.clone().into_array(), Operator::Eq)
391            .unwrap();
392        let expected = BoolArray::from_iter([true, true, false]);
393        assert_arrays_eq!(result, expected);
394
395        let result = list1
396            .clone()
397            .into_array()
398            .binary(list2.clone().into_array(), Operator::NotEq)
399            .unwrap();
400        let expected = BoolArray::from_iter([false, false, true]);
401        assert_arrays_eq!(result, expected);
402
403        let result = list1
404            .into_array()
405            .binary(list2.into_array(), Operator::Lt)
406            .unwrap();
407        let expected = BoolArray::from_iter([false, false, true]);
408        assert_arrays_eq!(result, expected);
409    }
410
411    #[ignore = "Arrow's ListView cannot be compared"]
412    #[test]
413    fn test_list_array_constant_comparison() {
414        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
415        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
416        let list = ListArray::try_new(
417            values.into_array(),
418            offsets.into_array(),
419            Validity::NonNullable,
420        )
421        .unwrap();
422
423        let list_scalar = Scalar::list(
424            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
425            vec![3i32.into(), 4i32.into()],
426            Nullability::NonNullable,
427        );
428        let constant = ConstantArray::new(list_scalar, 3);
429
430        let result = list
431            .into_array()
432            .binary(constant.into_array(), Operator::Eq)
433            .unwrap();
434        let expected = BoolArray::from_iter([false, true, false]);
435        assert_arrays_eq!(result, expected);
436    }
437
438    #[test]
439    fn test_struct_array_comparison() {
440        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
441        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
442
443        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
444        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
445
446        let struct1 = StructArray::from_fields(&[
447            ("bool_col", bool_field1.into_array()),
448            ("int_col", int_field1.into_array()),
449        ])
450        .unwrap();
451
452        let struct2 = StructArray::from_fields(&[
453            ("bool_col", bool_field2.into_array()),
454            ("int_col", int_field2.into_array()),
455        ])
456        .unwrap();
457
458        let result = struct1
459            .clone()
460            .into_array()
461            .binary(struct2.clone().into_array(), Operator::Eq)
462            .unwrap();
463        let expected = BoolArray::from_iter([true, true, false]);
464        assert_arrays_eq!(result, expected);
465
466        let result = struct1
467            .into_array()
468            .binary(struct2.into_array(), Operator::Gt)
469            .unwrap();
470        let expected = BoolArray::from_iter([false, false, true]);
471        assert_arrays_eq!(result, expected);
472    }
473
474    #[test]
475    fn test_empty_struct_compare() {
476        let empty1 = StructArray::try_new(
477            FieldNames::from(Vec::<FieldName>::new()),
478            Vec::new(),
479            5,
480            Validity::NonNullable,
481        )
482        .unwrap();
483
484        let empty2 = StructArray::try_new(
485            FieldNames::from(Vec::<FieldName>::new()),
486            Vec::new(),
487            5,
488            Validity::NonNullable,
489        )
490        .unwrap();
491
492        let result = empty1
493            .into_array()
494            .binary(empty2.into_array(), Operator::Eq)
495            .unwrap();
496        let expected = BoolArray::from_iter([true, true, true, true, true]);
497        assert_arrays_eq!(result, expected);
498    }
499
500    /// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
501    /// extension types (e.g., timestamps with different time units) rather than silently
502    /// returning a wrong result.
503    #[test]
504    fn scalar_cmp_incompatible_extension_types_errors() {
505        let ms_scalar = Scalar::extension::<Timestamp>(
506            TimestampOptions {
507                unit: TimeUnit::Milliseconds,
508                tz: None,
509            },
510            Scalar::from(1704067200000i64),
511        );
512        let s_scalar = Scalar::extension::<Timestamp>(
513            TimestampOptions {
514                unit: TimeUnit::Seconds,
515                tz: None,
516            },
517            Scalar::from(1704067200i64),
518        );
519
520        // Ordering comparisons must error on incompatible types.
521        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
522        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
523        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
524        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
525        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
526        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
527    }
528
529    #[test]
530    fn test_empty_list() {
531        let list = ListViewArray::new(
532            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
533            buffer![0i32, 0i32, 0i32].into_array(),
534            buffer![0i32, 0i32, 0i32].into_array(),
535            Validity::AllValid,
536        );
537
538        let result = list
539            .clone()
540            .into_array()
541            .binary(list.into_array(), Operator::Eq)
542            .unwrap();
543        assert!(result.scalar_at(0).unwrap().is_valid());
544        assert!(result.scalar_at(1).unwrap().is_valid());
545        assert!(result.scalar_at(2).unwrap().is_valid());
546    }
547}