Skip to main content

vortex_array/scalar_fn/fns/binary/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::cmp;
9use arrow_ord::ord::make_comparator;
10use arrow_schema::Field;
11use arrow_schema::SortOptions;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14
15use crate::ArrayRef;
16use crate::Canonical;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::array::VTable;
21use crate::arrays::Constant;
22use crate::arrays::ConstantArray;
23use crate::arrays::ScalarFn;
24use crate::arrays::scalar_fn::ExactScalarFn;
25use crate::arrays::scalar_fn::ScalarFnArrayExt;
26use crate::arrays::scalar_fn::ScalarFnArrayView;
27use crate::arrow::ArrowSessionExt;
28use crate::arrow::Datum;
29use crate::arrow::from_arrow_columnar;
30use crate::dtype::DType;
31use crate::dtype::Nullability;
32use crate::kernel::ExecuteParentKernel;
33use crate::scalar::Scalar;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::CompareOperator;
36
37/// Trait for encoding-specific comparison kernels that operate in encoded space.
38///
39/// Implementations can compare an encoded array against another array (typically a constant)
40/// without first decompressing. The adaptor normalizes operand order so `array` is always
41/// the left-hand side, swapping the operator when necessary.
42pub trait CompareKernel: VTable {
43    fn compare(
44        lhs: ArrayView<'_, Self>,
45        rhs: &ArrayRef,
46        operator: CompareOperator,
47        ctx: &mut ExecutionCtx,
48    ) -> VortexResult<Option<ArrayRef>>;
49}
50
51/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`].
52///
53/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`,
54/// this adaptor extracts the comparison operator and other operand, normalizes operand order
55/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel.
56#[derive(Default, Debug)]
57pub struct CompareExecuteAdaptor<V>(pub V);
58
59impl<V> ExecuteParentKernel<V> for CompareExecuteAdaptor<V>
60where
61    V: CompareKernel,
62{
63    type Parent = ExactScalarFn<Binary>;
64
65    fn execute_parent(
66        &self,
67        array: ArrayView<'_, V>,
68        parent: ScalarFnArrayView<'_, Binary>,
69        child_idx: usize,
70        ctx: &mut ExecutionCtx,
71    ) -> VortexResult<Option<ArrayRef>> {
72        // Only handle comparison operators
73        let Ok(cmp_op) = CompareOperator::try_from(*parent.options) else {
74            return Ok(None);
75        };
76
77        // Get the ScalarFnArray to access children
78        let Some(scalar_fn_array) = parent.as_opt::<ScalarFn>() else {
79            return Ok(None);
80        };
81        // Normalize so `array` is always LHS, swapping the operator if needed
82        // TODO(joe): should be go this here or in the Rule/Kernel
83        let (cmp_op, other) = match child_idx {
84            0 => (cmp_op, scalar_fn_array.get_child(1)),
85            1 => (cmp_op.swap(), scalar_fn_array.get_child(0)),
86            _ => return Ok(None),
87        };
88
89        let len = array.len();
90        let nullable = array.dtype().is_nullable() || other.dtype().is_nullable();
91
92        // Empty array → empty bool result
93        if len == 0 {
94            return Ok(Some(
95                Canonical::empty(&DType::Bool(nullable.into())).into_array(),
96            ));
97        }
98
99        // Null constant on either side → all-null bool result
100        if other.as_constant().is_some_and(|s| s.is_null()) {
101            return Ok(Some(
102                ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), len).into_array(),
103            ));
104        }
105
106        V::compare(array, other, cmp_op, ctx)
107    }
108}
109
110/// Execute a compare operation between two arrays.
111///
112/// This is the entry point for compare operations from the binary expression.
113/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow.
114pub(crate) fn execute_compare(
115    lhs: &ArrayRef,
116    rhs: &ArrayRef,
117    op: CompareOperator,
118    ctx: &mut ExecutionCtx,
119) -> VortexResult<ArrayRef> {
120    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
121
122    if lhs.is_empty() {
123        return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
124    }
125
126    let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
127    let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
128    if left_constant_null || right_constant_null {
129        return Ok(
130            ConstantArray::new(Scalar::null(DType::Bool(nullable.into())), lhs.len()).into_array(),
131        );
132    }
133
134    // Constant-constant fast path
135    if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
136    {
137        let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
138        return Ok(ConstantArray::new(result, lhs.len()).into_array());
139    }
140
141    arrow_compare_arrays(lhs, rhs, op, ctx)
142}
143
144/// Fall back to Arrow for comparison.
145fn arrow_compare_arrays(
146    left: &ArrayRef,
147    right: &ArrayRef,
148    operator: CompareOperator,
149    ctx: &mut ExecutionCtx,
150) -> VortexResult<ArrayRef> {
151    assert_eq!(left.len(), right.len());
152
153    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
154
155    // Arrow's vectorized comparison kernels don't support nested types.
156    // For nested types, fall back to `make_comparator` which does element-wise comparison.
157    let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
158        let session = ctx.session().clone();
159        let lhs = session.arrow().execute_arrow(left.clone(), None, ctx)?;
160        let target_field = Field::new("", lhs.data_type().clone(), right.dtype().is_nullable());
161        let rhs = session
162            .arrow()
163            .execute_arrow(right.clone(), Some(&target_field), ctx)?;
164
165        compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)?
166    } else {
167        // Fast path: use vectorized kernels for primitive types.
168        let lhs = Datum::try_new(left, ctx)?;
169        let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type(), ctx)?;
170
171        match operator {
172            CompareOperator::Eq => cmp::eq(&lhs, &rhs)?,
173            CompareOperator::NotEq => cmp::neq(&lhs, &rhs)?,
174            CompareOperator::Gt => cmp::gt(&lhs, &rhs)?,
175            CompareOperator::Gte => cmp::gt_eq(&lhs, &rhs)?,
176            CompareOperator::Lt => cmp::lt(&lhs, &rhs)?,
177            CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
178        }
179    };
180
181    from_arrow_columnar(&arrow_array, left.len(), nullable, ctx)
182}
183
184pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
185    if lhs.is_null() | rhs.is_null() {
186        return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
187    }
188
189    let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190
191    // We use `partial_cmp` to ensure we do not lose a type mismatch error.
192    let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193        vortex_err!(
194            "Cannot compare scalars with incompatible types: {} and {}",
195            lhs.dtype(),
196            rhs.dtype()
197        )
198    })?;
199
200    let b = match operator {
201        CompareOperator::Eq => ordering.is_eq(),
202        CompareOperator::NotEq => ordering.is_ne(),
203        CompareOperator::Gt => ordering.is_gt(),
204        CompareOperator::Gte => ordering.is_ge(),
205        CompareOperator::Lt => ordering.is_lt(),
206        CompareOperator::Lte => ordering.is_le(),
207    };
208
209    Ok(Scalar::bool(b, nullability))
210}
211
212/// Compare two Arrow arrays element-wise using [`make_comparator`].
213///
214/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
215/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
216///
217/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
218/// prefer using the vectorized kernels directly for better performance.
219pub fn compare_nested_arrow_arrays(
220    lhs: &dyn arrow_array::Array,
221    rhs: &dyn arrow_array::Array,
222    operator: CompareOperator,
223) -> VortexResult<BooleanArray> {
224    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
225
226    let cmp_fn = match operator {
227        CompareOperator::Eq => Ordering::is_eq,
228        CompareOperator::NotEq => Ordering::is_ne,
229        CompareOperator::Gt => Ordering::is_gt,
230        CompareOperator::Gte => Ordering::is_ge,
231        CompareOperator::Lt => Ordering::is_lt,
232        CompareOperator::Lte => Ordering::is_le,
233    };
234
235    let values = (0..lhs.len())
236        .map(|i| cmp_fn(compare_arrays_at(i, i)))
237        .collect();
238    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
239
240    Ok(BooleanArray::new(values, nulls))
241}
242
243#[cfg(test)]
244mod tests {
245    use std::sync::Arc;
246
247    use rstest::rstest;
248    use vortex_buffer::BitBuffer;
249    use vortex_buffer::buffer;
250    use vortex_error::VortexExpect;
251
252    use crate::ArrayRef;
253    use crate::IntoArray;
254    use crate::VortexSessionExecute;
255    use crate::array_session;
256    use crate::arrays::BoolArray;
257    use crate::arrays::ListArray;
258    use crate::arrays::ListViewArray;
259    use crate::arrays::PrimitiveArray;
260    use crate::arrays::StructArray;
261    use crate::arrays::VarBinArray;
262    use crate::arrays::VarBinViewArray;
263    use crate::assert_arrays_eq;
264    use crate::builtins::ArrayBuiltins;
265    use crate::dtype::DType;
266    use crate::dtype::FieldName;
267    use crate::dtype::FieldNames;
268    use crate::dtype::Nullability;
269    use crate::dtype::PType;
270    use crate::extension::datetime::TimeUnit;
271    use crate::extension::datetime::Timestamp;
272    use crate::extension::datetime::TimestampOptions;
273    use crate::scalar::Scalar;
274    use crate::scalar_fn::fns::binary::compare::ConstantArray;
275    use crate::scalar_fn::fns::binary::scalar_cmp;
276    use crate::scalar_fn::fns::operators::CompareOperator;
277    use crate::scalar_fn::fns::operators::Operator;
278    use crate::test_harness::to_int_indices;
279    use crate::validity::Validity;
280
281    #[test]
282    fn test_bool_basic_comparisons() {
283        let ctx = &mut array_session().create_execution_ctx();
284        let arr = BoolArray::new(
285            BitBuffer::from_iter([true, true, false, true, false]),
286            Validity::from_iter([false, true, true, true, true]),
287        );
288
289        let matches = arr
290            .clone()
291            .into_array()
292            .binary(arr.clone().into_array(), Operator::Eq)
293            .unwrap()
294            .execute::<BoolArray>(ctx)
295            .vortex_expect("must be a bool array");
296        assert_eq!(to_int_indices(matches, ctx).unwrap(), [1u64, 2, 3, 4]);
297
298        let matches = arr
299            .clone()
300            .into_array()
301            .binary(arr.clone().into_array(), Operator::NotEq)
302            .unwrap()
303            .execute::<BoolArray>(ctx)
304            .vortex_expect("must be a bool array");
305        let empty: [u64; 0] = [];
306        assert_eq!(to_int_indices(matches, ctx).unwrap(), empty);
307
308        let other = BoolArray::new(
309            BitBuffer::from_iter([false, false, false, true, true]),
310            Validity::from_iter([false, true, true, true, true]),
311        );
312
313        let matches = arr
314            .clone()
315            .into_array()
316            .binary(other.clone().into_array(), Operator::Lte)
317            .unwrap()
318            .execute::<BoolArray>(ctx)
319            .vortex_expect("must be a bool array");
320        assert_eq!(to_int_indices(matches, ctx).unwrap(), [2u64, 3, 4]);
321
322        let matches = arr
323            .clone()
324            .into_array()
325            .binary(other.clone().into_array(), Operator::Lt)
326            .unwrap()
327            .execute::<BoolArray>(ctx)
328            .vortex_expect("must be a bool array");
329        assert_eq!(to_int_indices(matches, ctx).unwrap(), [4u64]);
330
331        let matches = other
332            .clone()
333            .into_array()
334            .binary(arr.clone().into_array(), Operator::Gte)
335            .unwrap()
336            .execute::<BoolArray>(ctx)
337            .vortex_expect("must be a bool array");
338        assert_eq!(to_int_indices(matches, ctx).unwrap(), [2u64, 3, 4]);
339
340        let matches = other
341            .into_array()
342            .binary(arr.into_array(), Operator::Gt)
343            .unwrap()
344            .execute::<BoolArray>(ctx)
345            .vortex_expect("must be a bool array");
346        assert_eq!(to_int_indices(matches, ctx).unwrap(), [4u64]);
347    }
348
349    #[test]
350    fn constant_compare() {
351        let left = ConstantArray::new(Scalar::from(2u32), 10);
352        let right = ConstantArray::new(Scalar::from(10u32), 10);
353
354        let result = left
355            .into_array()
356            .binary(right.into_array(), Operator::Gt)
357            .unwrap();
358        assert_eq!(result.len(), 10);
359        let scalar = result
360            .execute_scalar(0, &mut array_session().create_execution_ctx())
361            .unwrap();
362        assert_eq!(scalar.as_bool().value(), Some(false));
363    }
364
365    #[rstest]
366    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
367    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
368    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
369    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
370    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
371        let mut ctx = array_session().create_execution_ctx();
372        let res = left.binary(right, Operator::Eq).unwrap();
373        let expected = BoolArray::from_iter([true, true]);
374        assert_arrays_eq!(res, expected, &mut ctx);
375    }
376
377    #[test]
378    fn test_list_array_comparison() {
379        let mut ctx = array_session().create_execution_ctx();
380        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
381        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
382        let list1 = ListArray::try_new(
383            values1.into_array(),
384            offsets1.into_array(),
385            Validity::NonNullable,
386        )
387        .unwrap();
388
389        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
390        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
391        let list2 = ListArray::try_new(
392            values2.into_array(),
393            offsets2.into_array(),
394            Validity::NonNullable,
395        )
396        .unwrap();
397
398        let result = list1
399            .clone()
400            .into_array()
401            .binary(list2.clone().into_array(), Operator::Eq)
402            .unwrap();
403        let expected = BoolArray::from_iter([true, true, false]);
404        assert_arrays_eq!(result, expected, &mut ctx);
405
406        let result = list1
407            .clone()
408            .into_array()
409            .binary(list2.clone().into_array(), Operator::NotEq)
410            .unwrap();
411        let expected = BoolArray::from_iter([false, false, true]);
412        assert_arrays_eq!(result, expected, &mut ctx);
413
414        let result = list1
415            .into_array()
416            .binary(list2.into_array(), Operator::Lt)
417            .unwrap();
418        let expected = BoolArray::from_iter([false, false, true]);
419        assert_arrays_eq!(result, expected, &mut ctx);
420    }
421
422    #[test]
423    fn test_list_array_constant_comparison() {
424        let mut ctx = array_session().create_execution_ctx();
425        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
426        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
427        let list = ListArray::try_new(
428            values.into_array(),
429            offsets.into_array(),
430            Validity::NonNullable,
431        )
432        .unwrap();
433
434        let list_scalar = Scalar::list(
435            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
436            vec![3i32.into(), 4i32.into()],
437            Nullability::NonNullable,
438        );
439        let constant = ConstantArray::new(list_scalar, 3);
440
441        let result = list
442            .into_array()
443            .binary(constant.into_array(), Operator::Eq)
444            .unwrap();
445        let expected = BoolArray::from_iter([false, true, false]);
446        assert_arrays_eq!(result, expected, &mut ctx);
447    }
448
449    #[test]
450    fn test_struct_array_comparison() {
451        let mut ctx = array_session().create_execution_ctx();
452        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
453        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
454
455        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
456        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
457
458        let struct1 = StructArray::from_fields(&[
459            ("bool_col", bool_field1.into_array()),
460            ("int_col", int_field1.into_array()),
461        ])
462        .unwrap();
463
464        let struct2 = StructArray::from_fields(&[
465            ("bool_col", bool_field2.into_array()),
466            ("int_col", int_field2.into_array()),
467        ])
468        .unwrap();
469
470        let result = struct1
471            .clone()
472            .into_array()
473            .binary(struct2.clone().into_array(), Operator::Eq)
474            .unwrap();
475        let expected = BoolArray::from_iter([true, true, false]);
476        assert_arrays_eq!(result, expected, &mut ctx);
477
478        let result = struct1
479            .into_array()
480            .binary(struct2.into_array(), Operator::Gt)
481            .unwrap();
482        let expected = BoolArray::from_iter([false, false, true]);
483        assert_arrays_eq!(result, expected, &mut ctx);
484    }
485
486    #[test]
487    fn test_empty_struct_compare() {
488        let mut ctx = array_session().create_execution_ctx();
489        let empty1 = StructArray::try_new(
490            FieldNames::from(Vec::<FieldName>::new()),
491            Vec::new(),
492            5,
493            Validity::NonNullable,
494        )
495        .unwrap();
496
497        let empty2 = StructArray::try_new(
498            FieldNames::from(Vec::<FieldName>::new()),
499            Vec::new(),
500            5,
501            Validity::NonNullable,
502        )
503        .unwrap();
504
505        let result = empty1
506            .into_array()
507            .binary(empty2.into_array(), Operator::Eq)
508            .unwrap();
509        let expected = BoolArray::from_iter([true, true, true, true, true]);
510        assert_arrays_eq!(result, expected, &mut ctx);
511    }
512
513    /// Regression test: comparing struct arrays where the same logical field is backed by
514    /// different Vortex encodings (VarBinArray vs VarBinViewArray) must not panic.
515    #[test]
516    fn struct_compare_mixed_binary_encodings() {
517        let mut ctx = array_session().create_execution_ctx();
518        // LHS: struct with a VarBinArray (offset-based) binary field
519        let bin_field1 = VarBinArray::from(vec![
520            "apple".as_bytes(),
521            "banana".as_bytes(),
522            "cherry".as_bytes(),
523        ]);
524        let struct1 = StructArray::from_fields(&[("data", bin_field1.into_array())]).unwrap();
525
526        // RHS: struct with a VarBinViewArray (view-based) binary field — same logical DType
527        let bin_field2 = VarBinViewArray::from_iter_bin([
528            "apple".as_bytes(),
529            "banana".as_bytes(),
530            "durian".as_bytes(),
531        ]);
532        let struct2 = StructArray::from_fields(&[("data", bin_field2.into_array())]).unwrap();
533
534        let result = struct1
535            .into_array()
536            .binary(struct2.into_array(), Operator::Eq)
537            .unwrap();
538        let expected = BoolArray::from_iter([true, true, false]);
539        assert_arrays_eq!(result, expected, &mut ctx);
540    }
541
542    /// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
543    /// extension types (e.g., timestamps with different time units) rather than silently
544    /// returning a wrong result.
545    #[test]
546    fn scalar_cmp_incompatible_extension_types_errors() {
547        let ms_scalar = Scalar::extension::<Timestamp>(
548            TimestampOptions {
549                unit: TimeUnit::Milliseconds,
550                tz: None,
551            },
552            Scalar::from(1704067200000i64),
553        );
554        let s_scalar = Scalar::extension::<Timestamp>(
555            TimestampOptions {
556                unit: TimeUnit::Seconds,
557                tz: None,
558            },
559            Scalar::from(1704067200i64),
560        );
561
562        // Ordering comparisons must error on incompatible types.
563        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
564        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
565        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
566        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
567        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
568        assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
569    }
570
571    #[test]
572    fn test_empty_list() {
573        let ctx = &mut array_session().create_execution_ctx();
574        let list = ListViewArray::new(
575            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
576            buffer![0i32, 0i32, 0i32].into_array(),
577            buffer![0i32, 0i32, 0i32].into_array(),
578            Validity::AllValid,
579        );
580
581        let result = list
582            .clone()
583            .into_array()
584            .binary(list.into_array(), Operator::Eq)
585            .unwrap();
586        assert!(result.execute_scalar(0, ctx).unwrap().is_valid());
587        assert!(result.execute_scalar(1, ctx).unwrap().is_valid());
588        assert!(result.execute_scalar(2, ctx).unwrap().is_valid());
589    }
590}