vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_array::{BooleanArray, Datum as ArrowDatum};
11use arrow_buffer::{BooleanBuffer, NullBuffer};
12use arrow_ord::cmp;
13use arrow_ord::ord::make_comparator;
14use arrow_schema::SortOptions;
15use vortex_dtype::{DType, IntegerPType, Nullability};
16use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
17use vortex_scalar::Scalar;
18
19use crate::arrays::ConstantArray;
20use crate::arrow::{Datum, from_arrow_array_with_len};
21use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
22use crate::vtable::VTable;
23use crate::{Array, ArrayRef, Canonical, IntoArray};
24
25static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
26    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
27    for kernel in inventory::iter::<CompareKernelRef> {
28        compute.register_kernel(kernel.0.clone());
29    }
30    compute
31});
32
33pub(crate) fn warm_up_vtable() -> usize {
34    COMPARE_FN.kernels().len()
35}
36
37/// Compares two arrays and returns a new boolean array with the result of the comparison.
38/// Or, returns None if comparison is not supported for these arrays.
39pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
40    COMPARE_FN
41        .invoke(&InvocationArgs {
42            inputs: &[left.into(), right.into()],
43            options: &operator,
44        })?
45        .unwrap_array()
46}
47
48#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
49pub enum Operator {
50    /// Equality (`=`)
51    Eq,
52    /// Inequality (`!=`)
53    NotEq,
54    /// Greater than (`>`)
55    Gt,
56    /// Greater than or equal (`>=`)
57    Gte,
58    /// Less than (`<`)
59    Lt,
60    /// Less than or equal (`<=`)
61    Lte,
62}
63
64impl Display for Operator {
65    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
66        let display = match &self {
67            Operator::Eq => "=",
68            Operator::NotEq => "!=",
69            Operator::Gt => ">",
70            Operator::Gte => ">=",
71            Operator::Lt => "<",
72            Operator::Lte => "<=",
73        };
74        Display::fmt(display, f)
75    }
76}
77
78impl Operator {
79    pub fn inverse(self) -> Self {
80        match self {
81            Operator::Eq => Operator::NotEq,
82            Operator::NotEq => Operator::Eq,
83            Operator::Gt => Operator::Lte,
84            Operator::Gte => Operator::Lt,
85            Operator::Lt => Operator::Gte,
86            Operator::Lte => Operator::Gt,
87        }
88    }
89
90    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
91    pub fn swap(self) -> Self {
92        match self {
93            Operator::Eq => Operator::Eq,
94            Operator::NotEq => Operator::NotEq,
95            Operator::Gt => Operator::Lt,
96            Operator::Gte => Operator::Lte,
97            Operator::Lt => Operator::Gt,
98            Operator::Lte => Operator::Gte,
99        }
100    }
101}
102
103pub struct CompareKernelRef(ArcRef<dyn Kernel>);
104inventory::collect!(CompareKernelRef);
105
106pub trait CompareKernel: VTable {
107    fn compare(
108        &self,
109        lhs: &Self::Array,
110        rhs: &dyn Array,
111        operator: Operator,
112    ) -> VortexResult<Option<ArrayRef>>;
113}
114
115#[derive(Debug)]
116pub struct CompareKernelAdapter<V: VTable>(pub V);
117
118impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
119    pub const fn lift(&'static self) -> CompareKernelRef {
120        CompareKernelRef(ArcRef::new_ref(self))
121    }
122}
123
124impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
125    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
126        let inputs = CompareArgs::try_from(args)?;
127        let Some(array) = inputs.lhs.as_opt::<V>() else {
128            return Ok(None);
129        };
130        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
131    }
132}
133
134struct Compare;
135
136impl ComputeFnVTable for Compare {
137    fn invoke(
138        &self,
139        args: &InvocationArgs,
140        kernels: &[ArcRef<dyn Kernel>],
141    ) -> VortexResult<Output> {
142        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
143
144        let return_dtype = self.return_dtype(args)?;
145
146        if lhs.is_empty() {
147            return Ok(Canonical::empty(&return_dtype).into_array().into());
148        }
149
150        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
151        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
152        if left_constant_null || right_constant_null {
153            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
154                .into_array()
155                .into());
156        }
157
158        let right_is_constant = rhs.is_constant();
159
160        // Always try to put constants on the right-hand side so encodings can optimise themselves.
161        if lhs.is_constant() && !right_is_constant {
162            return Ok(compare(rhs, lhs, operator.swap())?.into());
163        }
164
165        // First try lhs op rhs, then invert and try again.
166        for kernel in kernels {
167            if let Some(output) = kernel.invoke(args)? {
168                return Ok(output);
169            }
170        }
171        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
172            return Ok(output);
173        }
174
175        // Try inverting the operator and swapping the arguments
176        let inverted_args = InvocationArgs {
177            inputs: &[rhs.into(), lhs.into()],
178            options: &operator.swap(),
179        };
180        for kernel in kernels {
181            if let Some(output) = kernel.invoke(&inverted_args)? {
182                return Ok(output);
183            }
184        }
185        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
186            return Ok(output);
187        }
188
189        // Only log missing compare implementation if there's possibly better one than arrow,
190        // i.e. lhs isn't arrow or rhs isn't arrow or constant
191        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
192            log::debug!(
193                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
194                lhs.encoding_id(),
195                rhs.encoding_id(),
196                operator,
197            );
198        }
199
200        // Fallback to arrow on canonical types
201        Ok(arrow_compare(lhs, rhs, operator)?.into())
202    }
203
204    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
205        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
206
207        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
208            vortex_bail!(
209                "Cannot compare different DTypes {} and {}",
210                lhs.dtype(),
211                rhs.dtype()
212            );
213        }
214
215        Ok(DType::Bool(
216            lhs.dtype().nullability() | rhs.dtype().nullability(),
217        ))
218    }
219
220    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
221        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
222        if lhs.len() != rhs.len() {
223            vortex_bail!(
224                "Compare operations only support arrays of the same length, got {} and {}",
225                lhs.len(),
226                rhs.len()
227            );
228        }
229        Ok(lhs.len())
230    }
231
232    fn is_elementwise(&self) -> bool {
233        true
234    }
235}
236
237struct CompareArgs<'a> {
238    lhs: &'a dyn Array,
239    rhs: &'a dyn Array,
240    operator: Operator,
241}
242
243impl Options for Operator {
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247}
248
249impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
250    type Error = VortexError;
251
252    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
253        if value.inputs.len() != 2 {
254            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
255        }
256        let lhs = value.inputs[0]
257            .array()
258            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
259        let rhs = value.inputs[1]
260            .array()
261            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
262        let operator = *value
263            .options
264            .as_any()
265            .downcast_ref::<Operator>()
266            .vortex_expect("Expected options to be an operator");
267
268        Ok(CompareArgs { lhs, rhs, operator })
269    }
270}
271
272/// Helper function to compare empty values with arrays that have external value length information
273/// like `VarBin`.
274pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
275where
276    P: IntegerPType,
277    I: Iterator<Item = P>,
278{
279    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
280    let cmp_fn = match op {
281        Operator::Eq | Operator::Lte => |v| v == P::zero(),
282        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
283        Operator::Gte => |_| true,
284        Operator::Lt => |_| false,
285    };
286
287    lengths.map(cmp_fn).collect::<BooleanBuffer>()
288}
289
290/// Implementation of `CompareFn` using the Arrow crate.
291fn arrow_compare(
292    left: &dyn Array,
293    right: &dyn Array,
294    operator: Operator,
295) -> VortexResult<ArrayRef> {
296    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
297
298    let array = if left.dtype().is_nested() || right.dtype().is_nested() {
299        let lhs = Datum::try_new_array(&left.to_canonical().into_array())?;
300        let (lhs, _) = lhs.get();
301        let rhs = Datum::try_new_array(&right.to_canonical().into_array())?;
302        let (rhs, _) = rhs.get();
303
304        let cmp = make_comparator(lhs, rhs, SortOptions::default())?;
305        assert_eq!(lhs.len(), rhs.len());
306        let len = lhs.len();
307        let values = (0..len)
308            .map(|i| {
309                let cmp = cmp(i, i);
310                match operator {
311                    Operator::Eq => cmp.is_eq(),
312                    Operator::NotEq => cmp.is_ne(),
313                    Operator::Gt => cmp.is_gt(),
314                    Operator::Gte => cmp.is_gt() || cmp.is_eq(),
315                    Operator::Lt => cmp.is_lt(),
316                    Operator::Lte => cmp.is_lt() || cmp.is_eq(),
317                }
318            })
319            .collect();
320        let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
321        BooleanArray::new(values, nulls)
322    } else {
323        let lhs = Datum::try_new(left)?;
324        let rhs = Datum::try_new(right)?;
325
326        match operator {
327            Operator::Eq => cmp::eq(&lhs, &rhs)?,
328            Operator::NotEq => cmp::neq(&lhs, &rhs)?,
329            Operator::Gt => cmp::gt(&lhs, &rhs)?,
330            Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
331            Operator::Lt => cmp::lt(&lhs, &rhs)?,
332            Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
333        }
334    };
335    Ok(from_arrow_array_with_len(&array, left.len(), nullable))
336}
337
338pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
339    if lhs.is_null() | rhs.is_null() {
340        Scalar::null(DType::Bool(Nullability::Nullable))
341    } else {
342        let b = match operator {
343            Operator::Eq => lhs == rhs,
344            Operator::NotEq => lhs != rhs,
345            Operator::Gt => lhs > rhs,
346            Operator::Gte => lhs >= rhs,
347            Operator::Lt => lhs < rhs,
348            Operator::Lte => lhs <= rhs,
349        };
350
351        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use arrow_buffer::BooleanBuffer;
358    use rstest::rstest;
359
360    use super::*;
361    use crate::ToCanonical;
362    use crate::arrays::{
363        BoolArray, ConstantArray, ListArray, PrimitiveArray, StructArray, VarBinArray,
364        VarBinViewArray,
365    };
366    use crate::test_harness::to_int_indices;
367    use crate::validity::Validity;
368
369    #[test]
370    fn test_bool_basic_comparisons() {
371        let arr = BoolArray::from_bool_buffer(
372            BooleanBuffer::from_iter([true, true, false, true, false]),
373            Validity::from_iter([false, true, true, true, true]),
374        );
375
376        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
377            .unwrap()
378            .to_bool();
379
380        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
381
382        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
383            .unwrap()
384            .to_bool();
385        let empty: [u64; 0] = [];
386        assert_eq!(to_int_indices(matches).unwrap(), empty);
387
388        let other = BoolArray::from_bool_buffer(
389            BooleanBuffer::from_iter([false, false, false, true, true]),
390            Validity::from_iter([false, true, true, true, true]),
391        );
392
393        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
394            .unwrap()
395            .to_bool();
396        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
397
398        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
399            .unwrap()
400            .to_bool();
401        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
402
403        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
404            .unwrap()
405            .to_bool();
406        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
407
408        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
409            .unwrap()
410            .to_bool();
411        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
412    }
413
414    #[test]
415    fn constant_compare() {
416        let left = ConstantArray::new(Scalar::from(2u32), 10);
417        let right = ConstantArray::new(Scalar::from(10u32), 10);
418
419        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
420        let res = compare.as_constant().unwrap();
421        assert_eq!(res.as_bool().value(), Some(false));
422        assert_eq!(compare.len(), 10);
423
424        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
425        let res = compare.as_constant().unwrap();
426        assert_eq!(res.as_bool().value(), Some(false));
427        assert_eq!(compare.len(), 10);
428    }
429
430    #[rstest]
431    #[case(Operator::Eq, vec![false, false, false, true])]
432    #[case(Operator::NotEq, vec![true, true, true, false])]
433    #[case(Operator::Gt, vec![true, true, true, false])]
434    #[case(Operator::Gte, vec![true, true, true, true])]
435    #[case(Operator::Lt, vec![false, false, false, false])]
436    #[case(Operator::Lte, vec![false, false, false, true])]
437    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
438        let lengths: Vec<i32> = vec![1, 5, 7, 0];
439
440        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
441        assert_eq!(Vec::from_iter(output.iter()), expected);
442    }
443
444    #[rstest]
445    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
446    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
447    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
448    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
449    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
450        let res = compare(&left, &right, Operator::Eq).unwrap();
451        assert_eq!(res.to_bool().boolean_buffer().count_set_bits(), left.len());
452    }
453
454    #[ignore = "Arrow's ListView cannot be compared"]
455    #[test]
456    fn test_list_array_comparison() {
457        // Create two simple list arrays with integers
458        let values1 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
459        let offsets1 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
460        let list1 = ListArray::try_new(
461            values1.into_array(),
462            offsets1.into_array(),
463            Validity::NonNullable,
464        )
465        .unwrap();
466
467        let values2 = PrimitiveArray::from_iter([1i32, 2, 3, 4, 7, 8]);
468        let offsets2 = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
469        let list2 = ListArray::try_new(
470            values2.into_array(),
471            offsets2.into_array(),
472            Validity::NonNullable,
473        )
474        .unwrap();
475
476        // Test equality - first two lists should be equal, third should be different
477        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Eq).unwrap();
478        let bool_result = result.to_bool();
479        assert!(bool_result.boolean_buffer().value(0)); // [1,2] == [1,2]
480        assert!(bool_result.boolean_buffer().value(1)); // [3,4] == [3,4]
481        assert!(!bool_result.boolean_buffer().value(2)); // [5,6] != [7,8]
482
483        // Test inequality
484        let result = compare(list1.as_ref(), list2.as_ref(), Operator::NotEq).unwrap();
485        let bool_result = result.to_bool();
486        assert!(!bool_result.boolean_buffer().value(0));
487        assert!(!bool_result.boolean_buffer().value(1));
488        assert!(bool_result.boolean_buffer().value(2));
489
490        // Test less than
491        let result = compare(list1.as_ref(), list2.as_ref(), Operator::Lt).unwrap();
492        let bool_result = result.to_bool();
493        assert!(!bool_result.boolean_buffer().value(0)); // [1,2] < [1,2] = false
494        assert!(!bool_result.boolean_buffer().value(1)); // [3,4] < [3,4] = false
495        assert!(bool_result.boolean_buffer().value(2)); // [5,6] < [7,8] = true
496    }
497
498    #[ignore = "Arrow's ListView cannot be compared"]
499    #[test]
500    fn test_list_array_constant_comparison() {
501        use std::sync::Arc;
502
503        use vortex_dtype::{DType, PType};
504
505        // Create a list array
506        let values = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6]);
507        let offsets = PrimitiveArray::from_iter([0i32, 2, 4, 6]);
508        let list = ListArray::try_new(
509            values.into_array(),
510            offsets.into_array(),
511            Validity::NonNullable,
512        )
513        .unwrap();
514
515        // Create a constant list scalar [3,4] that will be broadcasted
516        let list_scalar = Scalar::list(
517            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
518            vec![3i32.into(), 4i32.into()],
519            Nullability::NonNullable,
520        );
521        let constant = ConstantArray::new(list_scalar, 3);
522
523        // Compare list with constant - all should be compared to [3,4]
524        let result = compare(list.as_ref(), constant.as_ref(), Operator::Eq).unwrap();
525        let bool_result = result.to_bool();
526        assert!(!bool_result.boolean_buffer().value(0)); // [1,2] != [3,4]
527        assert!(bool_result.boolean_buffer().value(1)); // [3,4] == [3,4]
528        assert!(!bool_result.boolean_buffer().value(2)); // [5,6] != [3,4]
529    }
530
531    #[test]
532    fn test_struct_array_comparison() {
533        // Create two struct arrays with bool and int fields
534        let bool_field1 = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
535        let int_field1 = PrimitiveArray::from_iter([1i32, 2, 3]);
536
537        let bool_field2 = BoolArray::from_iter([Some(true), Some(false), Some(false)]);
538        let int_field2 = PrimitiveArray::from_iter([1i32, 2, 4]);
539
540        let struct1 = StructArray::from_fields(&[
541            ("bool_col", bool_field1.into_array()),
542            ("int_col", int_field1.into_array()),
543        ])
544        .unwrap();
545
546        let struct2 = StructArray::from_fields(&[
547            ("bool_col", bool_field2.into_array()),
548            ("int_col", int_field2.into_array()),
549        ])
550        .unwrap();
551
552        // Test equality
553        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Eq).unwrap();
554        let bool_result = result.to_bool();
555        assert!(bool_result.boolean_buffer().value(0)); // {true, 1} == {true, 1}
556        assert!(bool_result.boolean_buffer().value(1)); // {false, 2} == {false, 2}
557        assert!(!bool_result.boolean_buffer().value(2)); // {true, 3} != {false, 4}
558
559        // Test greater than
560        let result = compare(struct1.as_ref(), struct2.as_ref(), Operator::Gt).unwrap();
561        let bool_result = result.to_bool();
562        assert!(!bool_result.boolean_buffer().value(0)); // {true, 1} > {true, 1} = false
563        assert!(!bool_result.boolean_buffer().value(1)); // {false, 2} > {false, 2} = false
564        assert!(bool_result.boolean_buffer().value(2)); // {true, 3} > {false, 4} = true (bool field takes precedence)
565    }
566}