vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::{BinaryArray, StringArray};
2use arrow_buffer::BooleanBuffer;
3use arrow_ord::cmp;
4use itertools::Itertools;
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable};
9use crate::arrow::{Datum, from_arrow_array_with_len};
10use crate::compute::{
11 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
12};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16impl CompareKernel for VarBinVTable {
18 fn compare(
19 &self,
20 lhs: &VarBinArray,
21 rhs: &dyn Array,
22 operator: Operator,
23 ) -> VortexResult<Option<ArrayRef>> {
24 if let Some(rhs_const) = rhs.as_constant() {
25 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
26 let len = lhs.len();
27
28 let rhs_is_empty = match rhs_const.dtype() {
29 DType::Binary(_) => rhs_const
30 .as_binary()
31 .is_empty()
32 .vortex_expect("RHS should not be null"),
33 DType::Utf8(_) => rhs_const
34 .as_utf8()
35 .is_empty()
36 .vortex_expect("RHS should not be null"),
37 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
38 };
39
40 if rhs_is_empty {
41 let buffer = match operator {
42 Operator::Gte => BooleanBuffer::new_set(len),
44 Operator::Lt => BooleanBuffer::new_unset(len),
46 _ => {
47 let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
48 match_each_native_ptype!(lhs_offsets.ptype(), |P| {
49 compare_offsets_to_empty::<P>(lhs_offsets, operator)
50 })
51 }
52 };
53
54 return Ok(Some(
55 BoolArray::new(buffer, lhs.validity().clone()).into_array(),
56 ));
57 }
58
59 let lhs = Datum::try_new(lhs.as_ref())?;
60
61 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
63 DType::Utf8(_) => &rhs_const
64 .as_utf8()
65 .value()
66 .map(StringArray::new_scalar)
67 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
68 DType::Binary(_) => &rhs_const
69 .as_binary()
70 .value()
71 .map(BinaryArray::new_scalar)
72 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
73 _ => vortex_bail!(
74 "VarBin array RHS can only be Utf8 or Binary, given {}",
75 rhs_const.dtype()
76 ),
77 };
78
79 let array = match operator {
80 Operator::Eq => cmp::eq(&lhs, arrow_rhs),
81 Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
82 Operator::Gt => cmp::gt(&lhs, arrow_rhs),
83 Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
84 Operator::Lt => cmp::lt(&lhs, arrow_rhs),
85 Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
86 }
87 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
88
89 Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
90 } else if !rhs.is::<VarBinVTable>() {
91 return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
95 } else {
96 Ok(None)
97 }
98 }
99}
100
101register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
102
103fn compare_offsets_to_empty<P: NativePType>(
104 offsets: PrimitiveArray,
105 operator: Operator,
106) -> BooleanBuffer {
107 let lengths_iter = offsets
108 .as_slice::<P>()
109 .iter()
110 .tuple_windows()
111 .map(|(&s, &e)| e - s);
112 compare_lengths_to_empty(lengths_iter, operator)
113}
114
115#[cfg(test)]
116mod test {
117 use arrow_buffer::BooleanBuffer;
118 use vortex_buffer::ByteBuffer;
119 use vortex_dtype::{DType, Nullability};
120 use vortex_scalar::Scalar;
121
122 use crate::ToCanonical;
123 use crate::arrays::{ConstantArray, VarBinArray, VarBinViewArray};
124 use crate::compute::{Operator, compare};
125
126 #[test]
127 fn test_binary_compare() {
128 let array = VarBinArray::from_iter(
129 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
130 DType::Binary(Nullability::Nullable),
131 );
132 let result = compare(
133 array.as_ref(),
134 ConstantArray::new(
135 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
136 3,
137 )
138 .as_ref(),
139 Operator::Eq,
140 )
141 .unwrap()
142 .to_bool()
143 .unwrap();
144
145 assert_eq!(
146 &result.validity_mask().unwrap().to_boolean_buffer(),
147 &BooleanBuffer::from_iter([true, false, true])
148 );
149 assert_eq!(
150 result.boolean_buffer(),
151 &BooleanBuffer::from_iter([true, false, false])
152 );
153 }
154
155 #[test]
156 fn varbinview_compare() {
157 let array = VarBinArray::from_iter(
158 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
159 DType::Binary(Nullability::Nullable),
160 );
161 let vbv = VarBinViewArray::from_iter(
162 [None, None, Some(b"def".to_vec())],
163 DType::Binary(Nullability::Nullable),
164 );
165 let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
166 .unwrap()
167 .to_bool()
168 .unwrap();
169
170 assert_eq!(
171 &result.validity_mask().unwrap().to_boolean_buffer(),
172 &BooleanBuffer::from_iter([false, false, true])
173 );
174 assert_eq!(
175 result.boolean_buffer(),
176 &BooleanBuffer::from_iter([false, true, true])
177 );
178 }
179}