vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect as _;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::IntoArray;
16use crate::array::ArrayView;
17use crate::arrays::BoolArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::VarBin;
20use crate::arrays::VarBinViewArray;
21use crate::arrays::varbin::VarBinArrayExt;
22use crate::arrow::Datum;
23use crate::arrow::from_arrow_array_with_len;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::IntegerPType;
27use crate::match_each_integer_ptype;
28use crate::scalar_fn::fns::binary::CompareKernel;
29use crate::scalar_fn::fns::operators::CompareOperator;
30use crate::scalar_fn::fns::operators::Operator;
31
32impl CompareKernel for VarBin {
34 fn compare(
35 lhs: ArrayView<'_, VarBin>,
36 rhs: &ArrayRef,
37 operator: CompareOperator,
38 ctx: &mut ExecutionCtx,
39 ) -> VortexResult<Option<ArrayRef>> {
40 if let Some(rhs_const) = rhs.as_constant() {
41 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
42 let len = lhs.len();
43
44 let rhs_is_empty = match rhs_const.dtype() {
45 DType::Binary(_) => rhs_const
46 .as_binary()
47 .is_empty()
48 .vortex_expect("RHS should not be null"),
49 DType::Utf8(_) => rhs_const
50 .as_utf8()
51 .is_empty()
52 .vortex_expect("RHS should not be null"),
53 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
54 };
55
56 if rhs_is_empty {
57 let buffer = match operator {
58 CompareOperator::Gte => BitBuffer::new_set(len), CompareOperator::Lt => BitBuffer::new_unset(len), CompareOperator::Eq | CompareOperator::Lte => {
61 let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
62 match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
63 compare_offsets_to_empty::<P>(lhs_offsets, true)
64 })
65 }
66 CompareOperator::NotEq | CompareOperator::Gt => {
67 let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
68 match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
69 compare_offsets_to_empty::<P>(lhs_offsets, false)
70 })
71 }
72 };
73
74 return Ok(Some(
75 BoolArray::new(
76 buffer,
77 lhs.validity()?.union_nullability(rhs.dtype().nullability()),
78 )
79 .into_array(),
80 ));
81 }
82
83 let lhs = Datum::try_new(lhs.array())?;
84
85 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
88 DType::Utf8(_) => &rhs_const
89 .as_utf8()
90 .value()
91 .map(StringArray::new_scalar)
92 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
93 DType::Binary(_) => &rhs_const
94 .as_binary()
95 .value()
96 .map(BinaryArray::new_scalar)
97 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
98 _ => vortex_bail!(
99 "VarBin array RHS can only be Utf8 or Binary, given {}",
100 rhs_const.dtype()
101 ),
102 };
103
104 let array = match operator {
105 CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
106 CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
107 CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
108 CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
109 CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
110 CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
111 }
112 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
113
114 Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
115 } else if !rhs.is::<VarBin>() {
116 Ok(Some(
120 lhs.array()
121 .clone()
122 .execute::<VarBinViewArray>(ctx)?
123 .into_array()
124 .binary(rhs.clone(), Operator::from(operator))?,
125 ))
126 } else {
127 Ok(None)
128 }
129 }
130}
131
132fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
133 let fn_ = if eq { P::eq } else { P::ne };
134 let offsets = offsets.as_slice::<P>();
135 BitBuffer::collect_bool(offsets.len() - 1, |idx| {
136 let left = unsafe { offsets.get_unchecked(idx) };
137 let right = unsafe { offsets.get_unchecked(idx + 1) };
138 fn_(left, right)
139 })
140}
141
142#[cfg(test)]
143mod test {
144 use vortex_buffer::BitBuffer;
145 use vortex_buffer::ByteBuffer;
146
147 use crate::IntoArray;
148 use crate::LEGACY_SESSION;
149 use crate::ToCanonical;
150 use crate::VortexSessionExecute;
151 use crate::arrays::ConstantArray;
152 use crate::arrays::VarBinArray;
153 use crate::arrays::VarBinViewArray;
154 use crate::arrays::bool::BoolArrayExt;
155 use crate::builtins::ArrayBuiltins;
156 use crate::dtype::DType;
157 use crate::dtype::Nullability;
158 use crate::scalar::Scalar;
159 use crate::scalar_fn::fns::operators::Operator;
160
161 #[test]
162 fn test_binary_compare() {
163 let array = VarBinArray::from_iter(
164 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
165 DType::Binary(Nullability::Nullable),
166 );
167 let result = array
168 .into_array()
169 .binary(
170 ConstantArray::new(
171 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
172 3,
173 )
174 .into_array(),
175 Operator::Eq,
176 )
177 .unwrap()
178 .to_bool();
179
180 assert_eq!(
181 &result
182 .as_ref()
183 .validity()
184 .unwrap()
185 .to_mask(
186 result.as_ref().len(),
187 &mut LEGACY_SESSION.create_execution_ctx()
188 )
189 .unwrap()
190 .to_bit_buffer(),
191 &BitBuffer::from_iter([true, false, true])
192 );
193 assert_eq!(
194 result.to_bit_buffer(),
195 BitBuffer::from_iter([true, false, false])
196 );
197 }
198
199 #[test]
200 fn varbinview_compare() {
201 let array = VarBinArray::from_iter(
202 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
203 DType::Binary(Nullability::Nullable),
204 );
205 let vbv = VarBinViewArray::from_iter(
206 [None, None, Some(b"def".to_vec())],
207 DType::Binary(Nullability::Nullable),
208 );
209 let result = array
210 .into_array()
211 .binary(vbv.into_array(), Operator::Eq)
212 .unwrap()
213 .to_bool();
214
215 assert_eq!(
216 result
217 .as_ref()
218 .validity()
219 .unwrap()
220 .to_mask(
221 result.as_ref().len(),
222 &mut LEGACY_SESSION.create_execution_ctx()
223 )
224 .unwrap()
225 .to_bit_buffer(),
226 BitBuffer::from_iter([false, false, true])
227 );
228 assert_eq!(
229 result.to_bit_buffer(),
230 BitBuffer::from_iter([false, true, true])
231 );
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use crate::IntoArray;
238 use crate::arrays::ConstantArray;
239 use crate::arrays::VarBinArray;
240 use crate::builtins::ArrayBuiltins;
241 use crate::dtype::DType;
242 use crate::dtype::Nullability;
243 use crate::scalar::Scalar;
244 use crate::scalar_fn::fns::operators::Operator;
245
246 #[test]
247 fn test_null_compare() {
248 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
249
250 let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
251
252 assert_eq!(
253 arr.into_array()
254 .binary(const_.into_array(), Operator::Eq)
255 .unwrap()
256 .dtype(),
257 &DType::Bool(Nullability::Nullable)
258 );
259 }
260}