Skip to main content

vortex_array/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5
6use arrow_array::BooleanArray;
7use arrow_buffer::NullBuffer;
8use arrow_ord::ord::make_comparator;
9use arrow_schema::SortOptions;
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexResult;
12
13use crate::dtype::IntegerPType;
14use crate::scalar_fn::fns::operators::CompareOperator;
15
16/// Helper function to compare empty values with arrays that have external value length information
17/// like `VarBin`.
18pub fn compare_lengths_to_empty<P, I>(lengths: I, op: CompareOperator) -> BitBuffer
19where
20    P: IntegerPType,
21    I: Iterator<Item = P>,
22{
23    // All comparison can be expressed in terms of equality. "" is the absolute min of possible value.
24    let cmp_fn = match op {
25        CompareOperator::Eq | CompareOperator::Lte => |v| v == P::zero(),
26        CompareOperator::NotEq | CompareOperator::Gt => |v| v != P::zero(),
27        CompareOperator::Gte => |_| true,
28        CompareOperator::Lt => |_| false,
29    };
30
31    lengths.map(cmp_fn).collect()
32}
33
34/// Compare two Arrow arrays element-wise using [`make_comparator`].
35///
36/// This function is required for nested types (Struct, List, FixedSizeList) because Arrow's
37/// vectorized comparison kernels ([`cmp::eq`], [`cmp::neq`], etc.) do not support them.
38///
39/// The vectorized kernels are faster but only work on primitive types, so for non-nested types,
40/// prefer using the vectorized kernels directly for better performance.
41pub(crate) fn compare_nested_arrow_arrays(
42    lhs: &dyn arrow_array::Array,
43    rhs: &dyn arrow_array::Array,
44    operator: CompareOperator,
45) -> VortexResult<BooleanArray> {
46    let compare_arrays_at = make_comparator(lhs, rhs, SortOptions::default())?;
47
48    let cmp_fn = match operator {
49        CompareOperator::Eq => Ordering::is_eq,
50        CompareOperator::NotEq => Ordering::is_ne,
51        CompareOperator::Gt => Ordering::is_gt,
52        CompareOperator::Gte => Ordering::is_ge,
53        CompareOperator::Lt => Ordering::is_lt,
54        CompareOperator::Lte => Ordering::is_le,
55    };
56
57    let values = (0..lhs.len())
58        .map(|i| cmp_fn(compare_arrays_at(i, i)))
59        .collect();
60    let nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
61
62    Ok(BooleanArray::new(values, nulls))
63}
64
65#[cfg(test)]
66mod tests {
67    use rstest::rstest;
68
69    use super::*;
70
71    #[rstest]
72    #[case(CompareOperator::Eq, vec![false, false, false, true])]
73    #[case(CompareOperator::NotEq, vec![true, true, true, false])]
74    #[case(CompareOperator::Gt, vec![true, true, true, false])]
75    #[case(CompareOperator::Gte, vec![true, true, true, true])]
76    #[case(CompareOperator::Lt, vec![false, false, false, false])]
77    #[case(CompareOperator::Lte, vec![false, false, false, true])]
78    fn test_cmp_to_empty(#[case] op: CompareOperator, #[case] expected: Vec<bool>) {
79        let lengths: Vec<i32> = vec![1, 5, 7, 0];
80
81        let output = compare_lengths_to_empty(lengths.iter().copied(), op);
82        assert_eq!(Vec::from_iter(output.iter()), expected);
83    }
84}