Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::LargeBinaryArray;
6use arrow_array::LargeStringArray;
7use arrow_array::StringArray;
8use arrow_ord::cmp;
9use arrow_schema::DataType;
10use vortex_buffer::BitBuffer;
11use vortex_error::VortexExpect as _;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::array::ArrayView;
20use crate::arrays::BoolArray;
21use crate::arrays::PrimitiveArray;
22use crate::arrays::VarBin;
23use crate::arrays::VarBinViewArray;
24use crate::arrays::varbin::VarBinArrayExt;
25use crate::arrow::Datum;
26use crate::arrow::from_arrow_array_with_len;
27use crate::builtins::ArrayBuiltins;
28use crate::dtype::DType;
29use crate::dtype::IntegerPType;
30use crate::match_each_integer_ptype;
31use crate::scalar_fn::fns::binary::CompareKernel;
32use crate::scalar_fn::fns::operators::CompareOperator;
33use crate::scalar_fn::fns::operators::Operator;
34
35// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
36impl CompareKernel for VarBin {
37    fn compare(
38        lhs: ArrayView<'_, VarBin>,
39        rhs: &ArrayRef,
40        operator: CompareOperator,
41        ctx: &mut ExecutionCtx,
42    ) -> VortexResult<Option<ArrayRef>> {
43        if let Some(rhs_const) = rhs.as_constant() {
44            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
45            let len = lhs.len();
46
47            let rhs_is_empty = match rhs_const.dtype() {
48                DType::Binary(_) => rhs_const
49                    .as_binary()
50                    .is_empty()
51                    .vortex_expect("RHS should not be null"),
52                DType::Utf8(_) => rhs_const
53                    .as_utf8()
54                    .is_empty()
55                    .vortex_expect("RHS should not be null"),
56                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57            };
58
59            if rhs_is_empty {
60                let buffer = match operator {
61                    CompareOperator::Gte => BitBuffer::new_set(len), // Every possible value is >= ""
62                    CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < ""
63                    CompareOperator::Eq | CompareOperator::Lte => {
64                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
65                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
66                            compare_offsets_to_empty::<P>(lhs_offsets, true)
67                        })
68                    }
69                    CompareOperator::NotEq | CompareOperator::Gt => {
70                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
71                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
72                            compare_offsets_to_empty::<P>(lhs_offsets, false)
73                        })
74                    }
75                };
76
77                return Ok(Some(
78                    BoolArray::new(
79                        buffer,
80                        lhs.validity()?.union_nullability(rhs.dtype().nullability()),
81                    )
82                    .into_array(),
83                ));
84            }
85
86            let lhs = Datum::try_new(lhs.array(), ctx)?;
87
88            // The RHS scalar must match the LHS Arrow data type. VarBin with i64 offsets is
89            // converted to LargeBinary/LargeUtf8 (see `preferred_arrow_type`), and Arrow refuses to
90            // compare LargeBinary with Binary (or LargeUtf8 with Utf8).
91            let arrow_rhs: &dyn arrow_array::Datum = match (rhs_const.dtype(), lhs.data_type()) {
92                (DType::Utf8(_), DataType::LargeUtf8) => &rhs_const
93                    .as_utf8()
94                    .value()
95                    .map(LargeStringArray::new_scalar)
96                    .unwrap_or_else(|| arrow_array::Scalar::new(LargeStringArray::new_null(1))),
97                (DType::Utf8(_), _) => &rhs_const
98                    .as_utf8()
99                    .value()
100                    .map(StringArray::new_scalar)
101                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
102                (DType::Binary(_), DataType::LargeBinary) => &rhs_const
103                    .as_binary()
104                    .value()
105                    .map(LargeBinaryArray::new_scalar)
106                    .unwrap_or_else(|| arrow_array::Scalar::new(LargeBinaryArray::new_null(1))),
107                (DType::Binary(_), _) => &rhs_const
108                    .as_binary()
109                    .value()
110                    .map(BinaryArray::new_scalar)
111                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
112                _ => vortex_bail!(
113                    "VarBin array RHS can only be Utf8 or Binary, given {}",
114                    rhs_const.dtype()
115                ),
116            };
117
118            let array = match operator {
119                CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
120                CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
121                CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
122                CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
123                CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
124                CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
125            }
126            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
127
128            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
129        } else if !rhs.is::<VarBin>() {
130            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
131            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
132            // to VarBinView and re-invoke.
133            Ok(Some(
134                lhs.array()
135                    .clone()
136                    .execute::<VarBinViewArray>(ctx)?
137                    .into_array()
138                    .binary(rhs.clone(), Operator::from(operator))?,
139            ))
140        } else {
141            Ok(None)
142        }
143    }
144}
145
146fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
147    let fn_ = if eq { P::eq } else { P::ne };
148    let offsets = offsets.as_slice::<P>();
149    BitBuffer::collect_bool(offsets.len() - 1, |idx| {
150        let left = unsafe { offsets.get_unchecked(idx) };
151        let right = unsafe { offsets.get_unchecked(idx + 1) };
152        fn_(left, right)
153    })
154}
155
156#[cfg(test)]
157mod test {
158    use vortex_buffer::BitBuffer;
159    use vortex_buffer::ByteBuffer;
160
161    use crate::IntoArray;
162    use crate::LEGACY_SESSION;
163    #[expect(deprecated)]
164    use crate::ToCanonical as _;
165    use crate::VortexSessionExecute;
166    use crate::arrays::ConstantArray;
167    use crate::arrays::VarBinArray;
168    use crate::arrays::VarBinViewArray;
169    use crate::arrays::bool::BoolArrayExt;
170    use crate::builtins::ArrayBuiltins;
171    use crate::dtype::DType;
172    use crate::dtype::Nullability;
173    use crate::scalar::Scalar;
174    use crate::scalar_fn::fns::operators::Operator;
175
176    #[test]
177    fn test_binary_compare() {
178        let array = VarBinArray::from_iter(
179            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
180            DType::Binary(Nullability::Nullable),
181        );
182        #[expect(deprecated)]
183        let result = array
184            .into_array()
185            .binary(
186                ConstantArray::new(
187                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
188                    3,
189                )
190                .into_array(),
191                Operator::Eq,
192            )
193            .unwrap()
194            .to_bool();
195
196        assert_eq!(
197            &result
198                .as_ref()
199                .validity()
200                .unwrap()
201                .execute_mask(
202                    result.as_ref().len(),
203                    &mut LEGACY_SESSION.create_execution_ctx()
204                )
205                .unwrap()
206                .to_bit_buffer(),
207            &BitBuffer::from_iter([true, false, true])
208        );
209        assert_eq!(
210            result.to_bit_buffer(),
211            BitBuffer::from_iter([true, false, false])
212        );
213    }
214
215    #[test]
216    fn varbinview_compare() {
217        let array = VarBinArray::from_iter(
218            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
219            DType::Binary(Nullability::Nullable),
220        );
221        let vbv = VarBinViewArray::from_iter(
222            [None, None, Some(b"def".to_vec())],
223            DType::Binary(Nullability::Nullable),
224        );
225        #[expect(deprecated)]
226        let result = array
227            .into_array()
228            .binary(vbv.into_array(), Operator::Eq)
229            .unwrap()
230            .to_bool();
231
232        assert_eq!(
233            result
234                .as_ref()
235                .validity()
236                .unwrap()
237                .execute_mask(
238                    result.as_ref().len(),
239                    &mut LEGACY_SESSION.create_execution_ctx()
240                )
241                .unwrap()
242                .to_bit_buffer(),
243            BitBuffer::from_iter([false, false, true])
244        );
245        assert_eq!(
246            result.to_bit_buffer(),
247            BitBuffer::from_iter([false, true, true])
248        );
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use vortex_buffer::ByteBuffer;
255
256    use crate::IntoArray;
257    use crate::arrays::BoolArray;
258    use crate::arrays::ConstantArray;
259    use crate::arrays::VarBinArray;
260    use crate::arrays::varbin::builder::VarBinBuilder;
261    use crate::assert_arrays_eq;
262    use crate::builtins::ArrayBuiltins;
263    use crate::dtype::DType;
264    use crate::dtype::Nullability;
265    use crate::scalar::Scalar;
266    use crate::scalar_fn::fns::operators::Operator;
267
268    #[test]
269    fn test_null_compare() {
270        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
271
272        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
273
274        assert_eq!(
275            arr.into_array()
276                .binary(const_.into_array(), Operator::Eq)
277                .unwrap()
278                .dtype(),
279            &DType::Bool(Nullability::Nullable)
280        );
281    }
282
283    /// Regression: a [`VarBinArray`] built with `i64` offsets is canonicalised to
284    /// Arrow `LargeUtf8` / `LargeBinary` by `preferred_arrow_type`. Without an explicit
285    /// branch in [`CompareKernel`], the constant RHS is wrapped in a `StringArray` /
286    /// `BinaryArray` and Arrow rejects the `LargeUtf8 == Utf8` mismatch. Triggering
287    /// this only requires `i64` offsets, not large data.
288    ///
289    /// [`CompareKernel`]: super::CompareKernel
290    #[test]
291    fn varbin_i64_offsets_compare_constant() {
292        let mut builder = VarBinBuilder::<i64>::with_capacity(3);
293        builder.append_value(b"abc");
294        builder.append_value(b"xyz");
295        builder.append_value(b"abc");
296        let array = builder.finish(DType::Utf8(Nullability::NonNullable));
297
298        let result = array
299            .into_array()
300            .binary(
301                ConstantArray::new(Scalar::utf8("abc", Nullability::NonNullable), 3).into_array(),
302                Operator::Eq,
303            )
304            .unwrap();
305
306        let expected = BoolArray::from_iter([true, false, true]);
307        assert_arrays_eq!(result, expected);
308    }
309
310    #[test]
311    fn varbin_i64_offsets_compare_constant_binary() {
312        let mut builder = VarBinBuilder::<i64>::with_capacity(3);
313        builder.append_value(b"abc");
314        builder.append_value(b"xyz");
315        builder.append_value(b"abc");
316        let array = builder.finish(DType::Binary(Nullability::NonNullable));
317
318        let result = array
319            .into_array()
320            .binary(
321                ConstantArray::new(
322                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::NonNullable),
323                    3,
324                )
325                .into_array(),
326                Operator::Eq,
327            )
328            .unwrap();
329
330        let expected = BoolArray::from_iter([true, false, true]);
331        assert_arrays_eq!(result, expected);
332    }
333}