vortex_array/compute/
compare.rs

1use core::fmt;
2use std::fmt::{Display, Formatter};
3
4use arrow_buffer::BooleanBuffer;
5use arrow_ord::cmp;
6use vortex_dtype::{DType, NativePType, Nullability};
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::arrays::ConstantArray;
11use crate::arrow::{Datum, from_arrow_array_with_len};
12use crate::encoding::Encoding;
13use crate::{Array, ArrayRef, Canonical, IntoArray};
14
15#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
16pub enum Operator {
17    Eq,
18    NotEq,
19    Gt,
20    Gte,
21    Lt,
22    Lte,
23}
24
25impl Display for Operator {
26    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
27        let display = match &self {
28            Operator::Eq => "=",
29            Operator::NotEq => "!=",
30            Operator::Gt => ">",
31            Operator::Gte => ">=",
32            Operator::Lt => "<",
33            Operator::Lte => "<=",
34        };
35        Display::fmt(display, f)
36    }
37}
38
39impl Operator {
40    pub fn inverse(self) -> Self {
41        match self {
42            Operator::Eq => Operator::NotEq,
43            Operator::NotEq => Operator::Eq,
44            Operator::Gt => Operator::Lte,
45            Operator::Gte => Operator::Lt,
46            Operator::Lt => Operator::Gte,
47            Operator::Lte => Operator::Gt,
48        }
49    }
50
51    /// Change the sides of the operator, where changing lhs and rhs won't change the result of the operation
52    pub fn swap(self) -> Self {
53        match self {
54            Operator::Eq => Operator::Eq,
55            Operator::NotEq => Operator::NotEq,
56            Operator::Gt => Operator::Lt,
57            Operator::Gte => Operator::Lte,
58            Operator::Lt => Operator::Gt,
59            Operator::Lte => Operator::Gte,
60        }
61    }
62}
63
64pub trait CompareFn<A> {
65    /// Compares two arrays and returns a new boolean array with the result of the comparison.
66    /// Or, returns None if comparison is not supported for these arrays.
67    fn compare(
68        &self,
69        lhs: A,
70        rhs: &dyn Array,
71        operator: Operator,
72    ) -> VortexResult<Option<ArrayRef>>;
73}
74
75impl<E: Encoding> CompareFn<&dyn Array> for E
76where
77    E: for<'a> CompareFn<&'a E::Array>,
78{
79    fn compare(
80        &self,
81        lhs: &dyn Array,
82        rhs: &dyn Array,
83        operator: Operator,
84    ) -> VortexResult<Option<ArrayRef>> {
85        let array_ref = lhs
86            .as_any()
87            .downcast_ref::<E::Array>()
88            .vortex_expect("Failed to downcast array");
89
90        CompareFn::compare(self, array_ref, rhs, operator)
91    }
92}
93
94pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
95    if left.len() != right.len() {
96        vortex_bail!("Compare operations only support arrays of the same length");
97    }
98    if !left.dtype().eq_ignore_nullability(right.dtype()) {
99        vortex_bail!(
100            "Cannot compare different DTypes {} and {}",
101            left.dtype(),
102            right.dtype()
103        );
104    }
105
106    // TODO(ngates): no reason why not
107    if left.dtype().is_struct() {
108        vortex_bail!(
109            "Compare does not support arrays with Struct DType, got: {} and {}",
110            left.dtype(),
111            right.dtype()
112        )
113    }
114
115    let result_dtype =
116        DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into());
117
118    if left.is_empty() {
119        return Ok(Canonical::empty(&result_dtype).into_array());
120    }
121
122    let left_constant_null = left.as_constant().map(|l| l.is_null()).unwrap_or(false);
123    let right_constant_null = right.as_constant().map(|r| r.is_null()).unwrap_or(false);
124    if left_constant_null || right_constant_null {
125        return Ok(ConstantArray::new(Scalar::null(result_dtype), left.len()).into_array());
126    }
127
128    let right_is_constant = right.is_constant();
129
130    // Always try to put constants on the right-hand side so encodings can optimise themselves.
131    if left.is_constant() && !right_is_constant {
132        return compare(right, left, operator.swap());
133    }
134
135    if let Some(result) = left
136        .vtable()
137        .compare_fn()
138        .and_then(|f| f.compare(left, right, operator).transpose())
139        .transpose()?
140    {
141        check_compare_result(&result, left, right);
142        return Ok(result);
143    }
144
145    if let Some(result) = right
146        .vtable()
147        .compare_fn()
148        .and_then(|f| f.compare(right, left, operator.swap()).transpose())
149        .transpose()?
150    {
151        check_compare_result(&result, left, right);
152        return Ok(result);
153    }
154
155    // Only log missing compare implementation if there's possibly better one than arrow,
156    // i.e. lhs isn't arrow or rhs isn't arrow or constant
157    if !(left.is_arrow() && (right.is_arrow() || right_is_constant)) {
158        log::debug!(
159            "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
160            right.encoding(),
161            left.encoding(),
162            operator.swap(),
163        );
164    }
165
166    // Fallback to arrow on canonical types
167    let result = arrow_compare(left, right, operator)?;
168    check_compare_result(&result, left, right);
169    Ok(result)
170}
171
172/// Helper function to compare empty values with arrays that have external value length information
173/// like `VarBin`.
174pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
175where
176    P: NativePType,
177    I: Iterator<Item = P>,
178{
179    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
180    let cmp_fn = match op {
181        Operator::Eq | Operator::Lte => |v| v == P::zero(),
182        Operator::NotEq | Operator::Gt => |v| v != P::zero(),
183        Operator::Gte => |_| true,
184        Operator::Lt => |_| false,
185    };
186
187    lengths.map(cmp_fn).collect::<BooleanBuffer>()
188}
189
190/// Implementation of `CompareFn` using the Arrow crate.
191fn arrow_compare(
192    left: &dyn Array,
193    right: &dyn Array,
194    operator: Operator,
195) -> VortexResult<ArrayRef> {
196    let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
197    let lhs = Datum::try_new(left.to_array())?;
198    let rhs = Datum::try_new(right.to_array())?;
199
200    let array = match operator {
201        Operator::Eq => cmp::eq(&lhs, &rhs)?,
202        Operator::NotEq => cmp::neq(&lhs, &rhs)?,
203        Operator::Gt => cmp::gt(&lhs, &rhs)?,
204        Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
205        Operator::Lt => cmp::lt(&lhs, &rhs)?,
206        Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
207    };
208    from_arrow_array_with_len(&array, left.len(), nullable)
209}
210
211#[inline(always)]
212fn check_compare_result(result: &dyn Array, lhs: &dyn Array, rhs: &dyn Array) {
213    debug_assert_eq!(
214        result.len(),
215        lhs.len(),
216        "CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
217        result.len(),
218        lhs.encoding(),
219        lhs.len(),
220        rhs.encoding(),
221        rhs.len()
222    );
223    debug_assert_eq!(
224        result.dtype(),
225        &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
226        "CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
227        result.dtype(),
228        lhs.encoding(),
229        rhs.encoding(),
230    );
231}
232
233pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
234    if lhs.is_null() | rhs.is_null() {
235        Scalar::null(DType::Bool(Nullability::Nullable))
236    } else {
237        let b = match operator {
238            Operator::Eq => lhs == rhs,
239            Operator::NotEq => lhs != rhs,
240            Operator::Gt => lhs > rhs,
241            Operator::Gte => lhs >= rhs,
242            Operator::Lt => lhs < rhs,
243            Operator::Lte => lhs <= rhs,
244        };
245
246        Scalar::bool(
247            b,
248            (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
249        )
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use arrow_buffer::BooleanBuffer;
256    use itertools::Itertools;
257
258    use super::*;
259    use crate::ToCanonical;
260    use crate::arrays::{BoolArray, ConstantArray};
261    use crate::validity::Validity;
262
263    fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
264        let buffer = indices_bits.boolean_buffer();
265        let null_buffer = indices_bits
266            .validity()
267            .to_logical(indices_bits.len())
268            .unwrap()
269            .to_null_buffer();
270        let is_valid = |idx: usize| match null_buffer.as_ref() {
271            None => true,
272            Some(buffer) => buffer.is_valid(idx),
273        };
274        let filtered = buffer
275            .iter()
276            .enumerate()
277            .flat_map(|(idx, v)| (v && is_valid(idx)).then_some(idx as u64))
278            .collect_vec();
279        filtered
280    }
281
282    #[test]
283    fn test_bool_basic_comparisons() {
284        let arr = BoolArray::new(
285            BooleanBuffer::from_iter([true, true, false, true, false]),
286            Validity::from_iter([false, true, true, true, true]),
287        );
288
289        let matches = compare(&arr, &arr, Operator::Eq)
290            .unwrap()
291            .to_bool()
292            .unwrap();
293
294        assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);
295
296        let matches = compare(&arr, &arr, Operator::NotEq)
297            .unwrap()
298            .to_bool()
299            .unwrap();
300        let empty: [u64; 0] = [];
301        assert_eq!(to_int_indices(matches), empty);
302
303        let other = BoolArray::new(
304            BooleanBuffer::from_iter([false, false, false, true, true]),
305            Validity::from_iter([false, true, true, true, true]),
306        );
307
308        let matches = compare(&arr, &other, Operator::Lte)
309            .unwrap()
310            .to_bool()
311            .unwrap();
312        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
313
314        let matches = compare(&arr, &other, Operator::Lt)
315            .unwrap()
316            .to_bool()
317            .unwrap();
318        assert_eq!(to_int_indices(matches), [4u64]);
319
320        let matches = compare(&other, &arr, Operator::Gte)
321            .unwrap()
322            .to_bool()
323            .unwrap();
324        assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
325
326        let matches = compare(&other, &arr, Operator::Gt)
327            .unwrap()
328            .to_bool()
329            .unwrap();
330        assert_eq!(to_int_indices(matches), [4u64]);
331    }
332
333    #[test]
334    fn constant_compare() {
335        let left = ConstantArray::new(Scalar::from(2u32), 10);
336        let right = ConstantArray::new(Scalar::from(10u32), 10);
337
338        let compare = compare(&left, &right, Operator::Gt).unwrap();
339        let res = compare.as_constant().unwrap();
340        assert_eq!(res.as_bool().value(), Some(false));
341        assert_eq!(compare.len(), 10);
342
343        let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
344        let res = compare.as_constant().unwrap();
345        assert_eq!(res.as_bool().value(), Some(false));
346        assert_eq!(compare.len(), 10);
347    }
348
349    #[rstest::rstest]
350    #[case(Operator::Eq, vec![false, false, false, true])]
351    #[case(Operator::NotEq, vec![true, true, true, false])]
352    #[case(Operator::Gt, vec![true, true, true, false])]
353    #[case(Operator::Gte, vec![true, true, true, true])]
354    #[case(Operator::Lt, vec![false, false, false, false])]
355    #[case(Operator::Lte, vec![false, false, false, true])]
356    fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
357        let lengths: Vec<i32> = vec![1, 5, 7, 0];
358
359        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
360        assert_eq!(Vec::from_iter(output.iter()), expected);
361    }
362}