vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_buffer::BooleanBuffer;
11use arrow_ord::cmp;
12use vortex_dtype::{DType, NativePType, Nullability};
13use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_scalar::Scalar;
15
16use crate::arrays::ConstantArray;
17use crate::arrow::{Datum, from_arrow_array_with_len};
18use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19use crate::vtable::VTable;
20use crate::{Array, ArrayRef, Canonical, IntoArray};
21
22static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
24    for kernel in inventory::iter::<CompareKernelRef> {
25        compute.register_kernel(kernel.0.clone());
26    }
27    compute
28});
29
30/// Compares two arrays and returns a new boolean array with the result of the comparison.
31/// Or, returns None if comparison is not supported for these arrays.
32pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
33    COMPARE_FN
34        .invoke(&InvocationArgs {
35            inputs: &[left.into(), right.into()],
36            options: &operator,
37        })?
38        .unwrap_array()
39}
40
41#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
42pub enum Operator {
43    /// Equality (`=`)
44    Eq,
45    /// Inequality (`!=`)
46    NotEq,
47    /// Greater than (`>`)
48    Gt,
49    /// Greater than or equal (`>=`)
50    Gte,
51    /// Less than (`<`)
52    Lt,
53    /// Less than or equal (`<=`)
54    Lte,
55}
56
57impl Display for Operator {
58    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
59        let display = match &self {
60            Operator::Eq => "=",
61            Operator::NotEq => "!=",
62            Operator::Gt => ">",
63            Operator::Gte => ">=",
64            Operator::Lt => "<",
65            Operator::Lte => "<=",
66        };
67        Display::fmt(display, f)
68    }
69}
70
71impl Operator {
72    pub fn inverse(self) -> Self {
73        match self {
74            Operator::Eq => Operator::NotEq,
75            Operator::NotEq => Operator::Eq,
76            Operator::Gt => Operator::Lte,
77            Operator::Gte => Operator::Lt,
78            Operator::Lt => Operator::Gte,
79            Operator::Lte => Operator::Gt,
80        }
81    }
82
83    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
84    pub fn swap(self) -> Self {
85        match self {
86            Operator::Eq => Operator::Eq,
87            Operator::NotEq => Operator::NotEq,
88            Operator::Gt => Operator::Lt,
89            Operator::Gte => Operator::Lte,
90            Operator::Lt => Operator::Gt,
91            Operator::Lte => Operator::Gte,
92        }
93    }
94}
95
96pub struct CompareKernelRef(ArcRef<dyn Kernel>);
97inventory::collect!(CompareKernelRef);
98
99pub trait CompareKernel: VTable {
100    fn compare(
101        &self,
102        lhs: &Self::Array,
103        rhs: &dyn Array,
104        operator: Operator,
105    ) -> VortexResult<Option<ArrayRef>>;
106}
107
108#[derive(Debug)]
109pub struct CompareKernelAdapter<V: VTable>(pub V);
110
111impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
112    pub const fn lift(&'static self) -> CompareKernelRef {
113        CompareKernelRef(ArcRef::new_ref(self))
114    }
115}
116
117impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
118    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
119        let inputs = CompareArgs::try_from(args)?;
120        let Some(array) = inputs.lhs.as_opt::<V>() else {
121            return Ok(None);
122        };
123        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
124    }
125}
126
127struct Compare;
128
129impl ComputeFnVTable for Compare {
130    fn invoke(
131        &self,
132        args: &InvocationArgs,
133        kernels: &[ArcRef<dyn Kernel>],
134    ) -> VortexResult<Output> {
135        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
136
137        let return_dtype = self.return_dtype(args)?;
138
139        if lhs.is_empty() {
140            return Ok(Canonical::empty(&return_dtype).into_array().into());
141        }
142
143        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
144        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
145        if left_constant_null || right_constant_null {
146            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
147                .into_array()
148                .into());
149        }
150
151        let right_is_constant = rhs.is_constant();
152
153        // Always try to put constants on the right-hand side so encodings can optimise themselves.
154        if lhs.is_constant() && !right_is_constant {
155            return Ok(compare(rhs, lhs, operator.swap())?.into());
156        }
157
158        // First try lhs op rhs, then invert and try again.
159        for kernel in kernels {
160            if let Some(output) = kernel.invoke(args)? {
161                return Ok(output);
162            }
163        }
164        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
165            return Ok(output);
166        }
167
168        // Try inverting the operator and swapping the arguments
169        let inverted_args = InvocationArgs {
170            inputs: &[rhs.into(), lhs.into()],
171            options: &operator.swap(),
172        };
173        for kernel in kernels {
174            if let Some(output) = kernel.invoke(&inverted_args)? {
175                return Ok(output);
176            }
177        }
178        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
179            return Ok(output);
180        }
181
182        // Only log missing compare implementation if there's possibly better one than arrow,
183        // i.e. lhs isn't arrow or rhs isn't arrow or constant
184        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
185            log::debug!(
186                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
187                lhs.encoding_id(),
188                rhs.encoding_id(),
189                operator,
190            );
191        }
192
193        // Fallback to arrow on canonical types
194        Ok(arrow_compare(lhs, rhs, operator)?.into())
195    }
196
197    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
198        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
199
200        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
201            vortex_bail!(
202                "Cannot compare different DTypes {} and {}",
203                lhs.dtype(),
204                rhs.dtype()
205            );
206        }
207
208        // TODO(ngates): no reason why not
209        if lhs.dtype().is_struct() {
210            vortex_bail!(
211                "Compare does not support arrays with Struct DType, got: {} and {}",
212                lhs.dtype(),
213                rhs.dtype()
214            )
215        }
216
217        Ok(DType::Bool(
218            lhs.dtype().nullability() | rhs.dtype().nullability(),
219        ))
220    }
221
222    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
223        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
224        if lhs.len() != rhs.len() {
225            vortex_bail!(
226                "Compare operations only support arrays of the same length, got {} and {}",
227                lhs.len(),
228                rhs.len()
229            );
230        }
231        Ok(lhs.len())
232    }
233
234    fn is_elementwise(&self) -> bool {
235        true
236    }
237}
238
239struct CompareArgs<'a> {
240    lhs: &'a dyn Array,
241    rhs: &'a dyn Array,
242    operator: Operator,
243}
244
245impl Options for Operator {
246    fn as_any(&self) -> &dyn Any {
247        self
248    }
249}
250
251impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
252    type Error = VortexError;
253
254    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
255        if value.inputs.len() != 2 {
256            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
257        }
258        let lhs = value.inputs[0]
259            .array()
260            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
261        let rhs = value.inputs[1]
262            .array()
263            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
264        let operator = *value
265            .options
266            .as_any()
267            .downcast_ref::<Operator>()
268            .vortex_expect("Expected options to be an operator");
269
270        Ok(CompareArgs { lhs, rhs, operator })
271    }
272}
273
274/// Helper function to compare empty values with arrays that have external value length information
275/// like `VarBin`.
276pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
277where
278    P: NativePType,
279    I: Iterator<Item = P>,
280{
281    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
282    let cmp_fn = match op {
283        Operator::Eq | Operator::Lte => |v| v == P::zero(),
284        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
285        Operator::Gte => |_| true,
286        Operator::Lt => |_| false,
287    };
288
289    lengths.map(cmp_fn).collect::<BooleanBuffer>()
290}
291
292/// Implementation of `CompareFn` using the Arrow crate.
293fn arrow_compare(
294    left: &dyn Array,
295    right: &dyn Array,
296    operator: Operator,
297) -> VortexResult<ArrayRef> {
298    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
299    let lhs = Datum::try_new(left)?;
300    let rhs = Datum::try_new(right)?;
301
302    let array = match operator {
303        Operator::Eq => cmp::eq(&lhs, &rhs)?,
304        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
305        Operator::Gt => cmp::gt(&lhs, &rhs)?,
306        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
307        Operator::Lt => cmp::lt(&lhs, &rhs)?,
308        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
309    };
310    from_arrow_array_with_len(&array, left.len(), nullable)
311}
312
313pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
314    if lhs.is_null() | rhs.is_null() {
315        Scalar::null(DType::Bool(Nullability::Nullable))
316    } else {
317        let b = match operator {
318            Operator::Eq => lhs == rhs,
319            Operator::NotEq => lhs != rhs,
320            Operator::Gt => lhs > rhs,
321            Operator::Gte => lhs >= rhs,
322            Operator::Lt => lhs < rhs,
323            Operator::Lte => lhs <= rhs,
324        };
325
326        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use arrow_buffer::BooleanBuffer;
333    use rstest::rstest;
334
335    use super::*;
336    use crate::ToCanonical;
337    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
338    use crate::test_harness::to_int_indices;
339    use crate::validity::Validity;
340
341    #[test]
342    fn test_bool_basic_comparisons() {
343        let arr = BoolArray::new(
344            BooleanBuffer::from_iter([true, true, false, true, false]),
345            Validity::from_iter([false, true, true, true, true]),
346        );
347
348        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
349            .unwrap()
350            .to_bool()
351            .unwrap();
352
353        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
354
355        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
356            .unwrap()
357            .to_bool()
358            .unwrap();
359        let empty: [u64; 0] = [];
360        assert_eq!(to_int_indices(matches).unwrap(), empty);
361
362        let other = BoolArray::new(
363            BooleanBuffer::from_iter([false, false, false, true, true]),
364            Validity::from_iter([false, true, true, true, true]),
365        );
366
367        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
368            .unwrap()
369            .to_bool()
370            .unwrap();
371        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
372
373        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
374            .unwrap()
375            .to_bool()
376            .unwrap();
377        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
378
379        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
380            .unwrap()
381            .to_bool()
382            .unwrap();
383        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
384
385        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
386            .unwrap()
387            .to_bool()
388            .unwrap();
389        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
390    }
391
392    #[test]
393    fn constant_compare() {
394        let left = ConstantArray::new(Scalar::from(2u32), 10);
395        let right = ConstantArray::new(Scalar::from(10u32), 10);
396
397        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
398        let res = compare.as_constant().unwrap();
399        assert_eq!(res.as_bool().value(), Some(false));
400        assert_eq!(compare.len(), 10);
401
402        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
403        let res = compare.as_constant().unwrap();
404        assert_eq!(res.as_bool().value(), Some(false));
405        assert_eq!(compare.len(), 10);
406    }
407
408    #[rstest]
409    #[case(Operator::Eq, vec![false, false, false, true])]
410    #[case(Operator::NotEq, vec![true, true, true, false])]
411    #[case(Operator::Gt, vec![true, true, true, false])]
412    #[case(Operator::Gte, vec![true, true, true, true])]
413    #[case(Operator::Lt, vec![false, false, false, false])]
414    #[case(Operator::Lte, vec![false, false, false, true])]
415    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
416        let lengths: Vec<i32> = vec![1, 5, 7, 0];
417
418        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
419        assert_eq!(Vec::from_iter(output.iter()), expected);
420    }
421
422    #[rstest]
423    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
424    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
425    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
426    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
427    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
428        let res = compare(&left, &right, Operator::Eq).unwrap();
429        assert_eq!(
430            res.to_bool().unwrap().boolean_buffer().count_set_bits(),
431            left.len()
432        );
433    }
434}