vortex_array/compute/
compare.rs

1use core::fmt;
2use std::any::Any;
3use std::fmt::{Display, Formatter};
4use std::sync::LazyLock;
5
6use arcref::ArcRef;
7use arrow_buffer::BooleanBuffer;
8use arrow_ord::cmp;
9use vortex_dtype::{DType, NativePType, Nullability};
10use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
11use vortex_scalar::Scalar;
12
13use crate::arrays::ConstantArray;
14use crate::arrow::{Datum, from_arrow_array_with_len};
15use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
16use crate::vtable::VTable;
17use crate::{Array, ArrayRef, Canonical, IntoArray};
18
19/// Compares two arrays and returns a new boolean array with the result of the comparison.
20/// Or, returns None if comparison is not supported for these arrays.
21pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
22    COMPARE_FN
23        .invoke(&InvocationArgs {
24            inputs: &[left.into(), right.into()],
25            options: &operator,
26        })?
27        .unwrap_array()
28}
29
30#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
31pub enum Operator {
32    Eq,
33    NotEq,
34    Gt,
35    Gte,
36    Lt,
37    Lte,
38}
39
40impl Display for Operator {
41    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
42        let display = match &self {
43            Operator::Eq => "=",
44            Operator::NotEq => "!=",
45            Operator::Gt => ">",
46            Operator::Gte => ">=",
47            Operator::Lt => "<",
48            Operator::Lte => "<=",
49        };
50        Display::fmt(display, f)
51    }
52}
53
54impl Operator {
55    pub fn inverse(self) -> Self {
56        match self {
57            Operator::Eq => Operator::NotEq,
58            Operator::NotEq => Operator::Eq,
59            Operator::Gt => Operator::Lte,
60            Operator::Gte => Operator::Lt,
61            Operator::Lt => Operator::Gte,
62            Operator::Lte => Operator::Gt,
63        }
64    }
65
66    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
67    pub fn swap(self) -> Self {
68        match self {
69            Operator::Eq => Operator::Eq,
70            Operator::NotEq => Operator::NotEq,
71            Operator::Gt => Operator::Lt,
72            Operator::Gte => Operator::Lte,
73            Operator::Lt => Operator::Gt,
74            Operator::Lte => Operator::Gte,
75        }
76    }
77}
78
79pub struct CompareKernelRef(ArcRef<dyn Kernel>);
80inventory::collect!(CompareKernelRef);
81
82pub trait CompareKernel: VTable {
83    fn compare(
84        &self,
85        lhs: &Self::Array,
86        rhs: &dyn Array,
87        operator: Operator,
88    ) -> VortexResult<Option<ArrayRef>>;
89}
90
91#[derive(Debug)]
92pub struct CompareKernelAdapter<V: VTable>(pub V);
93
94impl<V: VTable + CompareKernel> CompareKernelAdapter<V> {
95    pub const fn lift(&'static self) -> CompareKernelRef {
96        CompareKernelRef(ArcRef::new_ref(self))
97    }
98}
99
100impl<V: VTable + CompareKernel> Kernel for CompareKernelAdapter<V> {
101    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
102        let inputs = CompareArgs::try_from(args)?;
103        let Some(array) = inputs.lhs.as_opt::<V>() else {
104            return Ok(None);
105        };
106        Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into()))
107    }
108}
109
110pub static COMPARE_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
111    let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare));
112    for kernel in inventory::iter::<CompareKernelRef> {
113        compute.register_kernel(kernel.0.clone());
114    }
115    compute
116});
117
118struct Compare;
119
120impl ComputeFnVTable for Compare {
121    fn invoke(
122        &self,
123        args: &InvocationArgs,
124        kernels: &[ArcRef<dyn Kernel>],
125    ) -> VortexResult<Output> {
126        let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?;
127
128        let return_dtype = self.return_dtype(args)?;
129
130        if lhs.is_empty() {
131            return Ok(Canonical::empty(&return_dtype).into_array().into());
132        }
133
134        let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false);
135        let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false);
136        if left_constant_null || right_constant_null {
137            return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len())
138                .into_array()
139                .into());
140        }
141
142        let right_is_constant = rhs.is_constant();
143
144        // Always try to put constants on the right-hand side so encodings can optimise themselves.
145        if lhs.is_constant() && !right_is_constant {
146            return Ok(compare(rhs, lhs, operator.swap())?.into());
147        }
148
149        // First try lhs op rhs, then invert and try again.
150        for kernel in kernels {
151            if let Some(output) = kernel.invoke(args)? {
152                return Ok(output);
153            }
154        }
155        if let Some(output) = lhs.invoke(&COMPARE_FN, args)? {
156            return Ok(output);
157        }
158
159        // Try inverting the operator and swapping the arguments
160        let inverted_args = InvocationArgs {
161            inputs: &[rhs.into(), lhs.into()],
162            options: &operator.swap(),
163        };
164        for kernel in kernels {
165            if let Some(output) = kernel.invoke(&inverted_args)? {
166                return Ok(output);
167            }
168        }
169        if let Some(output) = rhs.invoke(&COMPARE_FN, &inverted_args)? {
170            return Ok(output);
171        }
172
173        // Only log missing compare implementation if there's possibly better one than arrow,
174        // i.e. lhs isn't arrow or rhs isn't arrow or constant
175        if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) {
176            log::debug!(
177                "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
178                rhs.encoding_id(),
179                lhs.encoding_id(),
180                operator.swap(),
181            );
182        }
183
184        // Fallback to arrow on canonical types
185        Ok(arrow_compare(lhs, rhs, operator)?.into())
186    }
187
188    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
189        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
190
191        if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) {
192            vortex_bail!(
193                "Cannot compare different DTypes {} and {}",
194                lhs.dtype(),
195                rhs.dtype()
196            );
197        }
198
199        // TODO(ngates): no reason why not
200        if lhs.dtype().is_struct() {
201            vortex_bail!(
202                "Compare does not support arrays with Struct DType, got: {} and {}",
203                lhs.dtype(),
204                rhs.dtype()
205            )
206        }
207
208        Ok(DType::Bool(
209            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
210        ))
211    }
212
213    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
214        let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?;
215        if lhs.len() != rhs.len() {
216            vortex_bail!(
217                "Compare operations only support arrays of the same length, got {} and {}",
218                lhs.len(),
219                rhs.len()
220            );
221        }
222        Ok(lhs.len())
223    }
224
225    fn is_elementwise(&self) -> bool {
226        true
227    }
228}
229
230struct CompareArgs<'a> {
231    lhs: &'a dyn Array,
232    rhs: &'a dyn Array,
233    operator: Operator,
234}
235
236impl Options for Operator {
237    fn as_any(&self) -> &dyn Any {
238        self
239    }
240}
241
242impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> {
243    type Error = VortexError;
244
245    fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
246        if value.inputs.len() != 2 {
247            vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
248        }
249        let lhs = value.inputs[0]
250            .array()
251            .ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
252        let rhs = value.inputs[1]
253            .array()
254            .ok_or_else(|| vortex_err!("Expected second input to be an array"))?;
255        let operator = *value
256            .options
257            .as_any()
258            .downcast_ref::<Operator>()
259            .vortex_expect("Expected options to be an operator");
260
261        Ok(CompareArgs { lhs, rhs, operator })
262    }
263}
264
265/// Helper function to compare empty values with arrays that have external value length information
266/// like `VarBin`.
267pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
268where
269    P: NativePType,
270    I: Iterator<Item = P>,
271{
272    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
273    let cmp_fn = match op {
274        Operator::Eq | Operator::Lte => |v| v == P::zero(),
275        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
276        Operator::Gte => |_| true,
277        Operator::Lt => |_| false,
278    };
279
280    lengths.map(cmp_fn).collect::<BooleanBuffer>()
281}
282
283/// Implementation of `CompareFn` using the Arrow crate.
284fn arrow_compare(
285    left: &dyn Array,
286    right: &dyn Array,
287    operator: Operator,
288) -> VortexResult<ArrayRef> {
289    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
290    let lhs = Datum::try_new(left)?;
291    let rhs = Datum::try_new(right)?;
292
293    let array = match operator {
294        Operator::Eq => cmp::eq(&lhs, &rhs)?,
295        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
296        Operator::Gt => cmp::gt(&lhs, &rhs)?,
297        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
298        Operator::Lt => cmp::lt(&lhs, &rhs)?,
299        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
300    };
301    from_arrow_array_with_len(&array, left.len(), nullable)
302}
303
304pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
305    if lhs.is_null() | rhs.is_null() {
306        Scalar::null(DType::Bool(Nullability::Nullable))
307    } else {
308        let b = match operator {
309            Operator::Eq => lhs == rhs,
310            Operator::NotEq => lhs != rhs,
311            Operator::Gt => lhs > rhs,
312            Operator::Gte => lhs >= rhs,
313            Operator::Lt => lhs < rhs,
314            Operator::Lte => lhs <= rhs,
315        };
316
317        Scalar::bool(
318            b,
319            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
320        )
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use arrow_buffer::BooleanBuffer;
327    use rstest::rstest;
328
329    use super::*;
330    use crate::ToCanonical;
331    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
332    use crate::test_harness::to_int_indices;
333    use crate::validity::Validity;
334
335    #[test]
336    fn test_bool_basic_comparisons() {
337        let arr = BoolArray::new(
338            BooleanBuffer::from_iter([true, true, false, true, false]),
339            Validity::from_iter([false, true, true, true, true]),
340        );
341
342        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::Eq)
343            .unwrap()
344            .to_bool()
345            .unwrap();
346
347        assert_eq!(to_int_indices(matches).unwrap(), [1u64, 2, 3, 4]);
348
349        let matches = compare(arr.as_ref(), arr.as_ref(), Operator::NotEq)
350            .unwrap()
351            .to_bool()
352            .unwrap();
353        let empty: [u64; 0] = [];
354        assert_eq!(to_int_indices(matches).unwrap(), empty);
355
356        let other = BoolArray::new(
357            BooleanBuffer::from_iter([false, false, false, true, true]),
358            Validity::from_iter([false, true, true, true, true]),
359        );
360
361        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lte)
362            .unwrap()
363            .to_bool()
364            .unwrap();
365        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
366
367        let matches = compare(arr.as_ref(), other.as_ref(), Operator::Lt)
368            .unwrap()
369            .to_bool()
370            .unwrap();
371        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
372
373        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gte)
374            .unwrap()
375            .to_bool()
376            .unwrap();
377        assert_eq!(to_int_indices(matches).unwrap(), [2u64, 3, 4]);
378
379        let matches = compare(other.as_ref(), arr.as_ref(), Operator::Gt)
380            .unwrap()
381            .to_bool()
382            .unwrap();
383        assert_eq!(to_int_indices(matches).unwrap(), [4u64]);
384    }
385
386    #[test]
387    fn constant_compare() {
388        let left = ConstantArray::new(Scalar::from(2u32), 10);
389        let right = ConstantArray::new(Scalar::from(10u32), 10);
390
391        let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap();
392        let res = compare.as_constant().unwrap();
393        assert_eq!(res.as_bool().value(), Some(false));
394        assert_eq!(compare.len(), 10);
395
396        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
397        let res = compare.as_constant().unwrap();
398        assert_eq!(res.as_bool().value(), Some(false));
399        assert_eq!(compare.len(), 10);
400    }
401
402    #[rstest]
403    #[case(Operator::Eq, vec![false, false, false, true])]
404    #[case(Operator::NotEq, vec![true, true, true, false])]
405    #[case(Operator::Gt, vec![true, true, true, false])]
406    #[case(Operator::Gte, vec![true, true, true, true])]
407    #[case(Operator::Lt, vec![false, false, false, false])]
408    #[case(Operator::Lte, vec![false, false, false, true])]
409    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
410        let lengths: Vec<i32> = vec![1, 5, 7, 0];
411
412        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
413        assert_eq!(Vec::from_iter(output.iter()), expected);
414    }
415
416    #[rstest]
417    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
418    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
419    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
420    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
421    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
422        let res = compare(&left, &right, Operator::Eq).unwrap();
423        assert_eq!(
424            res.to_bool().unwrap().boolean_buffer().count_set_bits(),
425            left.len()
426        );
427    }
428}