Skip to main content

vortex_array/scalar_fn/fns/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::SortOptions;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::array::ArrayView;
19use crate::array::VTable;
20use crate::arrays::Constant;
21use crate::arrays::ConstantArray;
22use crate::arrays::ScalarFn;
23use crate::arrays::scalar_fn::ExactScalarFn;
24use crate::arrays::scalar_fn::ScalarFnArrayExt;
25use crate::arrays::scalar_fn::ScalarFnArrayView;
26use crate::arrow::ArrowArrayExecutor;
27use crate::arrow::Datum;
28use crate::arrow::from_arrow_array_with_len;
29use crate::dtype::DType;
30use crate::dtype::Nullability;
31use crate::kernel::ExecuteParentKernel;
32use crate::scalar::Scalar;
33use crate::scalar_fn::fns::binary::Binary;
34use crate::scalar_fn::fns::operators::CompareOperator;
35
36/// Trait for encoding-specific comparison kernels that operate in encoded space.
37///
38/// Implementations can compare an encoded array against another array (typically a constant)
39/// without first decompressing. The adaptor normalizes operand order so `array` is always
40/// the left-hand side, swapping the operator when necessary.
41pub trait CompareKernel: VTable {
42    fn compare(
43        lhs: ArrayView<'_, Self>,
44        rhs: &ArrayRef,
45        operator: CompareOperator,
46        ctx: &mut ExecutionCtx,
47    ) -> VortexResult<Option<ArrayRef>>;
48}
49
50/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
51///
52/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
53/// this adaptor extracts the comparison operator and other operand, normalizes operand order
54/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
55#[derive(Default, Debug)]
56pub struct CompareExecuteAdaptor<V>(pub V);
57
58impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
59where
60    V: CompareKernel,
61{
62    type Parent = ExactScalarFn<Binary>;
63
64    fn execute_parent(
65        &self,
66        array: ArrayView<'_, V>,
67        parent: ScalarFnArrayView<'_, Binary>,
68        child_idx: usize,
69        ctx: &mut ExecutionCtx,
70    ) -> VortexResult<Option<ArrayRef>> {
71        // Only handle comparison operators
72        let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
73            return Ok(None);
74        };
75
76        // Get the ScalarFnArray to access children
77        let Some(scalar_fn_array) = parent.as_opt::<ScalarFn>() else {
78            return Ok(None);
79        };
80        // Normalize so `array` is always LHS, swapping the operator if needed
81        // TODO(joe): should be go this here or in the Rule/Kernel
82        let (cmp_op, other) = match child_idx {
83            0 => (cmp_op, scalar_fn_array.get_child(1)),
84            1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
85            _ => return Ok(None),
86        };
87
88        let len = array.len();
89        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
90
91        // Empty array → empty bool result
92        if len == 0 {
93            return Ok(Some(
94                Canonical::empty(&DType::Bool(nullable.into())).into_array(),
95            ));
96        }
97
98        // Null constant on either side → all-null bool result
99        if other.as_constant().is_some_and(|s| s.is_null()) {
100            return Ok(Some(
101                ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
102            ));
103        }
104
105        V::compare(array, other, cmp_op, ctx)
106    }
107}
108
109/// Execute a compare operation between two arrays.
110///
111/// This is the entry point for compare operations from the binary expression.
112/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
113pub(crate) fn execute_compare(
114    lhs: &ArrayRef,
115    rhs: &ArrayRef,
116    op: CompareOperator,
117    ctx: &mut ExecutionCtx,
118) -> VortexResult<ArrayRef> {
119    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
120
121    if lhs.is_empty() {
122        return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
123    }
124
125    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
126    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
127    if left_constant_null || right_constant_null {
128        return Ok(
129            ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
130        );
131    }
132
133    // Constant-constant fast path
134    if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
135    {
136        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
137        return Ok(ConstantArray::new(result, lhs.len()).into_array());
138    }
139
140    arrow_compare_arrays(lhs, rhs, op, ctx)
141}
142
143/// Fall back to Arrow for comparison.
144fn arrow_compare_arrays(
145    left: &ArrayRef,
146    right: &ArrayRef,
147    operator: CompareOperator,
148    ctx: &mut ExecutionCtx,
149) -> VortexResult<ArrayRef> {
150    assert_eq!(left.len(), right.len());
151
152    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
153
154    // Arrow's vectorized comparison kernels don't support nested types.
155    // For nested types, fall back to `make_comparator` which does element-wise comparison.
156    let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
157        let rhs = right.clone().execute_arrow(None, ctx)?;
158        let lhs = left.clone().execute_arrow(Some(rhs.data_type()), ctx)?;
159
160        assert!(
161            lhs.data_type().equals_datatype(rhs.data_type()),
162            "lhs data_type: {}, rhs data_type: {}",
163            lhs.data_type(),
164            rhs.data_type()
165        );
166
167        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
168    } else {
169        // Fast path: use vectorized kernels for primitive types.
170        let lhs = Datum::try_new(left, ctx)?;
171        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type(), ctx)?;
172
173        match operator {
174            CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
175            CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
176            CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
177            CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
178            CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
179            CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
180        }
181    };
182
183    from_arrow_array_with_len(&arrow_array, left.len(), nullable)
184}
185
186pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
187    if lhs.is_null() | rhs.is_null() {
188        return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
189    }
190
191    let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
192
193    // We use `partial_cmp` to ensure we do not lose a type mismatch error.
194    let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
195        vortex_err!(
196            "Cannot compare scalars with incompatible types: {} and {}",
197            lhs.dtype(),
198            rhs.dtype()
199        )
200    })?;
201
202    let b = match operator {
203        CompareOperator::Eq => ordering.is_eq(),
204        CompareOperator::NotEq => ordering.is_ne(),
205        CompareOperator::Gt => ordering.is_gt(),
206        CompareOperator::Gte => ordering.is_ge(),
207        CompareOperator::Lt => ordering.is_lt(),
208        CompareOperator::Lte => ordering.is_le(),
209    };
210
211    Ok(Scalar::bool(b, nullability))
212}
213
214/// Compare two Arrow arrays element-wise using [`make_comparator`].
215///
216/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
217/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
218///
219/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
220/// prefer using the vectorized kernels directly for better performance.
221pub fn compare_nested_arrow_arrays(
222    lhs: &dyn arrow_array::Array,
223    rhs: &dyn arrow_array::Array,
224    operator: CompareOperator,
225) -> VortexResult<BooleanArray> {
226    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
227
228    let cmp_fn = match operator {
229        CompareOperator::Eq => Ordering::is_eq,
230        CompareOperator::NotEq => Ordering::is_ne,
231        CompareOperator::Gt => Ordering::is_gt,
232        CompareOperator::Gte => Ordering::is_ge,
233        CompareOperator::Lt => Ordering::is_lt,
234        CompareOperator::Lte => Ordering::is_le,
235    };
236
237    let values = (0..lhs.len())
238        .map(|i| cmp_fn(compare_arrays_at(i, i)))
239        .collect();
240    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
241
242    Ok(BooleanArray::new(values, nulls))
243}
244
245#[cfg(test)]
246mod tests {
247    use std::sync::Arc;
248
249    use rstest::rstest;
250    use vortex_buffer::buffer;
251
252    use crate::ArrayRef;
253    use crate::IntoArray;
254    use crate::LEGACY_SESSION;
255    #[expect(deprecated)]
256    use crate::ToCanonical as _;
257    use crate::VortexSessionExecute;
258    use crate::arrays::BoolArray;
259    use crate::arrays::ListArray;
260    use crate::arrays::ListViewArray;
261    use crate::arrays::PrimitiveArray;
262    use crate::arrays::StructArray;
263    use crate::arrays::VarBinArray;
264    use crate::arrays::VarBinViewArray;
265    use crate::assert_arrays_eq;
266    use crate::builtins::ArrayBuiltins;
267    use crate::dtype::DType;
268    use crate::dtype::FieldName;
269    use crate::dtype::FieldNames;
270    use crate::dtype::Nullability;
271    use crate::dtype::PType;
272    use crate::extension::datetime::TimeUnit;
273    use crate::extension::datetime::Timestamp;
274    use crate::extension::datetime::TimestampOptions;
275    use crate::scalar::Scalar;
276    use crate::scalar_fn::fns::binary::compare::ConstantArray;
277    use crate::scalar_fn::fns::binary::scalar_cmp;
278    use crate::scalar_fn::fns::operators::CompareOperator;
279    use crate::scalar_fn::fns::operators::Operator;
280    use crate::test_harness::to_int_indices;
281    use crate::validity::Validity;
282
283    #[test]
284    fn test_bool_basic_comparisons() {
285        use vortex_buffer::BitBuffer;
286
287        let arr = BoolArray::new(
288            BitBuffer::from_iter([true, true, false, true, false]),
289            Validity::from_iter([false, true, true, true, true]),
290        );
291
292        #[expect(deprecated)]
293        let matches = arr
294            .clone()
295            .into_array()
296            .binary(arr.clone().into_array(), Operator::Eq)
297            .unwrap()
298            .to_bool();
299        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
300
301        #[expect(deprecated)]
302        let matches = arr
303            .clone()
304            .into_array()
305            .binary(arr.clone().into_array(), Operator::NotEq)
306            .unwrap()
307            .to_bool();
308        let empty: [u64; 0] = [];
309        assert_eq!(to_int_indices(matches).unwrap(), empty);
310
311        let other = BoolArray::new(
312            BitBuffer::from_iter([false, false, false, true, true]),
313            Validity::from_iter([false, true, true, true, true]),
314        );
315
316        #[expect(deprecated)]
317        let matches = arr
318            .clone()
319            .into_array()
320            .binary(other.clone().into_array(), Operator::Lte)
321            .unwrap()
322            .to_bool();
323        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
324
325        #[expect(deprecated)]
326        let matches = arr
327            .clone()
328            .into_array()
329            .binary(other.clone().into_array(), Operator::Lt)
330            .unwrap()
331            .to_bool();
332        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
333
334        #[expect(deprecated)]
335        let matches = other
336            .clone()
337            .into_array()
338            .binary(arr.clone().into_array(), Operator::Gte)
339            .unwrap()
340            .to_bool();
341        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
342
343        #[expect(deprecated)]
344        let matches = other
345            .into_array()
346            .binary(arr.into_array(), Operator::Gt)
347            .unwrap()
348            .to_bool();
349        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
350    }
351
352    #[test]
353    fn constant_compare() {
354        let left = ConstantArray::new(Scalar::from(2u32), 10);
355        let right = ConstantArray::new(Scalar::from(10u32), 10);
356
357        let result = left
358            .into_array()
359            .binary(right.into_array(), Operator::Gt)
360            .unwrap();
361        assert_eq!(result.len(), 10);
362        let scalar = result
363            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
364            .unwrap();
365        assert_eq!(scalar.as_bool().value(), Some(false));
366    }
367
368    #[rstest]
369    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
370    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
371    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
372    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
373    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
374        let res = left.binary(right, Operator::Eq).unwrap();
375        let expected = BoolArray::from_iter([true, true]);
376        assert_arrays_eq!(res, expected);
377    }
378
379    #[ignore = "Arrow's ListView cannot be compared"]
380    #[test]
381    fn test_list_array_comparison() {
382        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
383        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
384        let list1 = ListArray::try_new(
385            values1.into_array(),
386            offsets1.into_array(),
387            Validity::NonNullable,
388        )
389        .unwrap();
390
391        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
392        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
393        let list2 = ListArray::try_new(
394            values2.into_array(),
395            offsets2.into_array(),
396            Validity::NonNullable,
397        )
398        .unwrap();
399
400        let result = list1
401            .clone()
402            .into_array()
403            .binary(list2.clone().into_array(), Operator::Eq)
404            .unwrap();
405        let expected = BoolArray::from_iter([true, true, false]);
406        assert_arrays_eq!(result, expected);
407
408        let result = list1
409            .clone()
410            .into_array()
411            .binary(list2.clone().into_array(), Operator::NotEq)
412            .unwrap();
413        let expected = BoolArray::from_iter([false, false, true]);
414        assert_arrays_eq!(result, expected);
415
416        let result = list1
417            .into_array()
418            .binary(list2.into_array(), Operator::Lt)
419            .unwrap();
420        let expected = BoolArray::from_iter([false, false, true]);
421        assert_arrays_eq!(result, expected);
422    }
423
424    #[ignore = "Arrow's ListView cannot be compared"]
425    #[test]
426    fn test_list_array_constant_comparison() {
427        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
428        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
429        let list = ListArray::try_new(
430            values.into_array(),
431            offsets.into_array(),
432            Validity::NonNullable,
433        )
434        .unwrap();
435
436        let list_scalar = Scalar::list(
437            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
438            vec![3i32.into(), 4i32.into()],
439            Nullability::NonNullable,
440        );
441        let constant = ConstantArray::new(list_scalar, 3);
442
443        let result = list
444            .into_array()
445            .binary(constant.into_array(), Operator::Eq)
446            .unwrap();
447        let expected = BoolArray::from_iter([false, true, false]);
448        assert_arrays_eq!(result, expected);
449    }
450
451    #[test]
452    fn test_struct_array_comparison() {
453        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
454        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
455
456        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
457        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
458
459        let struct1 = StructArray::from_fields(&[
460            ("bool_col", bool_field1.into_array()),
461            ("int_col", int_field1.into_array()),
462        ])
463        .unwrap();
464
465        let struct2 = StructArray::from_fields(&[
466            ("bool_col", bool_field2.into_array()),
467            ("int_col", int_field2.into_array()),
468        ])
469        .unwrap();
470
471        let result = struct1
472            .clone()
473            .into_array()
474            .binary(struct2.clone().into_array(), Operator::Eq)
475            .unwrap();
476        let expected = BoolArray::from_iter([true, true, false]);
477        assert_arrays_eq!(result, expected);
478
479        let result = struct1
480            .into_array()
481            .binary(struct2.into_array(), Operator::Gt)
482            .unwrap();
483        let expected = BoolArray::from_iter([false, false, true]);
484        assert_arrays_eq!(result, expected);
485    }
486
487    #[test]
488    fn test_empty_struct_compare() {
489        let empty1 = StructArray::try_new(
490            FieldNames::from(Vec::<FieldName>::new()),
491            Vec::new(),
492            5,
493            Validity::NonNullable,
494        )
495        .unwrap();
496
497        let empty2 = StructArray::try_new(
498            FieldNames::from(Vec::<FieldName>::new()),
499            Vec::new(),
500            5,
501            Validity::NonNullable,
502        )
503        .unwrap();
504
505        let result = empty1
506            .into_array()
507            .binary(empty2.into_array(), Operator::Eq)
508            .unwrap();
509        let expected = BoolArray::from_iter([true, true, true, true, true]);
510        assert_arrays_eq!(result, expected);
511    }
512
513    /// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
514    /// extension types (e.g., timestamps with different time units) rather than silently
515    /// returning a wrong result.
516    #[test]
517    fn scalar_cmp_incompatible_extension_types_errors() {
518        let ms_scalar = Scalar::extension::<Timestamp>(
519            TimestampOptions {
520                unit: TimeUnit::Milliseconds,
521                tz: None,
522            },
523            Scalar::from(1704067200000i64),
524        );
525        let s_scalar = Scalar::extension::<Timestamp>(
526            TimestampOptions {
527                unit: TimeUnit::Seconds,
528                tz: None,
529            },
530            Scalar::from(1704067200i64),
531        );
532
533        // Ordering comparisons must error on incompatible types.
534        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
535        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
536        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
537        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
538        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
539        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
540    }
541
542    #[test]
543    fn test_empty_list() {
544        let list = ListViewArray::new(
545            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
546            buffer![0i32, 0i32, 0i32].into_array(),
547            buffer![0i32, 0i32, 0i32].into_array(),
548            Validity::AllValid,
549        );
550
551        let result = list
552            .clone()
553            .into_array()
554            .binary(list.into_array(), Operator::Eq)
555            .unwrap();
556        assert!(
557            result
558                .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
559                .unwrap()
560                .is_valid()
561        );
562        assert!(
563            result
564                .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
565                .unwrap()
566                .is_valid()
567        );
568        assert!(
569            result
570                .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
571                .unwrap()
572                .is_valid()
573        );
574    }
575}