vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::BinaryArray;
5use arrow_array::StringArray;
6use arrow_ord::cmp;
7use itertools::Itertools;
8use vortex_buffer::BitBuffer;
9use vortex_error::VortexExpect as _;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::BoolArray;
20use crate::arrays::PrimitiveArray;
21use crate::arrays::VarBinArray;
22use crate::arrays::VarBinVTable;
23use crate::arrow::Datum;
24use crate::arrow::from_arrow_array_with_len;
25use crate::builtins::ArrayBuiltins;
26use crate::compute::compare_lengths_to_empty;
27use crate::dtype::DType;
28use crate::dtype::IntegerPType;
29use crate::match_each_integer_ptype;
30use crate::scalar_fn::fns::binary::CompareKernel;
31use crate::scalar_fn::fns::operators::CompareOperator;
32use crate::scalar_fn::fns::operators::Operator;
33use crate::vtable::ValidityHelper;
34
35impl CompareKernel for VarBinVTable {
37 fn compare(
38 lhs: &VarBinArray,
39 rhs: &dyn Array,
40 operator: CompareOperator,
41 _ctx: &mut ExecutionCtx,
42 ) -> VortexResult<Option<ArrayRef>> {
43 if let Some(rhs_const) = rhs.as_constant() {
44 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
45 let len = lhs.len();
46
47 let rhs_is_empty = match rhs_const.dtype() {
48 DType::Binary(_) => rhs_const
49 .as_binary()
50 .is_empty()
51 .vortex_expect("RHS should not be null"),
52 DType::Utf8(_) => rhs_const
53 .as_utf8()
54 .is_empty()
55 .vortex_expect("RHS should not be null"),
56 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
57 };
58
59 if rhs_is_empty {
60 let buffer = match operator {
61 CompareOperator::Gte => BitBuffer::new_set(len), CompareOperator::Lt => BitBuffer::new_unset(len), CompareOperator::Eq
64 | CompareOperator::NotEq
65 | CompareOperator::Gt
66 | CompareOperator::Lte => {
67 let lhs_offsets = lhs.offsets().to_primitive();
68 match_each_integer_ptype!(lhs_offsets.ptype(), |P| {
69 compare_offsets_to_empty::<P>(lhs_offsets, operator)
70 })
71 }
72 };
73
74 return Ok(Some(
75 BoolArray::new(
76 buffer,
77 lhs.validity()
78 .clone()
79 .union_nullability(rhs.dtype().nullability()),
80 )
81 .into_array(),
82 ));
83 }
84
85 let lhs = Datum::try_new(lhs.as_ref())?;
86
87 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
90 DType::Utf8(_) => &rhs_const
91 .as_utf8()
92 .value()
93 .map(StringArray::new_scalar)
94 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
95 DType::Binary(_) => &rhs_const
96 .as_binary()
97 .value()
98 .map(BinaryArray::new_scalar)
99 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
100 _ => vortex_bail!(
101 "VarBin array RHS can only be Utf8 or Binary, given {}",
102 rhs_const.dtype()
103 ),
104 };
105
106 let array = match operator {
107 CompareOperator::Eq => cmp::eq(&lhs, arrow_rhs),
108 CompareOperator::NotEq => cmp::neq(&lhs, arrow_rhs),
109 CompareOperator::Gt => cmp::gt(&lhs, arrow_rhs),
110 CompareOperator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
111 CompareOperator::Lt => cmp::lt(&lhs, arrow_rhs),
112 CompareOperator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
113 }
114 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
115
116 Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
117 } else if !rhs.is::<VarBinVTable>() {
118 return Ok(Some(
122 lhs.to_varbinview()
123 .to_array()
124 .binary(rhs.to_array(), Operator::from(operator))?,
125 ));
126 } else {
127 Ok(None)
128 }
129 }
130}
131
132fn compare_offsets_to_empty<P: IntegerPType>(
133 offsets: PrimitiveArray,
134 operator: CompareOperator,
135) -> BitBuffer {
136 let lengths_iter = offsets
137 .as_slice::<P>()
138 .iter()
139 .tuple_windows()
140 .map(|(&s, &e)| e - s);
141 compare_lengths_to_empty(lengths_iter, operator)
142}
143
144#[cfg(test)]
145mod test {
146 use vortex_buffer::BitBuffer;
147 use vortex_buffer::ByteBuffer;
148
149 use crate::ToCanonical;
150 use crate::arrays::ConstantArray;
151 use crate::arrays::VarBinArray;
152 use crate::arrays::VarBinViewArray;
153 use crate::builtins::ArrayBuiltins;
154 use crate::dtype::DType;
155 use crate::dtype::Nullability;
156 use crate::scalar::Scalar;
157 use crate::scalar_fn::fns::operators::Operator;
158
159 #[test]
160 fn test_binary_compare() {
161 let array = VarBinArray::from_iter(
162 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
163 DType::Binary(Nullability::Nullable),
164 );
165 let result = array
166 .to_array()
167 .binary(
168 ConstantArray::new(
169 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
170 3,
171 )
172 .to_array(),
173 Operator::Eq,
174 )
175 .unwrap()
176 .to_bool();
177
178 assert_eq!(
179 &result.validity_mask().unwrap().to_bit_buffer(),
180 &BitBuffer::from_iter([true, false, true])
181 );
182 assert_eq!(
183 result.to_bit_buffer(),
184 BitBuffer::from_iter([true, false, false])
185 );
186 }
187
188 #[test]
189 fn varbinview_compare() {
190 let array = VarBinArray::from_iter(
191 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
192 DType::Binary(Nullability::Nullable),
193 );
194 let vbv = VarBinViewArray::from_iter(
195 [None, None, Some(b"def".to_vec())],
196 DType::Binary(Nullability::Nullable),
197 );
198 let result = array
199 .to_array()
200 .binary(vbv.to_array(), Operator::Eq)
201 .unwrap()
202 .to_bool();
203
204 assert_eq!(
205 result.validity_mask().unwrap().to_bit_buffer(),
206 BitBuffer::from_iter([false, false, true])
207 );
208 assert_eq!(
209 result.to_bit_buffer(),
210 BitBuffer::from_iter([false, true, true])
211 );
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use crate::Array;
218 use crate::arrays::ConstantArray;
219 use crate::arrays::VarBinArray;
220 use crate::builtins::ArrayBuiltins;
221 use crate::dtype::DType;
222 use crate::dtype::Nullability;
223 use crate::scalar::Scalar;
224 use crate::scalar_fn::fns::operators::Operator;
225
226 #[test]
227 fn test_null_compare() {
228 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
229
230 let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
231
232 assert_eq!(
233 arr.to_array()
234 .binary(const_.to_array(), Operator::Eq)
235 .unwrap()
236 .dtype(),
237 &DType::Bool(Nullability::Nullable)
238 );
239 }
240}