vortex_array/arrays/varbin/compute/
compare.rs

1use arrow_array::{BinaryArray, StringArray};
2use arrow_buffer::BooleanBuffer;
3use arrow_ord::cmp;
4use itertools::Itertools;
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable};
9use crate::arrow::{Datum, from_arrow_array_with_len};
10use crate::compute::{
11    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
12};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
17impl CompareKernel for VarBinVTable {
18    fn compare(
19        &self,
20        lhs: &VarBinArray,
21        rhs: &dyn Array,
22        operator: Operator,
23    ) -> VortexResult<Option<ArrayRef>> {
24        if let Some(rhs_const) = rhs.as_constant() {
25            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
26            let len = lhs.len();
27
28            let rhs_is_empty = match rhs_const.dtype() {
29                DType::Binary(_) => rhs_const
30                    .as_binary()
31                    .is_empty()
32                    .vortex_expect("RHS should not be null"),
33                DType::Utf8(_) => rhs_const
34                    .as_utf8()
35                    .is_empty()
36                    .vortex_expect("RHS should not be null"),
37                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
38            };
39
40            if rhs_is_empty {
41                let buffer = match operator {
42                    // Every possible value is gte ""
43                    Operator::Gte => BooleanBuffer::new_set(len),
44                    // No value is lt ""
45                    Operator::Lt => BooleanBuffer::new_unset(len),
46                    _ => {
47                        let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
48                        match_each_native_ptype!(lhs_offsets.ptype(), |P| {
49                            compare_offsets_to_empty::<P>(lhs_offsets, operator)
50                        })
51                    }
52                };
53
54                return Ok(Some(
55                    BoolArray::new(buffer, lhs.validity().clone()).into_array(),
56                ));
57            }
58
59            let lhs = Datum::try_new(lhs.as_ref())?;
60
61            // TODO(robert): Handle LargeString/Binary arrays
62            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
63                DType::Utf8(_) => &rhs_const
64                    .as_utf8()
65                    .value()
66                    .map(StringArray::new_scalar)
67                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
68                DType::Binary(_) => &rhs_const
69                    .as_binary()
70                    .value()
71                    .map(BinaryArray::new_scalar)
72                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
73                _ => vortex_bail!(
74                    "VarBin array RHS can only be Utf8 or Binary, given {}",
75                    rhs_const.dtype()
76                ),
77            };
78
79            let array = match operator {
80                Operator::Eq => cmp::eq(&lhs, arrow_rhs),
81                Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
82                Operator::Gt => cmp::gt(&lhs, arrow_rhs),
83                Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
84                Operator::Lt => cmp::lt(&lhs, arrow_rhs),
85                Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
86            }
87            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
88
89            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
90        } else if !rhs.is::<VarBinVTable>() {
91            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
92            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
93            // to VarBinView and re-invoke.
94            return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
95        } else {
96            Ok(None)
97        }
98    }
99}
100
101register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
102
103fn compare_offsets_to_empty<P: NativePType>(
104    offsets: PrimitiveArray,
105    operator: Operator,
106) -> BooleanBuffer {
107    let lengths_iter = offsets
108        .as_slice::<P>()
109        .iter()
110        .tuple_windows()
111        .map(|(&s, &e)| e - s);
112    compare_lengths_to_empty(lengths_iter, operator)
113}
114
115#[cfg(test)]
116mod test {
117    use arrow_buffer::BooleanBuffer;
118    use vortex_buffer::ByteBuffer;
119    use vortex_dtype::{DType, Nullability};
120    use vortex_scalar::Scalar;
121
122    use crate::ToCanonical;
123    use crate::arrays::{ConstantArray, VarBinArray, VarBinViewArray};
124    use crate::compute::{Operator, compare};
125
126    #[test]
127    fn test_binary_compare() {
128        let array = VarBinArray::from_iter(
129            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
130            DType::Binary(Nullability::Nullable),
131        );
132        let result = compare(
133            array.as_ref(),
134            ConstantArray::new(
135                Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
136                3,
137            )
138            .as_ref(),
139            Operator::Eq,
140        )
141        .unwrap()
142        .to_bool()
143        .unwrap();
144
145        assert_eq!(
146            &result.validity_mask().unwrap().to_boolean_buffer(),
147            &BooleanBuffer::from_iter([true, false, true])
148        );
149        assert_eq!(
150            result.boolean_buffer(),
151            &BooleanBuffer::from_iter([true, false, false])
152        );
153    }
154
155    #[test]
156    fn varbinview_compare() {
157        let array = VarBinArray::from_iter(
158            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
159            DType::Binary(Nullability::Nullable),
160        );
161        let vbv = VarBinViewArray::from_iter(
162            [None, None, Some(b"def".to_vec())],
163            DType::Binary(Nullability::Nullable),
164        );
165        let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
166            .unwrap()
167            .to_bool()
168            .unwrap();
169
170        assert_eq!(
171            &result.validity_mask().unwrap().to_boolean_buffer(),
172            &BooleanBuffer::from_iter([false, false, true])
173        );
174        assert_eq!(
175            result.boolean_buffer(),
176            &BooleanBuffer::from_iter([false, true, true])
177        );
178    }
179}