vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::{BinaryArray, StringArray};
5use arrow_buffer::BooleanBuffer;
6use arrow_ord::cmp;
7use itertools::Itertools;
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
10
11use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable};
12use crate::arrow::{Datum, from_arrow_array_with_len};
13use crate::compute::{
14    CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
15};
16use crate::vtable::ValidityHelper;
17use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
18
19// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
20impl CompareKernel for VarBinVTable {
21    fn compare(
22        &self,
23        lhs: &VarBinArray,
24        rhs: &dyn Array,
25        operator: Operator,
26    ) -> VortexResult<Option<ArrayRef>> {
27        if let Some(rhs_const) = rhs.as_constant() {
28            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
29            let len = lhs.len();
30
31            let rhs_is_empty = match rhs_const.dtype() {
32                DType::Binary(_) => rhs_const
33                    .as_binary()
34                    .is_empty()
35                    .vortex_expect("RHS should not be null"),
36                DType::Utf8(_) => rhs_const
37                    .as_utf8()
38                    .is_empty()
39                    .vortex_expect("RHS should not be null"),
40                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
41            };
42
43            if rhs_is_empty {
44                let buffer = match operator {
45                    Operator::Gte => BooleanBuffer::new_set(len), // Every possible value is >= ""
46                    Operator::Lt => BooleanBuffer::new_unset(len), // No value is < ""
47                    Operator::Eq | Operator::NotEq | Operator::Gt | Operator::Lte => {
48                        let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
49                        match_each_native_ptype!(lhs_offsets.ptype(), |P| {
50                            compare_offsets_to_empty::<P>(lhs_offsets, operator)
51                        })
52                    }
53                };
54
55                return Ok(Some(
56                    BoolArray::new(
57                        buffer,
58                        lhs.validity()
59                            .clone()
60                            .union_nullability(rhs.dtype().nullability()),
61                    )
62                    .into_array(),
63                ));
64            }
65
66            let lhs = Datum::try_new(lhs.as_ref())?;
67
68            // TODO(robert): Handle LargeString/Binary arrays
69            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
70                DType::Utf8(_) => &rhs_const
71                    .as_utf8()
72                    .value()
73                    .map(StringArray::new_scalar)
74                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
75                DType::Binary(_) => &rhs_const
76                    .as_binary()
77                    .value()
78                    .map(BinaryArray::new_scalar)
79                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
80                _ => vortex_bail!(
81                    "VarBin array RHS can only be Utf8 or Binary, given {}",
82                    rhs_const.dtype()
83                ),
84            };
85
86            let array = match operator {
87                Operator::Eq => cmp::eq(&lhs, arrow_rhs),
88                Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
89                Operator::Gt => cmp::gt(&lhs, arrow_rhs),
90                Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
91                Operator::Lt => cmp::lt(&lhs, arrow_rhs),
92                Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
93            }
94            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
95
96            Ok(Some(from_arrow_array_with_len(&array, len, nullable)))
97        } else if !rhs.is::<VarBinVTable>() {
98            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
99            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
100            // to VarBinView and re-invoke.
101            return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
102        } else {
103            Ok(None)
104        }
105    }
106}
107
108register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
109
110fn compare_offsets_to_empty<P: NativePType>(
111    offsets: PrimitiveArray,
112    operator: Operator,
113) -> BooleanBuffer {
114    let lengths_iter = offsets
115        .as_slice::<P>()
116        .iter()
117        .tuple_windows()
118        .map(|(&s, &e)| e - s);
119    compare_lengths_to_empty(lengths_iter, operator)
120}
121
122#[cfg(test)]
123mod test {
124    use arrow_buffer::BooleanBuffer;
125    use vortex_buffer::ByteBuffer;
126    use vortex_dtype::{DType, Nullability};
127    use vortex_scalar::Scalar;
128
129    use crate::ToCanonical;
130    use crate::arrays::{ConstantArray, VarBinArray, VarBinViewArray};
131    use crate::compute::{Operator, compare};
132
133    #[test]
134    fn test_binary_compare() {
135        let array = VarBinArray::from_iter(
136            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
137            DType::Binary(Nullability::Nullable),
138        );
139        let result = compare(
140            array.as_ref(),
141            ConstantArray::new(
142                Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
143                3,
144            )
145            .as_ref(),
146            Operator::Eq,
147        )
148        .unwrap()
149        .to_bool()
150        .unwrap();
151
152        assert_eq!(
153            &result.validity_mask().unwrap().to_boolean_buffer(),
154            &BooleanBuffer::from_iter([true, false, true])
155        );
156        assert_eq!(
157            result.boolean_buffer(),
158            &BooleanBuffer::from_iter([true, false, false])
159        );
160    }
161
162    #[test]
163    fn varbinview_compare() {
164        let array = VarBinArray::from_iter(
165            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
166            DType::Binary(Nullability::Nullable),
167        );
168        let vbv = VarBinViewArray::from_iter(
169            [None, None, Some(b"def".to_vec())],
170            DType::Binary(Nullability::Nullable),
171        );
172        let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
173            .unwrap()
174            .to_bool()
175            .unwrap();
176
177        assert_eq!(
178            &result.validity_mask().unwrap().to_boolean_buffer(),
179            &BooleanBuffer::from_iter([false, false, true])
180        );
181        assert_eq!(
182            result.boolean_buffer(),
183            &BooleanBuffer::from_iter([false, true, true])
184        );
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use vortex_dtype::{DType, Nullability};
191    use vortex_scalar::Scalar;
192
193    use crate::Array;
194    use crate::arrays::{ConstantArray, VarBinArray};
195    use crate::compute::{Operator, compare};
196
197    #[test]
198    fn test_null_compare() {
199        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
200
201        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
202
203        assert_eq!(
204            compare(arr.as_ref(), const_.as_ref(), Operator::Eq)
205                .unwrap()
206                .dtype(),
207            &DType::Bool(Nullability::Nullable)
208        );
209    }
210}