vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect as _;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_err;
12
13use crate::ArrayRef;
14use crate::DynArray;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::BoolArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::VarBinArray;
20use crate::arrays::VarBinVTable;
21use crate::arrays::VarBinViewArray;
22use crate::arrow::Datum;
23use crate::arrow::from_arrow_array_with_len;
24use crate::builtins::ArrayBuiltins;
25use crate::dtype::DType;
26use crate::dtype::IntegerPType;
27use crate::match_each_integer_ptype;
28use crate::scalar_fn::fns::binary::CompareKernel;
29use crate::scalar_fn::fns::operators::CompareOperator;
30use crate::scalar_fn::fns::operators::Operator;
31use crate::vtable::ValidityHelper;
32
33impl CompareKernel for VarBinVTable {
35 fn compare(
36 lhs: &VarBinArray,
37 rhs: &ArrayRef,
38 operator: CompareOperator,
39 ctx: &mut ExecutionCtx,
40 ) -> VortexResult<Option<ArrayRef>> {
41 if let Some(rhs_const) = rhs.as_constant() {
42 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
43 let len = lhs.len();
44
45 let rhs_is_empty = match rhs_const.dtype() {
46 DType::Binary(_) => rhs_const
47 .as_binary()
48 .is_empty()
49 .vortex_expect("RHS should not be null"),
50 DType::Utf8(_) => rhs_const
51 .as_utf8()
52 .is_empty()
53 .vortex_expect("RHS should not be null"),
54 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
55 };
56
57 if rhs_is_empty {
58 let buffer = match operator {
59 CompareOperator::Gte => BitBuffer::new_set(len), CompareOperator::Lt => BitBuffer::new_unset(len), CompareOperator::Eq | CompareOperator::Lte => {
62 let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
63 match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
64 compare_offsets_to_empty::<P>(lhs_offsets, true)
65 })
66 }
67 CompareOperator::NotEq | CompareOperator::Gt => {
68 let lhs_offsets = lhs.offsets().clone().execute::<PrimitiveArray>(ctx)?;
69 match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
70 compare_offsets_to_empty::<P>(lhs_offsets, false)
71 })
72 }
73 };
74
75 return Ok(Some(
76 BoolArray::new(
77 buffer,
78 lhs.validity()
79 .clone()
80 .union_nullability(rhs.dtype().nullability()),
81 )
82 .into_array(),
83 ));
84 }
85
86 let lhs = Datum::try_new(&lhs.clone().into_array())?;
87
88 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
91 DType::Utf8(_) => &rhs_const
92 .as_utf8()
93 .value()
94 .map(StringArray::new_scalar)
95 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
96 DType::Binary(_) => &rhs_const
97 .as_binary()
98 .value()
99 .map(BinaryArray::new_scalar)
100 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
101 _ => vortex_bail!(
102 "VarBin array RHS can only be Utf8 or Binary, given {}",
103 rhs_const.dtype()
104 ),
105 };
106
107 let array = match operator {
108 CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
109 CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
110 CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
111 CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
112 CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
113 CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
114 }
115 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
116
117 Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
118 } else if !rhs.is::<VarBinVTable>() {
119 return Ok(Some(
123 lhs.clone()
124 .into_array()
125 .execute::<VarBinViewArray>(ctx)?
126 .into_array()
127 .binary(rhs.to_array(), Operator::from(operator))?,
128 ));
129 } else {
130 Ok(None)
131 }
132 }
133}
134
135fn compare_offsets_to_empty<P: IntegerPType>(offsets: PrimitiveArray, eq: bool) -> BitBuffer {
136 let fn_ = if eq { P::eq } else { P::ne };
137 let offsets = offsets.as_slice::<P>();
138 BitBuffer::collect_bool(offsets.len() - 1, |idx| {
139 let left = unsafe { offsets.get_unchecked(idx) };
140 let right = unsafe { offsets.get_unchecked(idx + 1) };
141 fn_(left, right)
142 })
143}
144
145#[cfg(test)]
146mod test {
147 use vortex_buffer::BitBuffer;
148 use vortex_buffer::ByteBuffer;
149
150 use crate::IntoArray;
151 use crate::ToCanonical;
152 use crate::arrays::ConstantArray;
153 use crate::arrays::VarBinArray;
154 use crate::arrays::VarBinViewArray;
155 use crate::builtins::ArrayBuiltins;
156 use crate::dtype::DType;
157 use crate::dtype::Nullability;
158 use crate::scalar::Scalar;
159 use crate::scalar_fn::fns::operators::Operator;
160
161 #[test]
162 fn test_binary_compare() {
163 let array = VarBinArray::from_iter(
164 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
165 DType::Binary(Nullability::Nullable),
166 );
167 let result = array
168 .into_array()
169 .binary(
170 ConstantArray::new(
171 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
172 3,
173 )
174 .into_array(),
175 Operator::Eq,
176 )
177 .unwrap()
178 .to_bool();
179
180 assert_eq!(
181 &result.validity_mask().unwrap().to_bit_buffer(),
182 &BitBuffer::from_iter([true, false, true])
183 );
184 assert_eq!(
185 result.to_bit_buffer(),
186 BitBuffer::from_iter([true, false, false])
187 );
188 }
189
190 #[test]
191 fn varbinview_compare() {
192 let array = VarBinArray::from_iter(
193 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
194 DType::Binary(Nullability::Nullable),
195 );
196 let vbv = VarBinViewArray::from_iter(
197 [None, None, Some(b"def".to_vec())],
198 DType::Binary(Nullability::Nullable),
199 );
200 let result = array
201 .into_array()
202 .binary(vbv.into_array(), Operator::Eq)
203 .unwrap()
204 .to_bool();
205
206 assert_eq!(
207 result.validity_mask().unwrap().to_bit_buffer(),
208 BitBuffer::from_iter([false, false, true])
209 );
210 assert_eq!(
211 result.to_bit_buffer(),
212 BitBuffer::from_iter([false, true, true])
213 );
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use crate::DynArray;
220 use crate::IntoArray;
221 use crate::arrays::ConstantArray;
222 use crate::arrays::VarBinArray;
223 use crate::builtins::ArrayBuiltins;
224 use crate::dtype::DType;
225 use crate::dtype::Nullability;
226 use crate::scalar::Scalar;
227 use crate::scalar_fn::fns::operators::Operator;
228
229 #[test]
230 fn test_null_compare() {
231 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
232
233 let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
234
235 assert_eq!(
236 arr.into_array()
237 .binary(const_.into_array(), Operator::Eq)
238 .unwrap()
239 .dtype(),
240 &DType::Bool(Nullability::Nullable)
241 );
242 }
243}