vortex_array/arrays/varbin/compute/
compare.rs1use arrow_array::{BinaryArray, StringArray};
2use arrow_buffer::BooleanBuffer;
3use arrow_ord::cmp;
4use itertools::Itertools;
5use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::{BoolArray, PrimitiveArray, VarBinArray, VarBinVTable, VarBinViewVTable};
9use crate::arrow::{Datum, from_arrow_array_with_len};
10use crate::compute::{
11 CompareKernel, CompareKernelAdapter, Operator, compare, compare_lengths_to_empty,
12};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16impl CompareKernel for VarBinVTable {
18 fn compare(
19 &self,
20 lhs: &VarBinArray,
21 rhs: &dyn Array,
22 operator: Operator,
23 ) -> VortexResult<Option<ArrayRef>> {
24 if let Some(rhs_const) = rhs.as_constant() {
25 let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable();
26 let len = lhs.len();
27
28 let rhs_is_empty = match rhs_const.dtype() {
29 DType::Binary(_) => rhs_const
30 .as_binary()
31 .is_empty()
32 .vortex_expect("RHS should not be null"),
33 DType::Utf8(_) => rhs_const
34 .as_utf8()
35 .is_empty()
36 .vortex_expect("RHS should not be null"),
37 _ => vortex_bail!("VarBinArray can only have type of Binary or Utf8"),
38 };
39
40 if rhs_is_empty {
41 let buffer = match operator {
42 Operator::Gte => BooleanBuffer::new_set(len),
44 Operator::Lt => BooleanBuffer::new_unset(len),
46 _ => {
47 let lhs_offsets = lhs.offsets().to_canonical()?.into_primitive()?;
48 match_each_native_ptype!(lhs_offsets.ptype(), |$P| {
49 compare_offsets_to_empty::<$P>(lhs_offsets, operator)
50 })
51 }
52 };
53
54 return Ok(Some(
55 BoolArray::new(buffer, lhs.validity().clone()).into_array(),
56 ));
57 }
58
59 let lhs = Datum::try_new(lhs.as_ref())?;
60
61 let arrow_rhs: &dyn arrow_array::Datum = match rhs_const.dtype() {
63 DType::Utf8(_) => &rhs_const
64 .as_utf8()
65 .value()
66 .map(StringArray::new_scalar)
67 .unwrap_or_else(|| arrow_array::Scalar::new(StringArray::new_null(1))),
68 DType::Binary(_) => &rhs_const
69 .as_binary()
70 .value()
71 .map(BinaryArray::new_scalar)
72 .unwrap_or_else(|| arrow_array::Scalar::new(BinaryArray::new_null(1))),
73 _ => vortex_bail!(
74 "VarBin array RHS can only be Utf8 or Binary, given {}",
75 rhs_const.dtype()
76 ),
77 };
78
79 let array = match operator {
80 Operator::Eq => cmp::eq(&lhs, arrow_rhs),
81 Operator::NotEq => cmp::neq(&lhs, arrow_rhs),
82 Operator::Gt => cmp::gt(&lhs, arrow_rhs),
83 Operator::Gte => cmp::gt_eq(&lhs, arrow_rhs),
84 Operator::Lt => cmp::lt(&lhs, arrow_rhs),
85 Operator::Lte => cmp::lt_eq(&lhs, arrow_rhs),
86 }
87 .map_err(|err| vortex_err!("Failed to compare VarBin array: {}", err))?;
88
89 Ok(Some(from_arrow_array_with_len(&array, len, nullable)?))
90 } else if rhs.is::<VarBinViewVTable>() {
91 return Ok(Some(compare(lhs.to_varbinview()?.as_ref(), rhs, operator)?));
94 } else {
95 Ok(None)
96 }
97 }
98}
99
100register_kernel!(CompareKernelAdapter(VarBinVTable).lift());
101
102fn compare_offsets_to_empty<P: NativePType>(
103 offsets: PrimitiveArray,
104 operator: Operator,
105) -> BooleanBuffer {
106 let lengths_iter = offsets
107 .as_slice::<P>()
108 .iter()
109 .tuple_windows()
110 .map(|(&s, &e)| e - s);
111 compare_lengths_to_empty(lengths_iter, operator)
112}
113
114#[cfg(test)]
115mod test {
116 use arrow_buffer::BooleanBuffer;
117 use vortex_buffer::ByteBuffer;
118 use vortex_dtype::{DType, Nullability};
119 use vortex_scalar::Scalar;
120
121 use crate::ToCanonical;
122 use crate::arrays::{ConstantArray, VarBinArray};
123 use crate::compute::{Operator, compare};
124
125 #[test]
126 fn test_binary_compare() {
127 let array = VarBinArray::from_iter(
128 [Some(b"abc".to_vec()), None, Some(b"def".to_vec())],
129 DType::Binary(Nullability::Nullable),
130 );
131 let result = compare(
132 array.as_ref(),
133 ConstantArray::new(
134 Scalar::binary(ByteBuffer::copy_from(b"abc"), Nullability::Nullable),
135 3,
136 )
137 .as_ref(),
138 Operator::Eq,
139 )
140 .unwrap()
141 .to_bool()
142 .unwrap();
143
144 assert_eq!(
145 result.boolean_buffer(),
146 &BooleanBuffer::from_iter([true, false, false])
147 );
148 }
149}