vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_array::BooleanArray;
11use arrow_buffer::NullBuffer;
12use arrow_ord::cmp;
13use arrow_ord::ord::make_comparator;
14use arrow_schema::SortOptions;
15use vortex_buffer::BitBuffer;
16use vortex_dtype::{DType, IntegerPType, Nullability};
17use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
18use vortex_scalar::Scalar;
19
20use crate::arrays::ConstantArray;
21use crate::arrow::{Datum, IntoArrowArray, from_arrow_array_with_len};
22use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
23use crate::vtable::VTable;
24use crate::{Array, ArrayRef, Canonical, IntoArray};
25
26static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
28    for kernel in inventory::iter::<CompareKernelRef> {
29        compute.register_kernel(kernel.0.clone());
30    }
31    compute
32});
33
34pub(crate) fn warm_up_vtable() -> usize {
35    COMPARE_FN.kernels().len()
36}
37
38/// Compares two arrays and returns a new boolean array with the result of the comparison.
39/// Or, returns None if comparison is not supported for these arrays.
40pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
41    COMPARE_FN
42        .invoke(&InvocationArgs {
43            inputs: &[left.into(), right.into()],
44            options: &operator,
45        })?
46        .unwrap_array()
47}
48
49#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
50pub enum Operator {
51    /// Equality (`=`)
52    Eq,
53    /// Inequality (`!=`)
54    NotEq,
55    /// Greater than (`>`)
56    Gt,
57    /// Greater than or equal (`>=`)
58    Gte,
59    /// Less than (`<`)
60    Lt,
61    /// Less than or equal (`<=`)
62    Lte,
63}
64
65impl Display for Operator {
66    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
67        let display = match &self {
68            Operator::Eq => "=",
69            Operator::NotEq => "!=",
70            Operator::Gt => ">",
71            Operator::Gte => ">=",
72            Operator::Lt => "<",
73            Operator::Lte => "<=",
74        };
75        Display::fmt(display, f)
76    }
77}
78
79impl Operator {
80    pub fn inverse(self) -> Self {
81        match self {
82            Operator::Eq => Operator::NotEq,
83            Operator::NotEq => Operator::Eq,
84            Operator::Gt => Operator::Lte,
85            Operator::Gte => Operator::Lt,
86            Operator::Lt => Operator::Gte,
87            Operator::Lte => Operator::Gt,
88        }
89    }
90
91    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
92    pub fn swap(self) -> Self {
93        match self {
94            Operator::Eq => Operator::Eq,
95            Operator::NotEq => Operator::NotEq,
96            Operator::Gt => Operator::Lt,
97            Operator::Gte => Operator::Lte,
98            Operator::Lt => Operator::Gt,
99            Operator::Lte => Operator::Gte,
100        }
101    }
102}
103
104pub struct CompareKernelRef(ArcRef<dyn Kernel>);
105inventory::collect!(CompareKernelRef);
106
107pub trait CompareKernel: VTable {
108    fn compare(
109        &self,
110        lhs: &Self::Array,
111        rhs: &dyn Array,
112        operator: Operator,
113    ) -> VortexResult<Option<ArrayRef>>;
114}
115
116#[derive(Debug)]
117pub struct CompareKernelAdapter<V: VTable>(pub V);
118
119impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
120    pub const fn lift(&'static self) -> CompareKernelRef {
121        CompareKernelRef(ArcRef::new_ref(self))
122    }
123}
124
125impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
126    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
127        let inputs = CompareArgs::try_from(args)?;
128        let Some(array) = inputs.lhs.as_opt::<V>() else {
129            return Ok(None);
130        };
131        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
132    }
133}
134
135struct Compare;
136
137impl ComputeFnVTable for Compare {
138    fn invoke(
139        &self,
140        args: &InvocationArgs,
141        kernels: &[ArcRef<dyn Kernel>],
142    ) -> VortexResult<Output> {
143        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
144
145        let return_dtype = self.return_dtype(args)?;
146
147        if lhs.is_empty() {
148            return Ok(Canonical::empty(&return_dtype).into_array().into());
149        }
150
151        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
152        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
153        if left_constant_null || right_constant_null {
154            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
155                .into_array()
156                .into());
157        }
158
159        let right_is_constant = rhs.is_constant();
160
161        // Always try to put constants on the right-hand side so encodings can optimise themselves.
162        if lhs.is_constant() && !right_is_constant {
163            return Ok(compare(rhs, lhs, operator.swap())?.into());
164        }
165
166        // First try lhs op rhs, then invert and try again.
167        for kernel in kernels {
168            if let Some(output) = kernel.invoke(args)? {
169                return Ok(output);
170            }
171        }
172        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
173            return Ok(output);
174        }
175
176        // Try inverting the operator and swapping the arguments
177        let inverted_args = InvocationArgs {
178            inputs: &[rhs.into(), lhs.into()],
179            options: &operator.swap(),
180        };
181        for kernel in kernels {
182            if let Some(output) = kernel.invoke(&inverted_args)? {
183                return Ok(output);
184            }
185        }
186        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
187            return Ok(output);
188        }
189
190        // Only log missing compare implementation if there's possibly better one than arrow,
191        // i.e. lhs isn't arrow or rhs isn't arrow or constant
192        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
193            log::debug!(
194                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
195                lhs.encoding_id(),
196                rhs.encoding_id(),
197                operator,
198            );
199        }
200
201        // Fallback to arrow on canonical types
202        Ok(arrow_compare(lhs, rhs, operator)?.into())
203    }
204
205    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
206        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
207
208        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
209            vortex_bail!(
210                "Cannot compare different DTypes {} and {}",
211                lhs.dtype(),
212                rhs.dtype()
213            );
214        }
215
216        Ok(DType::Bool(
217            lhs.dtype().nullability() | rhs.dtype().nullability(),
218        ))
219    }
220
221    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
222        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
223        if lhs.len() != rhs.len() {
224            vortex_bail!(
225                "Compare operations only support arrays of the same length, got {} and {}",
226                lhs.len(),
227                rhs.len()
228            );
229        }
230        Ok(lhs.len())
231    }
232
233    fn is_elementwise(&self) -> bool {
234        true
235    }
236}
237
238struct CompareArgs<'a> {
239    lhs: &'a dyn Array,
240    rhs: &'a dyn Array,
241    operator: Operator,
242}
243
244impl Options for Operator {
245    fn as_any(&self) -> &dyn Any {
246        self
247    }
248}
249
250impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
251    type Error = VortexError;
252
253    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
254        if value.inputs.len() != 2 {
255            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
256        }
257        let lhs = value.inputs[0]
258            .array()
259            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
260        let rhs = value.inputs[1]
261            .array()
262            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
263        let operator = *value
264            .options
265            .as_any()
266            .downcast_ref::<Operator>()
267            .vortex_expect("Expected options to be an operator");
268
269        Ok(CompareArgs { lhs, rhs, operator })
270    }
271}
272
273/// Helper function to compare empty values with arrays that have external value length information
274/// like `VarBin`.
275pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BitBuffer
276where
277    P: IntegerPType,
278    I: Iterator<Item = P>,
279{
280    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
281    let cmp_fn = match op {
282        Operator::Eq | Operator::Lte => |v| v == P::zero(),
283        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
284        Operator::Gte => |_| true,
285        Operator::Lt => |_| false,
286    };
287
288    lengths.map(cmp_fn).collect()
289}
290
291/// Implementation of `CompareFn` using the Arrow crate.
292fn arrow_compare(
293    left: &dyn Array,
294    right: &dyn Array,
295    operator: Operator,
296) -> VortexResult<ArrayRef> {
297    assert_eq!(left.len(), right.len());
298
299    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
300
301    let array = if left.dtype().is_nested() || right.dtype().is_nested() {
302        let rhs = right.to_array().into_arrow_preferred()?;
303        let lhs = left.to_array().into_arrow(rhs.data_type())?;
304
305        assert!(
306            lhs.data_type().equals_datatype(rhs.data_type()),
307            "lhs data_type: {}, rhs data_type: {}",
308            lhs.data_type(),
309            rhs.data_type()
310        );
311
312        let cmp = make_comparator(lhs.as_ref(), rhs.as_ref(), SortOptions::default())?;
313        let len = left.len();
314        let values = (0..len)
315            .map(|i| {
316                let cmp = cmp(i, i);
317                match operator {
318                    Operator::Eq => cmp.is_eq(),
319                    Operator::NotEq => cmp.is_ne(),
320                    Operator::Gt => cmp.is_gt(),
321                    Operator::Gte => cmp.is_gt() || cmp.is_eq(),
322                    Operator::Lt => cmp.is_lt(),
323                    Operator::Lte => cmp.is_lt() || cmp.is_eq(),
324                }
325            })
326            .collect();
327        let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
328        BooleanArray::new(values, nulls)
329    } else {
330        let lhs = Datum::try_new(left)?;
331        let rhs = Datum::try_new(right)?;
332
333        match operator {
334            Operator::Eq => cmp::eq(&lhs, &rhs)?,
335            Operator::NotEq => cmp::neq(&lhs, &rhs)?,
336            Operator::Gt => cmp::gt(&lhs, &rhs)?,
337            Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
338            Operator::Lt => cmp::lt(&lhs, &rhs)?,
339            Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
340        }
341    };
342    Ok(from_arrow_array_with_len(&array, left.len(), nullable))
343}
344
345pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
346    if lhs.is_null() | rhs.is_null() {
347        Scalar::null(DType::Bool(Nullability::Nullable))
348    } else {
349        let b = match operator {
350            Operator::Eq => lhs == rhs,
351            Operator::NotEq => lhs != rhs,
352            Operator::Gt => lhs > rhs,
353            Operator::Gte => lhs >= rhs,
354            Operator::Lt => lhs < rhs,
355            Operator::Lte => lhs <= rhs,
356        };
357
358        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use rstest::rstest;
365    use vortex_buffer::buffer;
366    use vortex_dtype::{FieldName, FieldNames};
367
368    use super::*;
369    use crate::ToCanonical;
370    use crate::arrays::{
371        BoolArray, ConstantArray, ListArray, ListViewArray, PrimitiveArray, StructArray,
372        VarBinArray, VarBinViewArray,
373    };
374    use crate::test_harness::to_int_indices;
375    use crate::validity::Validity;
376
377    #[test]
378    fn test_bool_basic_comparisons() {
379        let arr = BoolArray::from_bit_buffer(
380            BitBuffer::from_iter([true, true, false, true, false]),
381            Validity::from_iter([false, true, true, true, true]),
382        );
383
384        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
385            .unwrap()
386            .to_bool();
387
388        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
389
390        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
391            .unwrap()
392            .to_bool();
393        let empty: [u64; 0] = [];
394        assert_eq!(to_int_indices(matches).unwrap(), empty);
395
396        let other = BoolArray::from_bit_buffer(
397            BitBuffer::from_iter([false, false, false, true, true]),
398            Validity::from_iter([false, true, true, true, true]),
399        );
400
401        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
402            .unwrap()
403            .to_bool();
404        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
405
406        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
407            .unwrap()
408            .to_bool();
409        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
410
411        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
412            .unwrap()
413            .to_bool();
414        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
415
416        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
417            .unwrap()
418            .to_bool();
419        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
420    }
421
422    #[test]
423    fn constant_compare() {
424        let left = ConstantArray::new(Scalar::from(2u32), 10);
425        let right = ConstantArray::new(Scalar::from(10u32), 10);
426
427        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
428        let res = compare.as_constant().unwrap();
429        assert_eq!(res.as_bool().value(), Some(false));
430        assert_eq!(compare.len(), 10);
431
432        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
433        let res = compare.as_constant().unwrap();
434        assert_eq!(res.as_bool().value(), Some(false));
435        assert_eq!(compare.len(), 10);
436    }
437
438    #[rstest]
439    #[case(Operator::Eq, vec![false, false, false, true])]
440    #[case(Operator::NotEq, vec![true, true, true, false])]
441    #[case(Operator::Gt, vec![true, true, true, false])]
442    #[case(Operator::Gte, vec![true, true, true, true])]
443    #[case(Operator::Lt, vec![false, false, false, false])]
444    #[case(Operator::Lte, vec![false, false, false, true])]
445    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
446        let lengths: Vec<i32> = vec![1, 5, 7, 0];
447
448        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
449        assert_eq!(Vec::from_iter(output.iter()), expected);
450    }
451
452    #[rstest]
453    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
454    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
455    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
456    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
457    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
458        let res = compare(&left, &right, Operator::Eq).unwrap();
459        assert_eq!(res.to_bool().bit_buffer().true_count(), left.len());
460    }
461
462    #[ignore = "Arrow's ListView cannot be compared"]
463    #[test]
464    fn test_list_array_comparison() {
465        // Create two simple list arrays with integers
466        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
467        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
468        let list1 = ListArray::try_new(
469            values1.into_array(),
470            offsets1.into_array(),
471            Validity::NonNullable,
472        )
473        .unwrap();
474
475        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
476        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
477        let list2 = ListArray::try_new(
478            values2.into_array(),
479            offsets2.into_array(),
480            Validity::NonNullable,
481        )
482        .unwrap();
483
484        // Test equality - first two lists should be equal, third should be different
485        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
486        let bool_result = result.to_bool();
487        assert!(bool_result.bit_buffer().value(0)); // [1,2] == [1,2]
488        assert!(bool_result.bit_buffer().value(1)); // [3,4] == [3,4]
489        assert!(!bool_result.bit_buffer().value(2)); // [5,6] != [7,8]
490
491        // Test inequality
492        let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
493        let bool_result = result.to_bool();
494        assert!(!bool_result.bit_buffer().value(0));
495        assert!(!bool_result.bit_buffer().value(1));
496        assert!(bool_result.bit_buffer().value(2));
497
498        // Test less than
499        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
500        let bool_result = result.to_bool();
501        assert!(!bool_result.bit_buffer().value(0)); // [1,2] < [1,2] = false
502        assert!(!bool_result.bit_buffer().value(1)); // [3,4] < [3,4] = false
503        assert!(bool_result.bit_buffer().value(2)); // [5,6] < [7,8] = true
504    }
505
506    #[ignore = "Arrow's ListView cannot be compared"]
507    #[test]
508    fn test_list_array_constant_comparison() {
509        use std::sync::Arc;
510
511        use vortex_dtype::{DType, PType};
512
513        // Create a list array
514        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
515        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
516        let list = ListArray::try_new(
517            values.into_array(),
518            offsets.into_array(),
519            Validity::NonNullable,
520        )
521        .unwrap();
522
523        // Create a constant list scalar [3,4] that will be broadcasted
524        let list_scalar = Scalar::list(
525            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
526            vec![3i32.into(), 4i32.into()],
527            Nullability::NonNullable,
528        );
529        let constant = ConstantArray::new(list_scalar, 3);
530
531        // Compare list with constant - all should be compared to [3,4]
532        let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
533        let bool_result = result.to_bool();
534        assert!(!bool_result.bit_buffer().value(0)); // [1,2] != [3,4]
535        assert!(bool_result.bit_buffer().value(1)); // [3,4] == [3,4]
536        assert!(!bool_result.bit_buffer().value(2)); // [5,6] != [3,4]
537    }
538
539    #[test]
540    fn test_struct_array_comparison() {
541        // Create two struct arrays with bool and int fields
542        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
543        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
544
545        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
546        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
547
548        let struct1 = StructArray::from_fields(&[
549            ("bool_col", bool_field1.into_array()),
550            ("int_col", int_field1.into_array()),
551        ])
552        .unwrap();
553
554        let struct2 = StructArray::from_fields(&[
555            ("bool_col", bool_field2.into_array()),
556            ("int_col", int_field2.into_array()),
557        ])
558        .unwrap();
559
560        // Test equality
561        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
562        let bool_result = result.to_bool();
563        assert!(bool_result.bit_buffer().value(0)); // {true, 1} == {true, 1}
564        assert!(bool_result.bit_buffer().value(1)); // {false, 2} == {false, 2}
565        assert!(!bool_result.bit_buffer().value(2)); // {true, 3} != {false, 4}
566
567        // Test greater than
568        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
569        let bool_result = result.to_bool();
570        assert!(!bool_result.bit_buffer().value(0)); // {true, 1} > {true, 1} = false
571        assert!(!bool_result.bit_buffer().value(1)); // {false, 2} > {false, 2} = false
572        assert!(bool_result.bit_buffer().value(2)); // {true, 3} > {false, 4} = true (bool field takes precedence)
573    }
574
575    #[test]
576    fn test_empty_struct_compare() {
577        let empty1 = StructArray::try_new(
578            FieldNames::from(Vec::<FieldName>::new()),
579            Vec::new(),
580            5,
581            Validity::NonNullable,
582        )
583        .unwrap();
584
585        let empty2 = StructArray::try_new(
586            FieldNames::from(Vec::<FieldName>::new()),
587            Vec::new(),
588            5,
589            Validity::NonNullable,
590        )
591        .unwrap();
592
593        let result = compare(empty1.as_ref(), empty2.as_ref(), Operator::Eq).unwrap();
594        let result = result.to_bool();
595
596        for idx in 0..5 {
597            assert!(result.bit_buffer().value(idx));
598        }
599    }
600
601    #[test]
602    fn test_empty_list() {
603        let list = ListViewArray::new(
604            BoolArray::from_iter(Vec::<bool>::new()).into_array(),
605            buffer![0i32, 0i32, 0i32].into_array(),
606            buffer![0i32, 0i32, 0i32].into_array(),
607            Validity::AllValid,
608        );
609
610        // Compare two lists together
611        let result = compare(list.as_ref(), list.as_ref(), Operator::Eq).unwrap();
612        assert!(result.scalar_at(0).is_valid());
613        assert!(result.scalar_at(1).is_valid());
614        assert!(result.scalar_at(2).is_valid());
615    }
616}