vortex_array/compute/
compare.rs

1use core::fmt;
2use std::any::Any;
3use std::fmt::{Display, Formatter};
4use std::sync::LazyLock;
5
6use arrow_buffer::BooleanBuffer;
7use arrow_ord::cmp;
8use arrow_schema::DataType;
9use vortex_dtype::{DType, NativePType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::Scalar;
12
13use crate::arcref::ArcRef;
14use crate::arrays::ConstantArray;
15use crate::arrow::{Datum, from_arrow_array_with_len};
16use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
17use crate::encoding::Encoding;
18use crate::{Array, ArrayRef, Canonical, IntoArray};
19
20/// Compares two arrays and returns a new boolean array with the result of the comparison.
21/// Or, returns None if comparison is not supported for these arrays.
22pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
23    COMPARE_FN
24        .invoke(&InvocationArgs {
25            inputs: &[left.into(), right.into()],
26            options: &operator,
27        })?
28        .unwrap_array()
29}
30
31#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
32pub enum Operator {
33    Eq,
34    NotEq,
35    Gt,
36    Gte,
37    Lt,
38    Lte,
39}
40
41impl Display for Operator {
42    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
43        let display = match &self {
44            Operator::Eq => "=",
45            Operator::NotEq => "!=",
46            Operator::Gt => ">",
47            Operator::Gte => ">=",
48            Operator::Lt => "<",
49            Operator::Lte => "<=",
50        };
51        Display::fmt(display, f)
52    }
53}
54
55impl Operator {
56    pub fn inverse(self) -> Self {
57        match self {
58            Operator::Eq => Operator::NotEq,
59            Operator::NotEq => Operator::Eq,
60            Operator::Gt => Operator::Lte,
61            Operator::Gte => Operator::Lt,
62            Operator::Lt => Operator::Gte,
63            Operator::Lte => Operator::Gt,
64        }
65    }
66
67    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
68    pub fn swap(self) -> Self {
69        match self {
70            Operator::Eq => Operator::Eq,
71            Operator::NotEq => Operator::NotEq,
72            Operator::Gt => Operator::Lt,
73            Operator::Gte => Operator::Lte,
74            Operator::Lt => Operator::Gt,
75            Operator::Lte => Operator::Gte,
76        }
77    }
78}
79
80pub struct CompareKernelRef(ArcRef<dyn Kernel>);
81inventory::collect!(CompareKernelRef);
82
83pub trait CompareKernel: Encoding {
84    fn compare(
85        &self,
86        lhs: &Self::Array,
87        rhs: &dyn Array,
88        operator: Operator,
89    ) -> VortexResult<Option<ArrayRef>>;
90}
91
92#[derive(Debug)]
93pub struct CompareKernelAdapter<E: Encoding>(pub E);
94
95impl<E: Encoding + CompareKernel> CompareKernelAdapter<E> {
96    pub const fn lift(&'static self) -> CompareKernelRef {
97        CompareKernelRef(ArcRef::new_ref(self))
98    }
99}
100
101impl<E: Encoding + CompareKernel> Kernel for CompareKernelAdapter<E> {
102    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
103        let inputs = CompareArgs::try_from(args)?;
104        let Some(array) = inputs.lhs.as_any().downcast_ref::<E::Array>() else {
105            return Ok(None);
106        };
107        Ok(E::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
108    }
109}
110
111pub static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
112    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
113    for kernel in inventory::iter::<CompareKernelRef> {
114        compute.register_kernel(kernel.0.clone());
115    }
116    compute
117});
118
119struct Compare;
120
121impl ComputeFnVTable for Compare {
122    fn invoke(
123        &self,
124        args: &InvocationArgs,
125        kernels: &[ArcRef<dyn Kernel>],
126    ) -> VortexResult<Output> {
127        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
128
129        let return_dtype = self.return_dtype(args)?;
130
131        if lhs.is_empty() {
132            return Ok(Canonical::empty(&return_dtype).into_array().into());
133        }
134
135        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
136        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
137        if left_constant_null || right_constant_null {
138            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
139                .into_array()
140                .into());
141        }
142
143        let right_is_constant = rhs.is_constant();
144
145        // Always try to put constants on the right-hand side so encodings can optimise themselves.
146        if lhs.is_constant() && !right_is_constant {
147            return Ok(compare(rhs, lhs, operator.swap())?.into());
148        }
149
150        // First try lhs op rhs, then invert and try again.
151        for kernel in kernels {
152            if let Some(output) = kernel.invoke(args)? {
153                return Ok(output);
154            }
155        }
156        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
157            return Ok(output);
158        }
159
160        // Try inverting the operator and swapping the arguments
161        let inverted_args = InvocationArgs {
162            inputs: &[rhs.into(), lhs.into()],
163            options: &operator.swap(),
164        };
165        for kernel in kernels {
166            if let Some(output) = kernel.invoke(&inverted_args)? {
167                return Ok(output);
168            }
169        }
170        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
171            return Ok(output);
172        }
173
174        // Only log missing compare implementation if there's possibly better one than arrow,
175        // i.e. lhs isn't arrow or rhs isn't arrow or constant
176        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
177            log::debug!(
178                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
179                rhs.encoding(),
180                lhs.encoding(),
181                operator.swap(),
182            );
183        }
184
185        // Fallback to arrow on canonical types
186        Ok(arrow_compare(lhs, rhs, operator)?.into())
187    }
188
189    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
190        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
191
192        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
193            vortex_bail!(
194                "Cannot compare different DTypes {} and {}",
195                lhs.dtype(),
196                rhs.dtype()
197            );
198        }
199
200        // TODO(ngates): no reason why not
201        if lhs.dtype().is_struct() {
202            vortex_bail!(
203                "Compare does not support arrays with Struct DType, got: {} and {}",
204                lhs.dtype(),
205                rhs.dtype()
206            )
207        }
208
209        Ok(DType::Bool(
210            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
211        ))
212    }
213
214    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
215        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
216        if lhs.len() != rhs.len() {
217            vortex_bail!(
218                "Compare operations only support arrays of the same length, got {} and {}",
219                lhs.len(),
220                rhs.len()
221            );
222        }
223        Ok(lhs.len())
224    }
225
226    fn is_elementwise(&self) -> bool {
227        true
228    }
229}
230
231struct CompareArgs<'a> {
232    lhs: &'a dyn Array,
233    rhs: &'a dyn Array,
234    operator: Operator,
235}
236
237impl Options for Operator {
238    fn as_any(&self) -> &dyn Any {
239        self
240    }
241}
242
243impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
244    type Error = VortexError;
245
246    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
247        if value.inputs.len() != 2 {
248            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
249        }
250        let lhs = value.inputs[0]
251            .array()
252            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
253        let rhs = value.inputs[1]
254            .array()
255            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
256        let operator = *value
257            .options
258            .as_any()
259            .downcast_ref::<Operator>()
260            .vortex_expect("Expected options to be an operator");
261
262        Ok(CompareArgs { lhs, rhs, operator })
263    }
264}
265
266/// Helper function to compare empty values with arrays that have external value length information
267/// like `VarBin`.
268pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
269where
270    P: NativePType,
271    I: Iterator<Item = P>,
272{
273    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
274    let cmp_fn = match op {
275        Operator::Eq | Operator::Lte => |v| v == P::zero(),
276        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
277        Operator::Gte => |_| true,
278        Operator::Lt => |_| false,
279    };
280
281    lengths.map(cmp_fn).collect::<BooleanBuffer>()
282}
283
284/// Implementation of `CompareFn` using the Arrow crate.
285fn arrow_compare(
286    left: &dyn Array,
287    right: &dyn Array,
288    operator: Operator,
289) -> VortexResult<ArrayRef> {
290    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
291    let lhs = datum_for_cmp(left)?;
292    let rhs = datum_for_cmp(right)?;
293
294    let array = match operator {
295        Operator::Eq => cmp::eq(&lhs, &rhs)?,
296        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
297        Operator::Gt => cmp::gt(&lhs, &rhs)?,
298        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
299        Operator::Lt => cmp::lt(&lhs, &rhs)?,
300        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
301    };
302    from_arrow_array_with_len(&array, left.len(), nullable)
303}
304
305pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
306    if lhs.is_null() | rhs.is_null() {
307        Scalar::null(DType::Bool(Nullability::Nullable))
308    } else {
309        let b = match operator {
310            Operator::Eq => lhs == rhs,
311            Operator::NotEq => lhs != rhs,
312            Operator::Gt => lhs > rhs,
313            Operator::Gte => lhs >= rhs,
314            Operator::Lt => lhs < rhs,
315            Operator::Lte => lhs <= rhs,
316        };
317
318        Scalar::bool(
319            b,
320            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
321        )
322    }
323}
324
325// Make sure both of the arrays end up with the same arrow data type
326fn datum_for_cmp(array: &dyn Array) -> VortexResult<Datum> {
327    if matches!(array.dtype(), DType::Utf8(_)) {
328        Datum::with_target_datatype(array, &DataType::Utf8View)
329    } else if matches!(array.dtype(), DType::Binary(_)) {
330        Datum::with_target_datatype(array, &DataType::BinaryView)
331    } else {
332        Datum::try_new(array)
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use arrow_buffer::BooleanBuffer;
339    use itertools::Itertools;
340    use rstest::rstest;
341
342    use super::*;
343    use crate::ToCanonical;
344    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
345    use crate::validity::Validity;
346
347    fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
348        let buffer = indices_bits.boolean_buffer();
349        let mask = indices_bits.validity_mask().unwrap();
350        buffer
351            .iter()
352            .enumerate()
353            .filter_map(|(idx, v)| (v && mask.value(idx)).then_some(idx as u64))
354            .collect_vec()
355    }
356
357    #[test]
358    fn test_bool_basic_comparisons() {
359        let arr = BoolArray::new(
360            BooleanBuffer::from_iter([true, true, false, true, false]),
361            Validity::from_iter([false, true, true, true, true]),
362        );
363
364        let matches = compare(&arr, &arr, Operator::Eq)
365            .unwrap()
366            .to_bool()
367            .unwrap();
368
369        assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);
370
371        let matches = compare(&arr, &arr, Operator::NotEq)
372            .unwrap()
373            .to_bool()
374            .unwrap();
375        let empty: [u64; 0] = [];
376        assert_eq!(to_int_indices(matches), empty);
377
378        let other = BoolArray::new(
379            BooleanBuffer::from_iter([false, false, false, true, true]),
380            Validity::from_iter([false, true, true, true, true]),
381        );
382
383        let matches = compare(&arr, &other, Operator::Lte)
384            .unwrap()
385            .to_bool()
386            .unwrap();
387        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
388
389        let matches = compare(&arr, &other, Operator::Lt)
390            .unwrap()
391            .to_bool()
392            .unwrap();
393        assert_eq!(to_int_indices(matches), [4u64]);
394
395        let matches = compare(&other, &arr, Operator::Gte)
396            .unwrap()
397            .to_bool()
398            .unwrap();
399        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
400
401        let matches = compare(&other, &arr, Operator::Gt)
402            .unwrap()
403            .to_bool()
404            .unwrap();
405        assert_eq!(to_int_indices(matches), [4u64]);
406    }
407
408    #[test]
409    fn constant_compare() {
410        let left = ConstantArray::new(Scalar::from(2u32), 10);
411        let right = ConstantArray::new(Scalar::from(10u32), 10);
412
413        let compare = compare(&left, &right, Operator::Gt).unwrap();
414        let res = compare.as_constant().unwrap();
415        assert_eq!(res.as_bool().value(), Some(false));
416        assert_eq!(compare.len(), 10);
417
418        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
419        let res = compare.as_constant().unwrap();
420        assert_eq!(res.as_bool().value(), Some(false));
421        assert_eq!(compare.len(), 10);
422    }
423
424    #[rstest]
425    #[case(Operator::Eq, vec![false, false, false, true])]
426    #[case(Operator::NotEq, vec![true, true, true, false])]
427    #[case(Operator::Gt, vec![true, true, true, false])]
428    #[case(Operator::Gte, vec![true, true, true, true])]
429    #[case(Operator::Lt, vec![false, false, false, false])]
430    #[case(Operator::Lte, vec![false, false, false, true])]
431    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
432        let lengths: Vec<i32> = vec![1, 5, 7, 0];
433
434        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
435        assert_eq!(Vec::from_iter(output.iter()), expected);
436    }
437
438    #[rstest]
439    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
440    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
441    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
442    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
443    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
444        let res = arrow_compare(&left, &right, Operator::Eq).unwrap();
445        assert_eq!(
446            res.to_bool().unwrap().boolean_buffer().count_set_bits(),
447            left.len()
448        );
449    }
450}