vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::{BinaryArray, StringArray};
5use arrow_buffer::BooleanBuffer;
6use arrow_ord::cmp;
7use itertools::Itertools;
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
10
11use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable};
12use crate::arrow::{Datum, from_arrow_array_with_len};
13use crate::compute::{
14 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
15};
16use crate::vtable::ValidityHelper;
17use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
18
19impl CompareKernel for VarBinVTable {
21 fn compare(
22 &self,
23 lhs: &VarBinArray,
24 rhs: &dyn Array,
25 operator: Operator,
26 ) -> VortexResult<Option<ArrayRef>> {
27 if let Some(rhs_const) = rhs.as_constant() {
28 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
29 let len = lhs.len();
30
31 let rhs_is_empty = match rhs_const.dtype() {
32 DType::Binary(_) => rhs_const
33 .as_binary()
34 .is_empty()
35 .vortex_expect("RHS should not be null"),
36 DType::Utf8(_) => rhs_const
37 .as_utf8()
38 .is_empty()
39 .vortex_expect("RHS should not be null"),
40 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
41 };
42
43 if rhs_is_empty {
44 let buffer = match operator {
45 Operator::Gte => BooleanBuffer::new_set(len), Operator::Lt => BooleanBuffer::new_unset(len), Operator::Eq | Operator::NotEq | Operator::Gt | Operator::Lte => {
48 let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
49 match_each_native_ptype!(lhs_offsets.ptype(), |P| {
50 compare_offsets_to_empty::<P>(lhs_offsets, operator)
51 })
52 }
53 };
54
55 return Ok(Some(
56 BoolArray::new(
57 buffer,
58 lhs.validity()
59 .clone()
60 .union_nullability(rhs.dtype().nullability()),
61 )
62 .into_array(),
63 ));
64 }
65
66 let lhs = Datum::try_new(lhs.as_ref())?;
67
68 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
70 DType::Utf8(_) => &rhs_const
71 .as_utf8()
72 .value()
73 .map(StringArray::new_scalar)
74 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
75 DType::Binary(_) => &rhs_const
76 .as_binary()
77 .value()
78 .map(BinaryArray::new_scalar)
79 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
80 _ => vortex_bail!(
81 "VarBin array RHS can only be Utf8 or Binary, given {}",
82 rhs_const.dtype()
83 ),
84 };
85
86 let array = match operator {
87 Operator::Eq => cmp::eq(&lhs, arrow_rhs),
88 Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
89 Operator::Gt => cmp::gt(&lhs, arrow_rhs),
90 Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
91 Operator::Lt => cmp::lt(&lhs, arrow_rhs),
92 Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
93 }
94 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
95
96 Ok(Some(from_arrow_array_with_len(&array, len, nullable)))
97 } else if !rhs.is::<VarBinVTable>() {
98 return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
102 } else {
103 Ok(None)
104 }
105 }
106}
107
108register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
109
110fn compare_offsets_to_empty<P: NativePType>(
111 offsets: PrimitiveArray,
112 operator: Operator,
113) -> BooleanBuffer {
114 let lengths_iter = offsets
115 .as_slice::<P>()
116 .iter()
117 .tuple_windows()
118 .map(|(&s, &e)| e - s);
119 compare_lengths_to_empty(lengths_iter, operator)
120}
121
122#[cfg(test)]
123mod test {
124 use arrow_buffer::BooleanBuffer;
125 use vortex_buffer::ByteBuffer;
126 use vortex_dtype::{DType, Nullability};
127 use vortex_scalar::Scalar;
128
129 use crate::ToCanonical;
130 use crate::arrays::{ConstantArray, VarBinArray, VarBinViewArray};
131 use crate::compute::{Operator, compare};
132
133 #[test]
134 fn test_binary_compare() {
135 let array = VarBinArray::from_iter(
136 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
137 DType::Binary(Nullability::Nullable),
138 );
139 let result = compare(
140 array.as_ref(),
141 ConstantArray::new(
142 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
143 3,
144 )
145 .as_ref(),
146 Operator::Eq,
147 )
148 .unwrap()
149 .to_bool()
150 .unwrap();
151
152 assert_eq!(
153 &result.validity_mask().unwrap().to_boolean_buffer(),
154 &BooleanBuffer::from_iter([true, false, true])
155 );
156 assert_eq!(
157 result.boolean_buffer(),
158 &BooleanBuffer::from_iter([true, false, false])
159 );
160 }
161
162 #[test]
163 fn varbinview_compare() {
164 let array = VarBinArray::from_iter(
165 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
166 DType::Binary(Nullability::Nullable),
167 );
168 let vbv = VarBinViewArray::from_iter(
169 [None, None, Some(b"def".to_vec())],
170 DType::Binary(Nullability::Nullable),
171 );
172 let result = compare(array.as_ref(), vbv.as_ref(), Operator::Eq)
173 .unwrap()
174 .to_bool()
175 .unwrap();
176
177 assert_eq!(
178 &result.validity_mask().unwrap().to_boolean_buffer(),
179 &BooleanBuffer::from_iter([false, false, true])
180 );
181 assert_eq!(
182 result.boolean_buffer(),
183 &BooleanBuffer::from_iter([false, true, true])
184 );
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use vortex_dtype::{DType, Nullability};
191 use vortex_scalar::Scalar;
192
193 use crate::Array;
194 use crate::arrays::{ConstantArray, VarBinArray};
195 use crate::compute::{Operator, compare};
196
197 #[test]
198 fn test_null_compare() {
199 let arr = VarBinArray::from_iter([Some("h")], DType::Utf8(Nullability::NonNullable));
200
201 let const_ = ConstantArray::new(Scalar::utf8("", Nullability::Nullable), 1);
202
203 assert_eq!(
204 compare(arr.as_ref(), const_.as_ref(), Operator::Eq)
205 .unwrap()
206 .dtype(),
207 &DType::Bool(Nullability::Nullable)
208 );
209 }
210}