vortex_array/arrays/varbin/compute/
compare.rs

1use arrow_array::{BinaryArray, StringArray};
2use arrow_buffer::BooleanBuffer;
3use arrow_ord::cmp;
4use itertools::Itertools;
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable};
9use crate::arrow::{Datum, from_arrow_array_with_len};
10use crate::compute::{
11    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
12};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
17impl CompareKernel for VarBinVTable {
18    fn compare(
19        &self,
20        lhs: &VarBinArray,
21        rhs: &dyn Array,
22        operator: Operator,
23    ) -> VortexResult<Option<ArrayRef>> {
24        if let Some(rhs_const) = rhs.as_constant() {
25            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
26            let len = lhs.len();
27
28            let rhs_is_empty = match rhs_const.dtype() {
29                DType::Binary(_) => rhs_const
30                    .as_binary()
31                    .is_empty()
32                    .vortex_expect("RHS should not be null"),
33                DType::Utf8(_) => rhs_const
34                    .as_utf8()
35                    .is_empty()
36                    .vortex_expect("RHS should not be null"),
37                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
38            };
39
40            if rhs_is_empty {
41                let buffer = match operator {
42                    // Every possible value is gte ""
43                    Operator::Gte => BooleanBuffer::new_set(len),
44                    // No value is lt ""
45                    Operator::Lt => BooleanBuffer::new_unset(len),
46                    _ => {
47                        let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
48                        match_each_native_ptype!(lhs_offsets.ptype(), |P| {
49                            compare_offsets_to_empty::<P>(lhs_offsets, operator)
50                        })
51                    }
52                };
53
54                return Ok(Some(
55                    BoolArray::new(
56                        buffer,
57                        lhs.validity()
58                            .clone()
59                            .union_nullability(rhs.dtype().nullability()),
60                    )
61                    .into_array(),
62                ));
63            }
64
65            let lhs = Datum::try_new(lhs.as_ref())?;
66
67            // TODO(robert): Handle LargeString/Binary arrays
68            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
69                DType::Utf8(_) => &rhs_const
70                    .as_utf8()
71                    .value()
72                    .map(StringArray::new_scalar)
73                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
74                DType::Binary(_) => &rhs_const
75                    .as_binary()
76                    .value()
77                    .map(BinaryArray::new_scalar)
78                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
79                _ => vortex_bail!(
80                    "VarBin array RHS can only be Utf8 or Binary, given {}",
81                    rhs_const.dtype()
82                ),
83            };
84
85            let array = match operator {
86                Operator::Eq => cmp::eq(&lhs, arrow_rhs),
87                Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
88                Operator::Gt => cmp::gt(&lhs, arrow_rhs),
89                Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
90                Operator::Lt => cmp::lt(&lhs, arrow_rhs),
91                Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
92            }
93            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
94
95            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
96        } else if !rhs.is::<VarBinVTable>() {
97            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
98            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
99            // to VarBinView and re-invoke.
100            return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
101        } else {
102            Ok(None)
103        }
104    }
105}
106
107register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
108
109fn compare_offsets_to_empty<P: NativePType>(
110    offsets: PrimitiveArray,
111    operator: Operator,
112) -> BooleanBuffer {
113    let lengths_iter = offsets
114        .as_slice::<P>()
115        .iter()
116        .tuple_windows()
117        .map(|(&s, &e)| e - s);
118    compare_lengths_to_empty(lengths_iter, operator)
119}
120
121#[cfg(test)]
122mod test {
123    use arrow_buffer::BooleanBuffer;
124    use vortex_buffer::ByteBuffer;
125    use vortex_dtype::{DType, Nullability};
126    use vortex_scalar::Scalar;
127
128    use crate::ToCanonical;
129    use crate::arrays::{ConstantArray, VarBinArray, VarBinViewArray};
130    use crate::compute::{Operator, compare};
131
132    #[test]
133    fn test_binary_compare() {
134        let array = VarBinArray::from_iter(
135            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
136            DType::Binary(Nullability::Nullable),
137        );
138        let result = compare(
139            array.as_ref(),
140            ConstantArray::new(
141                Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
142                3,
143            )
144            .as_ref(),
145            Operator::Eq,
146        )
147        .unwrap()
148        .to_bool()
149        .unwrap();
150
151        assert_eq!(
152            &result.validity_mask().unwrap().to_boolean_buffer(),
153            &BooleanBuffer::from_iter([true, false, true])
154        );
155        assert_eq!(
156            result.boolean_buffer(),
157            &BooleanBuffer::from_iter([true, false, false])
158        );
159    }
160
161    #[test]
162    fn varbinview_compare() {
163        let array = VarBinArray::from_iter(
164            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
165            DType::Binary(Nullability::Nullable),
166        );
167        let vbv = VarBinViewArray::from_iter(
168            [None, None, Some(b"def".to_vec())],
169            DType::Binary(Nullability::Nullable),
170        );
171        let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
172            .unwrap()
173            .to_bool()
174            .unwrap();
175
176        assert_eq!(
177            &result.validity_mask().unwrap().to_boolean_buffer(),
178            &BooleanBuffer::from_iter([false, false, true])
179        );
180        assert_eq!(
181            result.boolean_buffer(),
182            &BooleanBuffer::from_iter([false, true, true])
183        );
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use vortex_dtype::{DType, Nullability};
190    use vortex_scalar::Scalar;
191
192    use crate::Array;
193    use crate::arrays::{ConstantArray, VarBinArray};
194    use crate::compute::{Operator, compare};
195
196    #[test]
197    fn test_null_compare() {
198        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
199
200        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
201
202        assert_eq!(
203            compare(arr.as_ref(), const_.as_ref(), Operator::Eq)
204                .unwrap()
205                .dtype(),
206            &DType::Bool(Nullability::Nullable)
207        );
208    }
209}