Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect as _;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::arrays::BoolArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::VarBin;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::varbin::VarBinArrayExt;
22use crate::arrow::Datum;
23use crate::arrow::from_arrow_array_with_len;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::IntegerPType;
27use crate::match_each_integer_ptype;
28use crate::scalar_fn::fns::binary::CompareKernel;
29use crate::scalar_fn::fns::operators::CompareOperator;
30use crate::scalar_fn::fns::operators::Operator;
31
32// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
33impl CompareKernel for VarBin {
34    fn compare(
35        lhs: ArrayView<'_, VarBin>,
36        rhs: &ArrayRef,
37        operator: CompareOperator,
38        ctx: &mut ExecutionCtx,
39    ) -> VortexResult<Option<ArrayRef>> {
40        if let Some(rhs_const) = rhs.as_constant() {
41            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
42            let len = lhs.len();
43
44            let rhs_is_empty = match rhs_const.dtype() {
45                DType::Binary(_) => rhs_const
46                    .as_binary()
47                    .is_empty()
48                    .vortex_expect("RHS should not be null"),
49                DType::Utf8(_) => rhs_const
50                    .as_utf8()
51                    .is_empty()
52                    .vortex_expect("RHS should not be null"),
53                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
54            };
55
56            if rhs_is_empty {
57                let buffer = match operator {
58                    CompareOperator::Gte => BitBuffer::new_set(len), // Every possible value is >= ""
59                    CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < ""
60                    CompareOperator::Eq | CompareOperator::Lte => {
61                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
62                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
63                            compare_offsets_to_empty::<P>(lhs_offsets, true)
64                        })
65                    }
66                    CompareOperator::NotEq | CompareOperator::Gt => {
67                        let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
68                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
69                            compare_offsets_to_empty::<P>(lhs_offsets, false)
70                        })
71                    }
72                };
73
74                return Ok(Some(
75                    BoolArray::new(
76                        buffer,
77                        lhs.validity()?.union_nullability(rhs.dtype().nullability()),
78                    )
79                    .into_array(),
80                ));
81            }
82
83            let lhs = Datum::try_new(lhs.array(), ctx)?;
84
85            // Use StringViewArray/BinaryViewArray to match the Utf8View/BinaryView types
86            // produced by Datum::try_new (which uses execute_arrow(None, ctx))
87            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
88                DType::Utf8(_) => &rhs_const
89                    .as_utf8()
90                    .value()
91                    .map(StringArray::new_scalar)
92                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
93                DType::Binary(_) => &rhs_const
94                    .as_binary()
95                    .value()
96                    .map(BinaryArray::new_scalar)
97                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
98                _ => vortex_bail!(
99                    "VarBin array RHS can only be Utf8 or Binary, given {}",
100                    rhs_const.dtype()
101                ),
102            };
103
104            let array = match operator {
105                CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
106                CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
107                CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
108                CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
109                CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
110                CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
111            }
112            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
113
114            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
115        } else if !rhs.is::<VarBin>() {
116            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
117            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
118            // to VarBinView and re-invoke.
119            Ok(Some(
120                lhs.array()
121                    .clone()
122                    .execute::<VarBinViewArray>(ctx)?
123                    .into_array()
124                    .binary(rhs.clone(), Operator::from(operator))?,
125            ))
126        } else {
127            Ok(None)
128        }
129    }
130}
131
132fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
133    let fn_ = if eq { P::eq } else { P::ne };
134    let offsets = offsets.as_slice::<P>();
135    BitBuffer::collect_bool(offsets.len() - 1, |idx| {
136        let left = unsafe { offsets.get_unchecked(idx) };
137        let right = unsafe { offsets.get_unchecked(idx + 1) };
138        fn_(left, right)
139    })
140}
141
142#[cfg(test)]
143mod test {
144    use vortex_buffer::BitBuffer;
145    use vortex_buffer::ByteBuffer;
146
147    use crate::IntoArray;
148    use crate::LEGACY_SESSION;
149    #[expect(deprecated)]
150    use crate::ToCanonical as _;
151    use crate::VortexSessionExecute;
152    use crate::arrays::ConstantArray;
153    use crate::arrays::VarBinArray;
154    use crate::arrays::VarBinViewArray;
155    use crate::arrays::bool::BoolArrayExt;
156    use crate::builtins::ArrayBuiltins;
157    use crate::dtype::DType;
158    use crate::dtype::Nullability;
159    use crate::scalar::Scalar;
160    use crate::scalar_fn::fns::operators::Operator;
161
162    #[test]
163    fn test_binary_compare() {
164        let array = VarBinArray::from_iter(
165            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
166            DType::Binary(Nullability::Nullable),
167        );
168        #[expect(deprecated)]
169        let result = array
170            .into_array()
171            .binary(
172                ConstantArray::new(
173                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
174                    3,
175                )
176                .into_array(),
177                Operator::Eq,
178            )
179            .unwrap()
180            .to_bool();
181
182        assert_eq!(
183            &result
184                .as_ref()
185                .validity()
186                .unwrap()
187                .execute_mask(
188                    result.as_ref().len(),
189                    &mut LEGACY_SESSION.create_execution_ctx()
190                )
191                .unwrap()
192                .to_bit_buffer(),
193            &BitBuffer::from_iter([true, false, true])
194        );
195        assert_eq!(
196            result.to_bit_buffer(),
197            BitBuffer::from_iter([true, false, false])
198        );
199    }
200
201    #[test]
202    fn varbinview_compare() {
203        let array = VarBinArray::from_iter(
204            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
205            DType::Binary(Nullability::Nullable),
206        );
207        let vbv = VarBinViewArray::from_iter(
208            [None, None, Some(b"def".to_vec())],
209            DType::Binary(Nullability::Nullable),
210        );
211        #[expect(deprecated)]
212        let result = array
213            .into_array()
214            .binary(vbv.into_array(), Operator::Eq)
215            .unwrap()
216            .to_bool();
217
218        assert_eq!(
219            result
220                .as_ref()
221                .validity()
222                .unwrap()
223                .execute_mask(
224                    result.as_ref().len(),
225                    &mut LEGACY_SESSION.create_execution_ctx()
226                )
227                .unwrap()
228                .to_bit_buffer(),
229            BitBuffer::from_iter([false, false, true])
230        );
231        assert_eq!(
232            result.to_bit_buffer(),
233            BitBuffer::from_iter([false, true, true])
234        );
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::IntoArray;
241    use crate::arrays::ConstantArray;
242    use crate::arrays::VarBinArray;
243    use crate::builtins::ArrayBuiltins;
244    use crate::dtype::DType;
245    use crate::dtype::Nullability;
246    use crate::scalar::Scalar;
247    use crate::scalar_fn::fns::operators::Operator;
248
249    #[test]
250    fn test_null_compare() {
251        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
252
253        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
254
255        assert_eq!(
256            arr.into_array()
257                .binary(const_.into_array(), Operator::Eq)
258                .unwrap()
259                .dtype(),
260            &DType::Bool(Nullability::Nullable)
261        );
262    }
263}