Skip to main content

vortex_array/arrays/varbin/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use itertools::Itertools;
8use vortex_buffer::BitBuffer;
9use vortex_dtype::DType;
10use vortex_dtype::IntegerPType;
11use vortex_dtype::match_each_integer_ptype;
12use vortex_error::VortexExpect as _;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::ToCanonical;
22use crate::arrays::BoolArray;
23use crate::arrays::PrimitiveArray;
24use crate::arrays::VarBinArray;
25use crate::arrays::VarBinVTable;
26use crate::arrow::Datum;
27use crate::arrow::from_arrow_array_with_len;
28use crate::compute::Operator;
29use crate::compute::compare;
30use crate::compute::compare_lengths_to_empty;
31use crate::expr::CompareKernel;
32use crate::vtable::ValidityHelper;
33
34// This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical
35impl CompareKernel for VarBinVTable {
36    fn compare(
37        lhs: &VarBinArray,
38        rhs: &dyn Array,
39        operator: Operator,
40        _ctx: &mut ExecutionCtx,
41    ) -> VortexResult<Option<ArrayRef>> {
42        if let Some(rhs_const) = rhs.as_constant() {
43            let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
44            let len = lhs.len();
45
46            let rhs_is_empty = match rhs_const.dtype() {
47                DType::Binary(_) => rhs_const
48                    .as_binary()
49                    .is_empty()
50                    .vortex_expect("RHS should not be null"),
51                DType::Utf8(_) => rhs_const
52                    .as_utf8()
53                    .is_empty()
54                    .vortex_expect("RHS should not be null"),
55                _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
56            };
57
58            if rhs_is_empty {
59                let buffer = match operator {
60                    Operator::Gte => BitBuffer::new_set(len), // Every possible value is >= ""
61                    Operator::Lt => BitBuffer::new_unset(len), // No value is < ""
62                    Operator::Eq | Operator::NotEq | Operator::Gt | Operator::Lte => {
63                        let lhs_offsets = lhs.offsets().to_primitive();
64                        match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
65                            compare_offsets_to_empty::<P>(lhs_offsets, operator)
66                        })
67                    }
68                };
69
70                return Ok(Some(
71                    BoolArray::new(
72                        buffer,
73                        lhs.validity()
74                            .clone()
75                            .union_nullability(rhs.dtype().nullability()),
76                    )
77                    .into_array(),
78                ));
79            }
80
81            let lhs = Datum::try_new(lhs.as_ref())?;
82
83            // Use StringViewArray/BinaryViewArray to match the Utf8View/BinaryView types
84            // produced by Datum::try_new (which uses into_arrow_preferred())
85            let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
86                DType::Utf8(_) => &rhs_const
87                    .as_utf8()
88                    .value()
89                    .map(StringArray::new_scalar)
90                    .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
91                DType::Binary(_) => &rhs_const
92                    .as_binary()
93                    .value()
94                    .map(BinaryArray::new_scalar)
95                    .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
96                _ => vortex_bail!(
97                    "VarBin array RHS can only be Utf8 or Binary, given {}",
98                    rhs_const.dtype()
99                ),
100            };
101
102            let array = match operator {
103                Operator::Eq => cmp::eq(&lhs, arrow_rhs),
104                Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
105                Operator::Gt => cmp::gt(&lhs, arrow_rhs),
106                Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
107                Operator::Lt => cmp::lt(&lhs, arrow_rhs),
108                Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
109            }
110            .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
111
112            Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
113        } else if !rhs.is::<VarBinVTable>() {
114            // NOTE: If the rhs is not a VarBin array it will be canonicalized to a VarBinView
115            // Arrow doesn't support comparing VarBin to VarBinView arrays, so we convert ourselves
116            // to VarBinView and re-invoke.
117            return Ok(Some(compare(lhs.to_varbinview().as_ref(), rhs, operator)?));
118        } else {
119            Ok(None)
120        }
121    }
122}
123
124fn compare_offsets_to_empty<P: IntegerPType>(
125    offsets: PrimitiveArray,
126    operator: Operator,
127) -> BitBuffer {
128    let lengths_iter = offsets
129        .as_slice::<P>()
130        .iter()
131        .tuple_windows()
132        .map(|(&s, &e)| e - s);
133    compare_lengths_to_empty(lengths_iter, operator)
134}
135
136#[cfg(test)]
137mod test {
138    use vortex_buffer::BitBuffer;
139    use vortex_buffer::ByteBuffer;
140    use vortex_dtype::DType;
141    use vortex_dtype::Nullability;
142
143    use crate::ToCanonical;
144    use crate::arrays::ConstantArray;
145    use crate::arrays::VarBinArray;
146    use crate::arrays::VarBinViewArray;
147    use crate::compute::Operator;
148    use crate::compute::compare;
149    use crate::scalar::Scalar;
150
151    #[test]
152    fn test_binary_compare() {
153        let array = VarBinArray::from_iter(
154            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
155            DType::Binary(Nullability::Nullable),
156        );
157        let result = compare(
158            array.as_ref(),
159            ConstantArray::new(
160                Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
161                3,
162            )
163            .as_ref(),
164            Operator::Eq,
165        )
166        .unwrap()
167        .to_bool();
168
169        assert_eq!(
170            &result.validity_mask().unwrap().to_bit_buffer(),
171            &BitBuffer::from_iter([true, false, true])
172        );
173        assert_eq!(
174            result.to_bit_buffer(),
175            BitBuffer::from_iter([true, false, false])
176        );
177    }
178
179    #[test]
180    fn varbinview_compare() {
181        let array = VarBinArray::from_iter(
182            [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
183            DType::Binary(Nullability::Nullable),
184        );
185        let vbv = VarBinViewArray::from_iter(
186            [None, None, Some(b"def".to_vec())],
187            DType::Binary(Nullability::Nullable),
188        );
189        let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
190            .unwrap()
191            .to_bool();
192
193        assert_eq!(
194            result.validity_mask().unwrap().to_bit_buffer(),
195            BitBuffer::from_iter([false, false, true])
196        );
197        assert_eq!(
198            result.to_bit_buffer(),
199            BitBuffer::from_iter([false, true, true])
200        );
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use vortex_dtype::DType;
207    use vortex_dtype::Nullability;
208
209    use crate::Array;
210    use crate::arrays::ConstantArray;
211    use crate::arrays::VarBinArray;
212    use crate::compute::Operator;
213    use crate::compute::compare;
214    use crate::scalar::Scalar;
215
216    #[test]
217    fn test_null_compare() {
218        let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
219
220        let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
221
222        assert_eq!(
223            compare(arr.as_ref(), const_.as_ref(), Operator::Eq)
224                .unwrap()
225                .dtype(),
226            &DType::Bool(Nullability::Nullable)
227        );
228    }
229}