Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::LargeBinaryArray;
6use arrow_array::LargeStringArray;
7use arrow_array::StringArray;
8use arrow_ord::cmp;
9use arrow_schema::DataType;
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexExpect as _;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::arrays::BoolArray;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::VarBin;
23use crate::arrays::VarBinViewArray;
24use crate::arrays::varbin::VarBinArrayExt;
25use crate::arrow::Datum;
26use crate::arrow::from_arrow_columnar;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::IntegerPType;
30use crate::match_each_integer_ptype;
31use crate::scalar_fn::fns::binary::CompareKernel;
32use crate::scalar_fn::fns::operators::CompareOperator;
33use crate::scalar_fn::fns::operators::Operator;
34
35// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
36impl CompareKernel for VarBin {
37    fn compare(
38        lhs: ArrayView<'_, VarBin>,
39        rhs: &ArrayRef,
40        operator: CompareOperator,
41        ctx: &mut ExecutionCtx,
42    ) -> VortexResult<Option<ArrayRef>> {
43        if let Some(rhs_const) = rhs.as_constant() {
44            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
45            let len = lhs.len();
46
47            let rhs_is_empty = match rhs_const.dtype() {
48                DType::Binary(_) => rhs_const
49                    .as_binary()
50                    .is_empty()
51                    .vortex_expect("RHS should not be null"),
52                DType::Utf8(_) => rhs_const
53                    .as_utf8()
54                    .is_empty()
55                    .vortex_expect("RHS should not be null"),
56                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57            };
58
59            if rhs_is_empty {
60                let buffer = match operator {
61                    CompareOperator::Gte => BitBuffer::new_set(len), /* Every possible value is >= "" */
62                    CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < ""
63                    CompareOperator::Eq | CompareOperator::Lte => {
64                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
65                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
66                            compare_offsets_to_empty::<P>(lhs_offsets, true)
67                        })
68                    }
69                    CompareOperator::NotEq | CompareOperator::Gt => {
70                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
71                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
72                            compare_offsets_to_empty::<P>(lhs_offsets, false)
73                        })
74                    }
75                };
76
77                return Ok(Some(
78                    BoolArray::new(
79                        buffer,
80                        lhs.validity()?.union_nullability(rhs.dtype().nullability()),
81                    )
82                    .into_array(),
83                ));
84            }
85
86            let lhs = Datum::try_new(lhs.array(), ctx)?;
87
88            // The RHS scalar must match the LHS Arrow data type. VarBin with i64 offsets is
89            // converted to LargeBinary/LargeUtf8 (see `preferred_arrow_type`), and Arrow refuses to
90            // compare LargeBinary with Binary (or LargeUtf8 with Utf8).
91            let arrow_rhs: &dyn arrow_array::Datum = match (rhs_const.dtype(), lhs.data_type()) {
92                (DType::Utf8(_), DataType::LargeUtf8) => &rhs_const
93                    .as_utf8()
94                    .value()
95                    .map(LargeStringArray::new_scalar)
96                    .unwrap_or_else(|| arrow_array::Scalar::new(LargeStringArray::new_null(1))),
97                (DType::Utf8(_), _) => &rhs_const
98                    .as_utf8()
99                    .value()
100                    .map(StringArray::new_scalar)
101                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
102                (DType::Binary(_), DataType::LargeBinary) => &rhs_const
103                    .as_binary()
104                    .value()
105                    .map(LargeBinaryArray::new_scalar)
106                    .unwrap_or_else(|| arrow_array::Scalar::new(LargeBinaryArray::new_null(1))),
107                (DType::Binary(_), _) => &rhs_const
108                    .as_binary()
109                    .value()
110                    .map(BinaryArray::new_scalar)
111                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
112                _ => vortex_bail!(
113                    "VarBin array RHS can only be Utf8 or Binary, given {}",
114                    rhs_const.dtype()
115                ),
116            };
117
118            let array = match operator {
119                CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
120                CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
121                CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
122                CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
123                CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
124                CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
125            }
126            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
127
128            Ok(Some(from_arrow_columnar(&array, len, nullable, ctx)?))
129        } else if !rhs.is::<VarBin>() {
130            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
131            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
132            // to VarBinView and re-invoke.
133            Ok(Some(
134                lhs.array()
135                    .clone()
136                    .execute::<VarBinViewArray>(ctx)?
137                    .into_array()
138                    .binary(rhs.clone(), Operator::from(operator))?,
139            ))
140        } else {
141            Ok(None)
142        }
143    }
144}
145
146fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
147    let fn_ = if eq { P::eq } else { P::ne };
148    let offsets = offsets.as_slice::<P>();
149    BitBuffer::collect_bool(offsets.len() - 1, |idx| {
150        let left = unsafe { offsets.get_unchecked(idx) };
151        let right = unsafe { offsets.get_unchecked(idx + 1) };
152        fn_(left, right)
153    })
154}
155
156#[cfg(test)]
157mod test {
158    use vortex_buffer::BitBuffer;
159    use vortex_buffer::ByteBuffer;
160
161    use crate::IntoArray;
162    #[expect(deprecated)]
163    use crate::ToCanonical as _;
164    use crate::VortexSessionExecute;
165    use crate::array_session;
166    use crate::arrays::ConstantArray;
167    use crate::arrays::VarBinArray;
168    use crate::arrays::VarBinViewArray;
169    use crate::arrays::bool::BoolArrayExt;
170    use crate::builtins::ArrayBuiltins;
171    use crate::dtype::DType;
172    use crate::dtype::Nullability;
173    use crate::scalar::Scalar;
174    use crate::scalar_fn::fns::operators::Operator;
175
176    #[test]
177    fn test_binary_compare() {
178        let array = VarBinArray::from_iter(
179            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
180            DType::Binary(Nullability::Nullable),
181        );
182        #[expect(deprecated)]
183        let result = array
184            .into_array()
185            .binary(
186                ConstantArray::new(
187                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
188                    3,
189                )
190                .into_array(),
191                Operator::Eq,
192            )
193            .unwrap()
194            .to_bool();
195
196        assert_eq!(
197            &result
198                .as_ref()
199                .validity()
200                .unwrap()
201                .execute_mask(
202                    result.as_ref().len(),
203                    &mut array_session().create_execution_ctx()
204                )
205                .unwrap()
206                .to_bit_buffer(),
207            &BitBuffer::from_iter([true, false, true])
208        );
209        assert_eq!(
210            result.to_bit_buffer(),
211            BitBuffer::from_iter([true, false, false])
212        );
213    }
214
215    #[test]
216    fn varbinview_compare() {
217        let array = VarBinArray::from_iter(
218            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
219            DType::Binary(Nullability::Nullable),
220        );
221        let vbv = VarBinViewArray::from_iter(
222            [None, None, Some(b"def".to_vec())],
223            DType::Binary(Nullability::Nullable),
224        );
225        #[expect(deprecated)]
226        let result = array
227            .into_array()
228            .binary(vbv.into_array(), Operator::Eq)
229            .unwrap()
230            .to_bool();
231
232        assert_eq!(
233            result
234                .as_ref()
235                .validity()
236                .unwrap()
237                .execute_mask(
238                    result.as_ref().len(),
239                    &mut array_session().create_execution_ctx()
240                )
241                .unwrap()
242                .to_bit_buffer(),
243            BitBuffer::from_iter([false, false, true])
244        );
245        assert_eq!(
246            result.to_bit_buffer(),
247            BitBuffer::from_iter([false, true, true])
248        );
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use vortex_buffer::ByteBuffer;
255
256    use crate::IntoArray;
257    use crate::VortexSessionExecute;
258    use crate::array_session;
259    use crate::arrays::BoolArray;
260    use crate::arrays::ConstantArray;
261    use crate::arrays::VarBinArray;
262    use crate::arrays::varbin::builder::VarBinBuilder;
263    use crate::assert_arrays_eq;
264    use crate::builtins::ArrayBuiltins;
265    use crate::dtype::DType;
266    use crate::dtype::Nullability;
267    use crate::scalar::Scalar;
268    use crate::scalar_fn::fns::operators::Operator;
269
270    #[test]
271    fn test_null_compare() {
272        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
273
274        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
275
276        assert_eq!(
277            arr.into_array()
278                .binary(const_.into_array(), Operator::Eq)
279                .unwrap()
280                .dtype(),
281            &DType::Bool(Nullability::Nullable)
282        );
283    }
284
285    /// Regression: a [`VarBinArray`] built with `i64` offsets is canonicalised to
286    /// Arrow `LargeUtf8` / `LargeBinary` by `preferred_arrow_type`. Without an explicit
287    /// branch in [`CompareKernel`], the constant RHS is wrapped in a `StringArray` /
288    /// `BinaryArray` and Arrow rejects the `LargeUtf8 == Utf8` mismatch. Triggering
289    /// this only requires `i64` offsets, not large data.
290    ///
291    /// [`CompareKernel`]: super::CompareKernel
292    #[test]
293    fn varbin_i64_offsets_compare_constant() {
294        let mut ctx = array_session().create_execution_ctx();
295        let mut builder = VarBinBuilder::<i64>::with_capacity(3);
296        builder.append_value(b"abc");
297        builder.append_value(b"xyz");
298        builder.append_value(b"abc");
299        let array = builder.finish(DType::Utf8(Nullability::NonNullable));
300
301        let result = array
302            .into_array()
303            .binary(
304                ConstantArray::new(Scalar::utf8("abc", Nullability::NonNullable), 3).into_array(),
305                Operator::Eq,
306            )
307            .unwrap();
308
309        let expected = BoolArray::from_iter([true, false, true]);
310        assert_arrays_eq!(result, expected, &mut ctx);
311    }
312
313    #[test]
314    fn varbin_i64_offsets_compare_constant_binary() {
315        let mut ctx = array_session().create_execution_ctx();
316        let mut builder = VarBinBuilder::<i64>::with_capacity(3);
317        builder.append_value(b"abc");
318        builder.append_value(b"xyz");
319        builder.append_value(b"abc");
320        let array = builder.finish(DType::Binary(Nullability::NonNullable));
321
322        let result = array
323            .into_array()
324            .binary(
325                ConstantArray::new(
326                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::NonNullable),
327                    3,
328                )
329                .into_array(),
330                Operator::Eq,
331            )
332            .unwrap();
333
334        let expected = BoolArray::from_iter([true, false, true]);
335        assert_arrays_eq!(result, expected, &mut ctx);
336    }
337}