vortex_array/compute/
compare.rs

1use core::fmt;
2use std::fmt::{Display, Formatter};
3
4use arrow_buffer::BooleanBuffer;
5use arrow_ord::cmp;
6use arrow_schema::DataType;
7use vortex_dtype::{DType, NativePType, Nullability};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::arrow::{Datum, from_arrow_array_with_len};
13use crate::encoding::Encoding;
14use crate::{Array, ArrayRef, Canonical, IntoArray};
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
17pub enum Operator {
18    Eq,
19    NotEq,
20    Gt,
21    Gte,
22    Lt,
23    Lte,
24}
25
26impl Display for Operator {
27    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
28        let display = match &self {
29            Operator::Eq => "=",
30            Operator::NotEq => "!=",
31            Operator::Gt => ">",
32            Operator::Gte => ">=",
33            Operator::Lt => "<",
34            Operator::Lte => "<=",
35        };
36        Display::fmt(display, f)
37    }
38}
39
40impl Operator {
41    pub fn inverse(self) -> Self {
42        match self {
43            Operator::Eq => Operator::NotEq,
44            Operator::NotEq => Operator::Eq,
45            Operator::Gt => Operator::Lte,
46            Operator::Gte => Operator::Lt,
47            Operator::Lt => Operator::Gte,
48            Operator::Lte => Operator::Gt,
49        }
50    }
51
52    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
53    pub fn swap(self) -> Self {
54        match self {
55            Operator::Eq => Operator::Eq,
56            Operator::NotEq => Operator::NotEq,
57            Operator::Gt => Operator::Lt,
58            Operator::Gte => Operator::Lte,
59            Operator::Lt => Operator::Gt,
60            Operator::Lte => Operator::Gte,
61        }
62    }
63}
64
65pub trait CompareFn<A> {
66    /// Compares two arrays and returns a new boolean array with the result of the comparison.
67    /// Or, returns None if comparison is not supported for these arrays.
68    fn compare(
69        &self,
70        lhs: A,
71        rhs: &dyn Array,
72        operator: Operator,
73    ) -> VortexResult<Option<ArrayRef>>;
74}
75
76impl<E: Encoding> CompareFn<&dyn Array> for E
77where
78    E: for<'a> CompareFn<&'a E::Array>,
79{
80    fn compare(
81        &self,
82        lhs: &dyn Array,
83        rhs: &dyn Array,
84        operator: Operator,
85    ) -> VortexResult<Option<ArrayRef>> {
86        let array_ref = lhs
87            .as_any()
88            .downcast_ref::<E::Array>()
89            .vortex_expect("Failed to downcast array");
90
91        CompareFn::compare(self, array_ref, rhs, operator)
92    }
93}
94
95pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
96    if left.len() != right.len() {
97        vortex_bail!("Compare operations only support arrays of the same length");
98    }
99    if !left.dtype().eq_ignore_nullability(right.dtype()) {
100        vortex_bail!(
101            "Cannot compare different DTypes {} and {}",
102            left.dtype(),
103            right.dtype()
104        );
105    }
106
107    // TODO(ngates): no reason why not
108    if left.dtype().is_struct() {
109        vortex_bail!(
110            "Compare does not support arrays with Struct DType, got: {} and {}",
111            left.dtype(),
112            right.dtype()
113        )
114    }
115
116    let result_dtype =
117        DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into());
118
119    if left.is_empty() {
120        return Ok(Canonical::empty(&result_dtype).into_array());
121    }
122
123    let left_constant_null = left.as_constant().map(|l| l.is_null()).unwrap_or(false);
124    let right_constant_null = right.as_constant().map(|r| r.is_null()).unwrap_or(false);
125    if left_constant_null || right_constant_null {
126        return Ok(ConstantArray::new(Scalar::null(result_dtype), left.len()).into_array());
127    }
128
129    let right_is_constant = right.is_constant();
130
131    // Always try to put constants on the right-hand side so encodings can optimise themselves.
132    if left.is_constant() && !right_is_constant {
133        return compare(right, left, operator.swap());
134    }
135
136    if let Some(result) = left
137        .vtable()
138        .compare_fn()
139        .and_then(|f| f.compare(left, right, operator).transpose())
140        .transpose()?
141    {
142        check_compare_result(&result, left, right);
143        return Ok(result);
144    }
145
146    if let Some(result) = right
147        .vtable()
148        .compare_fn()
149        .and_then(|f| f.compare(right, left, operator.swap()).transpose())
150        .transpose()?
151    {
152        check_compare_result(&result, left, right);
153        return Ok(result);
154    }
155
156    // Only log missing compare implementation if there's possibly better one than arrow,
157    // i.e. lhs isn't arrow or rhs isn't arrow or constant
158    if !(left.is_arrow() && (right.is_arrow() || right_is_constant)) {
159        log::debug!(
160            "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
161            right.encoding(),
162            left.encoding(),
163            operator.swap(),
164        );
165    }
166
167    // Fallback to arrow on canonical types
168    let result = arrow_compare(left, right, operator)?;
169    check_compare_result(&result, left, right);
170    Ok(result)
171}
172
173/// Helper function to compare empty values with arrays that have external value length information
174/// like `VarBin`.
175pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
176where
177    P: NativePType,
178    I: Iterator<Item = P>,
179{
180    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
181    let cmp_fn = match op {
182        Operator::Eq | Operator::Lte => |v| v == P::zero(),
183        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
184        Operator::Gte => |_| true,
185        Operator::Lt => |_| false,
186    };
187
188    lengths.map(cmp_fn).collect::<BooleanBuffer>()
189}
190
191/// Implementation of `CompareFn` using the Arrow crate.
192fn arrow_compare(
193    left: &dyn Array,
194    right: &dyn Array,
195    operator: Operator,
196) -> VortexResult<ArrayRef> {
197    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
198    let lhs = datum_for_cmp(left)?;
199    let rhs = datum_for_cmp(right)?;
200
201    let array = match operator {
202        Operator::Eq => cmp::eq(&lhs, &rhs)?,
203        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
204        Operator::Gt => cmp::gt(&lhs, &rhs)?,
205        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
206        Operator::Lt => cmp::lt(&lhs, &rhs)?,
207        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
208    };
209    from_arrow_array_with_len(&array, left.len(), nullable)
210}
211
212#[inline(always)]
213fn check_compare_result(result: &dyn Array, lhs: &dyn Array, rhs: &dyn Array) {
214    assert_eq!(
215        result.len(),
216        lhs.len(),
217        "CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
218        result.len(),
219        lhs.encoding(),
220        lhs.len(),
221        rhs.encoding(),
222        rhs.len()
223    );
224    assert_eq!(
225        result.dtype(),
226        &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
227        "CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
228        result.dtype(),
229        lhs.encoding(),
230        rhs.encoding(),
231    );
232}
233
234pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
235    if lhs.is_null() | rhs.is_null() {
236        Scalar::null(DType::Bool(Nullability::Nullable))
237    } else {
238        let b = match operator {
239            Operator::Eq => lhs == rhs,
240            Operator::NotEq => lhs != rhs,
241            Operator::Gt => lhs > rhs,
242            Operator::Gte => lhs >= rhs,
243            Operator::Lt => lhs < rhs,
244            Operator::Lte => lhs <= rhs,
245        };
246
247        Scalar::bool(
248            b,
249            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
250        )
251    }
252}
253
254// Make sure both of the arrays end up with the same arrow data type
255fn datum_for_cmp(array: &dyn Array) -> VortexResult<Datum> {
256    if matches!(array.dtype(), DType::Utf8(_)) {
257        Datum::with_target_datatype(array, &DataType::Utf8View)
258    } else if matches!(array.dtype(), DType::Binary(_)) {
259        Datum::with_target_datatype(array, &DataType::BinaryView)
260    } else {
261        Datum::try_new(array)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use arrow_buffer::BooleanBuffer;
268    use itertools::Itertools;
269    use rstest::rstest;
270
271    use super::*;
272    use crate::ToCanonical;
273    use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
274    use crate::validity::Validity;
275
276    fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
277        let buffer = indices_bits.boolean_buffer();
278        let mask = indices_bits.validity_mask().unwrap();
279        buffer
280            .iter()
281            .enumerate()
282            .filter_map(|(idx, v)| (v && mask.value(idx)).then_some(idx as u64))
283            .collect_vec()
284    }
285
286    #[test]
287    fn test_bool_basic_comparisons() {
288        let arr = BoolArray::new(
289            BooleanBuffer::from_iter([true, true, false, true, false]),
290            Validity::from_iter([false, true, true, true, true]),
291        );
292
293        let matches = compare(&arr, &arr, Operator::Eq)
294            .unwrap()
295            .to_bool()
296            .unwrap();
297
298        assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);
299
300        let matches = compare(&arr, &arr, Operator::NotEq)
301            .unwrap()
302            .to_bool()
303            .unwrap();
304        let empty: [u64; 0] = [];
305        assert_eq!(to_int_indices(matches), empty);
306
307        let other = BoolArray::new(
308            BooleanBuffer::from_iter([false, false, false, true, true]),
309            Validity::from_iter([false, true, true, true, true]),
310        );
311
312        let matches = compare(&arr, &other, Operator::Lte)
313            .unwrap()
314            .to_bool()
315            .unwrap();
316        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
317
318        let matches = compare(&arr, &other, Operator::Lt)
319            .unwrap()
320            .to_bool()
321            .unwrap();
322        assert_eq!(to_int_indices(matches), [4u64]);
323
324        let matches = compare(&other, &arr, Operator::Gte)
325            .unwrap()
326            .to_bool()
327            .unwrap();
328        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
329
330        let matches = compare(&other, &arr, Operator::Gt)
331            .unwrap()
332            .to_bool()
333            .unwrap();
334        assert_eq!(to_int_indices(matches), [4u64]);
335    }
336
337    #[test]
338    fn constant_compare() {
339        let left = ConstantArray::new(Scalar::from(2u32), 10);
340        let right = ConstantArray::new(Scalar::from(10u32), 10);
341
342        let compare = compare(&left, &right, Operator::Gt).unwrap();
343        let res = compare.as_constant().unwrap();
344        assert_eq!(res.as_bool().value(), Some(false));
345        assert_eq!(compare.len(), 10);
346
347        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
348        let res = compare.as_constant().unwrap();
349        assert_eq!(res.as_bool().value(), Some(false));
350        assert_eq!(compare.len(), 10);
351    }
352
353    #[rstest]
354    #[case(Operator::Eq, vec![false, false, false, true])]
355    #[case(Operator::NotEq, vec![true, true, true, false])]
356    #[case(Operator::Gt, vec![true, true, true, false])]
357    #[case(Operator::Gte, vec![true, true, true, true])]
358    #[case(Operator::Lt, vec![false, false, false, false])]
359    #[case(Operator::Lte, vec![false, false, false, true])]
360    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
361        let lengths: Vec<i32> = vec![1, 5, 7, 0];
362
363        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
364        assert_eq!(Vec::from_iter(output.iter()), expected);
365    }
366
367    #[rstest]
368    #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
369    #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
370    #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
371    #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
372    fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
373        let res = arrow_compare(&left, &right, Operator::Eq).unwrap();
374        assert_eq!(
375            res.to_bool().unwrap().boolean_buffer().count_set_bits(),
376            left.len()
377        );
378    }
379}