Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use itertools::Itertools;
8use vortex_buffer::BitBuffer;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::BoolArray;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::VarBinArray;
22use crate::arrays::VarBinVTable;
23use crate::arrow::Datum;
24use crate::arrow::from_arrow_array_with_len;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::compare_lengths_to_empty;
27use crate::dtype::DType;
28use crate::dtype::IntegerPType;
29use crate::match_each_integer_ptype;
30use crate::scalar_fn::fns::binary::CompareKernel;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::scalar_fn::fns::operators::Operator;
33use crate::vtable::ValidityHelper;
34
35// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
36impl CompareKernel for VarBinVTable {
37    fn compare(
38        lhs: &VarBinArray,
39        rhs: &dyn Array,
40        operator: CompareOperator,
41        _ctx: &mut ExecutionCtx,
42    ) -> VortexResult<Option<ArrayRef>> {
43        if let Some(rhs_const) = rhs.as_constant() {
44            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
45            let len = lhs.len();
46
47            let rhs_is_empty = match rhs_const.dtype() {
48                DType::Binary(_) => rhs_const
49                    .as_binary()
50                    .is_empty()
51                    .vortex_expect("RHS should not be null"),
52                DType::Utf8(_) => rhs_const
53                    .as_utf8()
54                    .is_empty()
55                    .vortex_expect("RHS should not be null"),
56                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57            };
58
59            if rhs_is_empty {
60                let buffer = match operator {
61                    CompareOperator::Gte => BitBuffer::new_set(len), // Every possible value is >= ""
62                    CompareOperator::Lt => BitBuffer::new_unset(len), // No value is < ""
63                    CompareOperator::Eq
64                    | CompareOperator::NotEq
65                    | CompareOperator::Gt
66                    | CompareOperator::Lte => {
67                        let lhs_offsets = lhs.offsets().to_primitive();
68                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
69                            compare_offsets_to_empty::<P>(lhs_offsets, operator)
70                        })
71                    }
72                };
73
74                return Ok(Some(
75                    BoolArray::new(
76                        buffer,
77                        lhs.validity()
78                            .clone()
79                            .union_nullability(rhs.dtype().nullability()),
80                    )
81                    .into_array(),
82                ));
83            }
84
85            let lhs = Datum::try_new(lhs.as_ref())?;
86
87            // Use StringViewArray/BinaryViewArray to match the Utf8View/BinaryView types
88            // produced by Datum::try_new (which uses into_arrow_preferred())
89            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
90                DType::Utf8(_) => &rhs_const
91                    .as_utf8()
92                    .value()
93                    .map(StringArray::new_scalar)
94                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
95                DType::Binary(_) => &rhs_const
96                    .as_binary()
97                    .value()
98                    .map(BinaryArray::new_scalar)
99                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
100                _ => vortex_bail!(
101                    "VarBin array RHS can only be Utf8 or Binary, given {}",
102                    rhs_const.dtype()
103                ),
104            };
105
106            let array = match operator {
107                CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
108                CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
109                CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
110                CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
111                CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
112                CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
113            }
114            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
115
116            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
117        } else if !rhs.is::<VarBinVTable>() {
118            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
119            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
120            // to VarBinView and re-invoke.
121            return Ok(Some(
122                lhs.to_varbinview()
123                    .to_array()
124                    .binary(rhs.to_array(), Operator::from(operator))?,
125            ));
126        } else {
127            Ok(None)
128        }
129    }
130}
131
132fn compare_offsets_to_empty<P: IntegerPType>(
133    offsets: PrimitiveArray,
134    operator: CompareOperator,
135) -> BitBuffer {
136    let lengths_iter = offsets
137        .as_slice::<P>()
138        .iter()
139        .tuple_windows()
140        .map(|(&s, &e)| e - s);
141    compare_lengths_to_empty(lengths_iter, operator)
142}
143
144#[cfg(test)]
145mod test {
146    use vortex_buffer::BitBuffer;
147    use vortex_buffer::ByteBuffer;
148
149    use crate::ToCanonical;
150    use crate::arrays::ConstantArray;
151    use crate::arrays::VarBinArray;
152    use crate::arrays::VarBinViewArray;
153    use crate::builtins::ArrayBuiltins;
154    use crate::dtype::DType;
155    use crate::dtype::Nullability;
156    use crate::scalar::Scalar;
157    use crate::scalar_fn::fns::operators::Operator;
158
159    #[test]
160    fn test_binary_compare() {
161        let array = VarBinArray::from_iter(
162            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
163            DType::Binary(Nullability::Nullable),
164        );
165        let result = array
166            .to_array()
167            .binary(
168                ConstantArray::new(
169                    Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
170                    3,
171                )
172                .to_array(),
173                Operator::Eq,
174            )
175            .unwrap()
176            .to_bool();
177
178        assert_eq!(
179            &result.validity_mask().unwrap().to_bit_buffer(),
180            &BitBuffer::from_iter([true, false, true])
181        );
182        assert_eq!(
183            result.to_bit_buffer(),
184            BitBuffer::from_iter([true, false, false])
185        );
186    }
187
188    #[test]
189    fn varbinview_compare() {
190        let array = VarBinArray::from_iter(
191            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
192            DType::Binary(Nullability::Nullable),
193        );
194        let vbv = VarBinViewArray::from_iter(
195            [None, None, Some(b"def".to_vec())],
196            DType::Binary(Nullability::Nullable),
197        );
198        let result = array
199            .to_array()
200            .binary(vbv.to_array(), Operator::Eq)
201            .unwrap()
202            .to_bool();
203
204        assert_eq!(
205            result.validity_mask().unwrap().to_bit_buffer(),
206            BitBuffer::from_iter([false, false, true])
207        );
208        assert_eq!(
209            result.to_bit_buffer(),
210            BitBuffer::from_iter([false, true, true])
211        );
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use crate::Array;
218    use crate::arrays::ConstantArray;
219    use crate::arrays::VarBinArray;
220    use crate::builtins::ArrayBuiltins;
221    use crate::dtype::DType;
222    use crate::dtype::Nullability;
223    use crate::scalar::Scalar;
224    use crate::scalar_fn::fns::operators::Operator;
225
226    #[test]
227    fn test_null_compare() {
228        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
229
230        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
231
232        assert_eq!(
233            arr.to_array()
234                .binary(const_.to_array(), Operator::Eq)
235                .unwrap()
236                .dtype(),
237            &DType::Bool(Nullability::Nullable)
238        );
239    }
240}