Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect as _;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::ArrayRef;
14use crate::DynArray;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::BoolArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::VarBinArray;
20use crate::arrays::VarBinVTable;
21use crate::arrays::VarBinViewArray;
22use crate::arrow::Datum;
23use crate::arrow::from_arrow_array_with_len;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::IntegerPType;
27use crate::match_each_integer_ptype;
28use crate::scalar_fn::fns::binary::CompareKernel;
29use crate::scalar_fn::fns::operators::CompareOperator;
30use crate::scalar_fn::fns::operators::Operator;
31use crate::vtable::ValidityHelper;
32
33// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
34impl CompareKernel for VarBinVTable {
35    fn compare(
36        lhs: &VarBinArray,
37        rhs: &ArrayRef,
38        operator: CompareOperator,
39        ctx: &mut ExecutionCtx,
40    ) -> VortexResult<Option<ArrayRef>> {
41        if let Some(rhs_const) = rhs.as_constant() {
42            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
43            let len = lhs.len();
44
45            let rhs_is_empty = match rhs_const.dtype() {
46                DType::Binary(_) => rhs_const
47                    .as_binary()
48                    .is_empty()
49                    .vortex_expect("RHS should not be null"),
50                DType::Utf8(_) => rhs_const
51                    .as_utf8()
52                    .is_empty()
53                    .vortex_expect("RHS should not be null"),
54                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
55            };
56
57            if rhs_is_empty {
58                let buffer = match operator {
59                    CompareOperator::Gte => BitBuffer::new_set(len), // Every possible value is >= ""
60                    CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < ""
61                    CompareOperator::Eq | CompareOperator::Lte => {
62                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
63                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
64                            compare_offsets_to_empty::<P>(lhs_offsets, true)
65                        })
66                    }
67                    CompareOperator::NotEq | CompareOperator::Gt => {
68                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
69                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
70                            compare_offsets_to_empty::<P>(lhs_offsets, false)
71                        })
72                    }
73                };
74
75                return Ok(Some(
76                    BoolArray::new(
77                        buffer,
78                        lhs.validity()
79                            .clone()
80                            .union_nullability(rhs.dtype().nullability()),
81                    )
82                    .into_array(),
83                ));
84            }
85
86            let lhs = Datum::try_new(&lhs.clone().into_array())?;
87
88            // Use StringViewArray/BinaryViewArray to match the Utf8View/BinaryView types
89            // produced by Datum::try_new (which uses into_arrow_preferred())
90            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
91                DType::Utf8(_) => &rhs_const
92                    .as_utf8()
93                    .value()
94                    .map(StringArray::new_scalar)
95                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
96                DType::Binary(_) => &rhs_const
97                    .as_binary()
98                    .value()
99                    .map(BinaryArray::new_scalar)
100                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
101                _ => vortex_bail!(
102                    "VarBin array RHS can only be Utf8 or Binary, given {}",
103                    rhs_const.dtype()
104                ),
105            };
106
107            let array = match operator {
108                CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
109                CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
110                CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
111                CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
112                CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
113                CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
114            }
115            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
116
117            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
118        } else if !rhs.is::<VarBinVTable>() {
119            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
120            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
121            // to VarBinView and re-invoke.
122            return Ok(Some(
123                lhs.clone()
124                    .into_array()
125                    .execute::<VarBinViewArray>(ctx)?
126                    .into_array()
127                    .binary(rhs.to_array(), Operator::from(operator))?,
128            ));
129        } else {
130            Ok(None)
131        }
132    }
133}
134
135fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
136    let fn_ = if eq { P::eq } else { P::ne };
137    let offsets = offsets.as_slice::<P>();
138    BitBuffer::collect_bool(offsets.len() - 1, |idx| {
139        let left = unsafe { offsets.get_unchecked(idx) };
140        let right = unsafe { offsets.get_unchecked(idx + 1) };
141        fn_(left, right)
142    })
143}
144
145#[cfg(test)]
146mod test {
147    use vortex_buffer::BitBuffer;
148    use vortex_buffer::ByteBuffer;
149
150    use crate::IntoArray;
151    use crate::ToCanonical;
152    use crate::arrays::ConstantArray;
153    use crate::arrays::VarBinArray;
154    use crate::arrays::VarBinViewArray;
155    use crate::builtins::ArrayBuiltins;
156    use crate::dtype::DType;
157    use crate::dtype::Nullability;
158    use crate::scalar::Scalar;
159    use crate::scalar_fn::fns::operators::Operator;
160
161    #[test]
162    fn test_binary_compare() {
163        let array = VarBinArray::from_iter(
164            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
165            DType::Binary(Nullability::Nullable),
166        );
167        let result = array
168            .into_array()
169            .binary(
170                ConstantArray::new(
171                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
172                    3,
173                )
174                .into_array(),
175                Operator::Eq,
176            )
177            .unwrap()
178            .to_bool();
179
180        assert_eq!(
181            &result.validity_mask().unwrap().to_bit_buffer(),
182            &BitBuffer::from_iter([true, false, true])
183        );
184        assert_eq!(
185            result.to_bit_buffer(),
186            BitBuffer::from_iter([true, false, false])
187        );
188    }
189
190    #[test]
191    fn varbinview_compare() {
192        let array = VarBinArray::from_iter(
193            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
194            DType::Binary(Nullability::Nullable),
195        );
196        let vbv = VarBinViewArray::from_iter(
197            [None, None, Some(b"def".to_vec())],
198            DType::Binary(Nullability::Nullable),
199        );
200        let result = array
201            .into_array()
202            .binary(vbv.into_array(), Operator::Eq)
203            .unwrap()
204            .to_bool();
205
206        assert_eq!(
207            result.validity_mask().unwrap().to_bit_buffer(),
208            BitBuffer::from_iter([false, false, true])
209        );
210        assert_eq!(
211            result.to_bit_buffer(),
212            BitBuffer::from_iter([false, true, true])
213        );
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use crate::DynArray;
220    use crate::IntoArray;
221    use crate::arrays::ConstantArray;
222    use crate::arrays::VarBinArray;
223    use crate::builtins::ArrayBuiltins;
224    use crate::dtype::DType;
225    use crate::dtype::Nullability;
226    use crate::scalar::Scalar;
227    use crate::scalar_fn::fns::operators::Operator;
228
229    #[test]
230    fn test_null_compare() {
231        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
232
233        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
234
235        assert_eq!(
236            arr.into_array()
237                .binary(const_.into_array(), Operator::Eq)
238                .unwrap()
239                .dtype(),
240            &DType::Bool(Nullability::Nullable)
241        );
242    }
243}