vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use core::fmt;
5use std::any::Any;
6use std::fmt::{Display, Formatter};
7use std::sync::LazyLock;
8
9use arcref::ArcRef;
10use arrow_buffer::BooleanBuffer;
11use arrow_ord::cmp;
12use vortex_dtype::{DType, NativePType, Nullability};
13use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_scalar::Scalar;
15
16use crate::arrays::ConstantArray;
17use crate::arrow::{Datum, from_arrow_array_with_len};
18use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
19use crate::vtable::VTable;
20use crate::{Array, ArrayRef, Canonical, IntoArray};
21
22static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
23    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
24    for kernel in inventory::iter::<CompareKernelRef> {
25        compute.register_kernel(kernel.0.clone());
26    }
27    compute
28});
29
30pub(crate) fn warm_up_vtable() -> usize {
31    COMPARE_FN.kernels().len()
32}
33
34/// Compares two arrays and returns a new boolean array with the result of the comparison.
35/// Or, returns None if comparison is not supported for these arrays.
36pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
37    COMPARE_FN
38        .invoke(&InvocationArgs {
39            inputs: &[left.into(), right.into()],
40            options: &operator,
41        })?
42        .unwrap_array()
43}
44
45#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
46pub enum Operator {
47    /// Equality (`=`)
48    Eq,
49    /// Inequality (`!=`)
50    NotEq,
51    /// Greater than (`>`)
52    Gt,
53    /// Greater than or equal (`>=`)
54    Gte,
55    /// Less than (`<`)
56    Lt,
57    /// Less than or equal (`<=`)
58    Lte,
59}
60
61impl Display for Operator {
62    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
63        let display = match &self {
64            Operator::Eq => "=",
65            Operator::NotEq => "!=",
66            Operator::Gt => ">",
67            Operator::Gte => ">=",
68            Operator::Lt => "<",
69            Operator::Lte => "<=",
70        };
71        Display::fmt(display, f)
72    }
73}
74
75impl Operator {
76    pub fn inverse(self) -> Self {
77        match self {
78            Operator::Eq => Operator::NotEq,
79            Operator::NotEq => Operator::Eq,
80            Operator::Gt => Operator::Lte,
81            Operator::Gte => Operator::Lt,
82            Operator::Lt => Operator::Gte,
83            Operator::Lte => Operator::Gt,
84        }
85    }
86
87    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
88    pub fn swap(self) -> Self {
89        match self {
90            Operator::Eq => Operator::Eq,
91            Operator::NotEq => Operator::NotEq,
92            Operator::Gt => Operator::Lt,
93            Operator::Gte => Operator::Lte,
94            Operator::Lt => Operator::Gt,
95            Operator::Lte => Operator::Gte,
96        }
97    }
98}
99
100pub struct CompareKernelRef(ArcRef<dyn Kernel>);
101inventory::collect!(CompareKernelRef);
102
103pub trait CompareKernel: VTable {
104    fn compare(
105        &self,
106        lhs: &Self::Array,
107        rhs: &dyn Array,
108        operator: Operator,
109    ) -> VortexResult<Option<ArrayRef>>;
110}
111
112#[derive(Debug)]
113pub struct CompareKernelAdapter<V: VTable>(pub V);
114
115impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
116    pub const fn lift(&'static self) -> CompareKernelRef {
117        CompareKernelRef(ArcRef::new_ref(self))
118    }
119}
120
121impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
122    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
123        let inputs = CompareArgs::try_from(args)?;
124        let Some(array) = inputs.lhs.as_opt::<V>() else {
125            return Ok(None);
126        };
127        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
128    }
129}
130
131struct Compare;
132
133impl ComputeFnVTable for Compare {
134    fn invoke(
135        &self,
136        args: &InvocationArgs,
137        kernels: &[ArcRef<dyn Kernel>],
138    ) -> VortexResult<Output> {
139        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
140
141        let return_dtype = self.return_dtype(args)?;
142
143        if lhs.is_empty() {
144            return Ok(Canonical::empty(&return_dtype).into_array().into());
145        }
146
147        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
148        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
149        if left_constant_null || right_constant_null {
150            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
151                .into_array()
152                .into());
153        }
154
155        let right_is_constant = rhs.is_constant();
156
157        // Always try to put constants on the right-hand side so encodings can optimise themselves.
158        if lhs.is_constant() && !right_is_constant {
159            return Ok(compare(rhs, lhs, operator.swap())?.into());
160        }
161
162        // First try lhs op rhs, then invert and try again.
163        for kernel in kernels {
164            if let Some(output) = kernel.invoke(args)? {
165                return Ok(output);
166            }
167        }
168        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
169            return Ok(output);
170        }
171
172        // Try inverting the operator and swapping the arguments
173        let inverted_args = InvocationArgs {
174            inputs: &[rhs.into(), lhs.into()],
175            options: &operator.swap(),
176        };
177        for kernel in kernels {
178            if let Some(output) = kernel.invoke(&inverted_args)? {
179                return Ok(output);
180            }
181        }
182        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
183            return Ok(output);
184        }
185
186        // Only log missing compare implementation if there's possibly better one than arrow,
187        // i.e. lhs isn't arrow or rhs isn't arrow or constant
188        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
189            log::debug!(
190                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
191                lhs.encoding_id(),
192                rhs.encoding_id(),
193                operator,
194            );
195        }
196
197        // Fallback to arrow on canonical types
198        Ok(arrow_compare(lhs, rhs, operator)?.into())
199    }
200
201    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
202        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
203
204        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
205            vortex_bail!(
206                "Cannot compare different DTypes {} and {}",
207                lhs.dtype(),
208                rhs.dtype()
209            );
210        }
211
212        // TODO(ngates): no reason why not
213        if lhs.dtype().is_struct() {
214            vortex_bail!(
215                "Compare does not support arrays with Struct DType, got: {} and {}",
216                lhs.dtype(),
217                rhs.dtype()
218            )
219        }
220
221        Ok(DType::Bool(
222            lhs.dtype().nullability() | rhs.dtype().nullability(),
223        ))
224    }
225
226    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
227        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
228        if lhs.len() != rhs.len() {
229            vortex_bail!(
230                "Compare operations only support arrays of the same length, got {} and {}",
231                lhs.len(),
232                rhs.len()
233            );
234        }
235        Ok(lhs.len())
236    }
237
238    fn is_elementwise(&self) -> bool {
239        true
240    }
241}
242
243struct CompareArgs<'a> {
244    lhs: &'a dyn Array,
245    rhs: &'a dyn Array,
246    operator: Operator,
247}
248
249impl Options for Operator {
250    fn as_any(&self) -> &dyn Any {
251        self
252    }
253}
254
255impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
256    type Error = VortexError;
257
258    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
259        if value.inputs.len() != 2 {
260            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
261        }
262        let lhs = value.inputs[0]
263            .array()
264            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
265        let rhs = value.inputs[1]
266            .array()
267            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
268        let operator = *value
269            .options
270            .as_any()
271            .downcast_ref::<Operator>()
272            .vortex_expect("Expected options to be an operator");
273
274        Ok(CompareArgs { lhs, rhs, operator })
275    }
276}
277
278/// Helper function to compare empty values with arrays that have external value length information
279/// like `VarBin`.
280pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
281where
282    P: NativePType,
283    I: Iterator<Item = P>,
284{
285    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
286    let cmp_fn = match op {
287        Operator::Eq | Operator::Lte => |v| v == P::zero(),
288        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
289        Operator::Gte => |_| true,
290        Operator::Lt => |_| false,
291    };
292
293    lengths.map(cmp_fn).collect::<BooleanBuffer>()
294}
295
296/// Implementation of `CompareFn` using the Arrow crate.
297fn arrow_compare(
298    left: &dyn Array,
299    right: &dyn Array,
300    operator: Operator,
301) -> VortexResult<ArrayRef> {
302    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
303    let lhs = Datum::try_new(left)?;
304    let rhs = Datum::try_new(right)?;
305
306    let array = match operator {
307        Operator::Eq => cmp::eq(&lhs, &rhs)?,
308        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
309        Operator::Gt => cmp::gt(&lhs, &rhs)?,
310        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
311        Operator::Lt => cmp::lt(&lhs, &rhs)?,
312        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
313    };
314    Ok(from_arrow_array_with_len(&array, left.len(), nullable))
315}
316
317pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
318    if lhs.is_null() | rhs.is_null() {
319        Scalar::null(DType::Bool(Nullability::Nullable))
320    } else {
321        let b = match operator {
322            Operator::Eq => lhs == rhs,
323            Operator::NotEq => lhs != rhs,
324            Operator::Gt => lhs > rhs,
325            Operator::Gte => lhs >= rhs,
326            Operator::Lt => lhs < rhs,
327            Operator::Lte => lhs <= rhs,
328        };
329
330        Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use arrow_buffer::BooleanBuffer;
337    use rstest::rstest;
338
339    use super::*;
340    use crate::ToCanonical;
341    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
342    use crate::test_harness::to_int_indices;
343    use crate::validity::Validity;
344
345    #[test]
346    fn test_bool_basic_comparisons() {
347        let arr = BoolArray::from_bool_buffer(
348            BooleanBuffer::from_iter([true, true, false, true, false]),
349            Validity::from_iter([false, true, true, true, true]),
350        );
351
352        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
353            .unwrap()
354            .to_bool();
355
356        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
357
358        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
359            .unwrap()
360            .to_bool();
361        let empty: [u64; 0] = [];
362        assert_eq!(to_int_indices(matches).unwrap(), empty);
363
364        let other = BoolArray::from_bool_buffer(
365            BooleanBuffer::from_iter([false, false, false, true, true]),
366            Validity::from_iter([false, true, true, true, true]),
367        );
368
369        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
370            .unwrap()
371            .to_bool();
372        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
373
374        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
375            .unwrap()
376            .to_bool();
377        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
378
379        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
380            .unwrap()
381            .to_bool();
382        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
383
384        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
385            .unwrap()
386            .to_bool();
387        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
388    }
389
390    #[test]
391    fn constant_compare() {
392        let left = ConstantArray::new(Scalar::from(2u32), 10);
393        let right = ConstantArray::new(Scalar::from(10u32), 10);
394
395        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
396        let res = compare.as_constant().unwrap();
397        assert_eq!(res.as_bool().value(), Some(false));
398        assert_eq!(compare.len(), 10);
399
400        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
401        let res = compare.as_constant().unwrap();
402        assert_eq!(res.as_bool().value(), Some(false));
403        assert_eq!(compare.len(), 10);
404    }
405
406    #[rstest]
407    #[case(Operator::Eq, vec![false, false, false, true])]
408    #[case(Operator::NotEq, vec![true, true, true, false])]
409    #[case(Operator::Gt, vec![true, true, true, false])]
410    #[case(Operator::Gte, vec![true, true, true, true])]
411    #[case(Operator::Lt, vec![false, false, false, false])]
412    #[case(Operator::Lte, vec![false, false, false, true])]
413    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
414        let lengths: Vec<i32> = vec![1, 5, 7, 0];
415
416        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
417        assert_eq!(Vec::from_iter(output.iter()), expected);
418    }
419
420    #[rstest]
421    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
422    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
423    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
424    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
425    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
426        let res = compare(&left, &right, Operator::Eq).unwrap();
427        assert_eq!(res.to_bool().boolean_buffer().count_set_bits(), left.len());
428    }
429}