1use core::fmt;
2use std::fmt::{Display, Formatter};
3
4use arrow_buffer::BooleanBuffer;
5use arrow_ord::cmp;
6use arrow_schema::DataType;
7use vortex_dtype::{DType, NativePType, Nullability};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::Scalar;
10
11use crate::arrays::ConstantArray;
12use crate::arrow::{Datum, from_arrow_array_with_len};
13use crate::encoding::Encoding;
14use crate::{Array, ArrayRef, Canonical, IntoArray};
15
16#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd)]
17pub enum Operator {
18 Eq,
19 NotEq,
20 Gt,
21 Gte,
22 Lt,
23 Lte,
24}
25
26impl Display for Operator {
27 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
28 let display = match &self {
29 Operator::Eq => "=",
30 Operator::NotEq => "!=",
31 Operator::Gt => ">",
32 Operator::Gte => ">=",
33 Operator::Lt => "<",
34 Operator::Lte => "<=",
35 };
36 Display::fmt(display, f)
37 }
38}
39
40impl Operator {
41 pub fn inverse(self) -> Self {
42 match self {
43 Operator::Eq => Operator::NotEq,
44 Operator::NotEq => Operator::Eq,
45 Operator::Gt => Operator::Lte,
46 Operator::Gte => Operator::Lt,
47 Operator::Lt => Operator::Gte,
48 Operator::Lte => Operator::Gt,
49 }
50 }
51
52 pub fn swap(self) -> Self {
54 match self {
55 Operator::Eq => Operator::Eq,
56 Operator::NotEq => Operator::NotEq,
57 Operator::Gt => Operator::Lt,
58 Operator::Gte => Operator::Lte,
59 Operator::Lt => Operator::Gt,
60 Operator::Lte => Operator::Gte,
61 }
62 }
63}
64
65pub trait CompareFn<A> {
66 fn compare(
69 &self,
70 lhs: A,
71 rhs: &dyn Array,
72 operator: Operator,
73 ) -> VortexResult<Option<ArrayRef>>;
74}
75
76impl<E: Encoding> CompareFn<&dyn Array> for E
77where
78 E: for<'a> CompareFn<&'a E::Array>,
79{
80 fn compare(
81 &self,
82 lhs: &dyn Array,
83 rhs: &dyn Array,
84 operator: Operator,
85 ) -> VortexResult<Option<ArrayRef>> {
86 let array_ref = lhs
87 .as_any()
88 .downcast_ref::<E::Array>()
89 .vortex_expect("Failed to downcast array");
90
91 CompareFn::compare(self, array_ref, rhs, operator)
92 }
93}
94
95pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult<ArrayRef> {
96 if left.len() != right.len() {
97 vortex_bail!("Compare operations only support arrays of the same length");
98 }
99 if !left.dtype().eq_ignore_nullability(right.dtype()) {
100 vortex_bail!(
101 "Cannot compare different DTypes {} and {}",
102 left.dtype(),
103 right.dtype()
104 );
105 }
106
107 if left.dtype().is_struct() {
109 vortex_bail!(
110 "Compare does not support arrays with Struct DType, got: {} and {}",
111 left.dtype(),
112 right.dtype()
113 )
114 }
115
116 let result_dtype =
117 DType::Bool((left.dtype().is_nullable() || right.dtype().is_nullable()).into());
118
119 if left.is_empty() {
120 return Ok(Canonical::empty(&result_dtype).into_array());
121 }
122
123 let left_constant_null = left.as_constant().map(|l| l.is_null()).unwrap_or(false);
124 let right_constant_null = right.as_constant().map(|r| r.is_null()).unwrap_or(false);
125 if left_constant_null || right_constant_null {
126 return Ok(ConstantArray::new(Scalar::null(result_dtype), left.len()).into_array());
127 }
128
129 let right_is_constant = right.is_constant();
130
131 if left.is_constant() && !right_is_constant {
133 return compare(right, left, operator.swap());
134 }
135
136 if let Some(result) = left
137 .vtable()
138 .compare_fn()
139 .and_then(|f| f.compare(left, right, operator).transpose())
140 .transpose()?
141 {
142 check_compare_result(&result, left, right);
143 return Ok(result);
144 }
145
146 if let Some(result) = right
147 .vtable()
148 .compare_fn()
149 .and_then(|f| f.compare(right, left, operator.swap()).transpose())
150 .transpose()?
151 {
152 check_compare_result(&result, left, right);
153 return Ok(result);
154 }
155
156 if !(left.is_arrow() && (right.is_arrow() || right_is_constant)) {
159 log::debug!(
160 "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)",
161 right.encoding(),
162 left.encoding(),
163 operator.swap(),
164 );
165 }
166
167 let result = arrow_compare(left, right, operator)?;
169 check_compare_result(&result, left, right);
170 Ok(result)
171}
172
173pub fn compare_lengths_to_empty<P, I>(lengths: I, op: Operator) -> BooleanBuffer
176where
177 P: NativePType,
178 I: Iterator<Item = P>,
179{
180 let cmp_fn = match op {
182 Operator::Eq | Operator::Lte => |v| v == P::zero(),
183 Operator::NotEq | Operator::Gt => |v| v != P::zero(),
184 Operator::Gte => |_| true,
185 Operator::Lt => |_| false,
186 };
187
188 lengths.map(cmp_fn).collect::<BooleanBuffer>()
189}
190
191fn arrow_compare(
193 left: &dyn Array,
194 right: &dyn Array,
195 operator: Operator,
196) -> VortexResult<ArrayRef> {
197 let nullable = left.dtype().is_nullable() || right.dtype().is_nullable();
198 let lhs = datum_for_cmp(left)?;
199 let rhs = datum_for_cmp(right)?;
200
201 let array = match operator {
202 Operator::Eq => cmp::eq(&lhs, &rhs)?,
203 Operator::NotEq => cmp::neq(&lhs, &rhs)?,
204 Operator::Gt => cmp::gt(&lhs, &rhs)?,
205 Operator::Gte => cmp::gt_eq(&lhs, &rhs)?,
206 Operator::Lt => cmp::lt(&lhs, &rhs)?,
207 Operator::Lte => cmp::lt_eq(&lhs, &rhs)?,
208 };
209 from_arrow_array_with_len(&array, left.len(), nullable)
210}
211
212#[inline(always)]
213fn check_compare_result(result: &dyn Array, lhs: &dyn Array, rhs: &dyn Array) {
214 assert_eq!(
215 result.len(),
216 lhs.len(),
217 "CompareFn result length ({}) mismatch for left encoding {}, left len {}, right encoding {}, right len {}",
218 result.len(),
219 lhs.encoding(),
220 lhs.len(),
221 rhs.encoding(),
222 rhs.len()
223 );
224 assert_eq!(
225 result.dtype(),
226 &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
227 "CompareFn result dtype ({}) mismatch for left encoding {}, right encoding {}",
228 result.dtype(),
229 lhs.encoding(),
230 rhs.encoding(),
231 );
232}
233
234pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar {
235 if lhs.is_null() | rhs.is_null() {
236 Scalar::null(DType::Bool(Nullability::Nullable))
237 } else {
238 let b = match operator {
239 Operator::Eq => lhs == rhs,
240 Operator::NotEq => lhs != rhs,
241 Operator::Gt => lhs > rhs,
242 Operator::Gte => lhs >= rhs,
243 Operator::Lt => lhs < rhs,
244 Operator::Lte => lhs <= rhs,
245 };
246
247 Scalar::bool(
248 b,
249 (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(),
250 )
251 }
252}
253
254fn datum_for_cmp(array: &dyn Array) -> VortexResult<Datum> {
256 if matches!(array.dtype(), DType::Utf8(_)) {
257 Datum::with_target_datatype(array, &DataType::Utf8View)
258 } else if matches!(array.dtype(), DType::Binary(_)) {
259 Datum::with_target_datatype(array, &DataType::BinaryView)
260 } else {
261 Datum::try_new(array)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use arrow_buffer::BooleanBuffer;
268 use itertools::Itertools;
269 use rstest::rstest;
270
271 use super::*;
272 use crate::ToCanonical;
273 use crate::arrays::{BoolArray, ConstantArray, VarBinArray, VarBinViewArray};
274 use crate::validity::Validity;
275
276 fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
277 let buffer = indices_bits.boolean_buffer();
278 let mask = indices_bits.validity_mask().unwrap();
279 buffer
280 .iter()
281 .enumerate()
282 .filter_map(|(idx, v)| (v && mask.value(idx)).then_some(idx as u64))
283 .collect_vec()
284 }
285
286 #[test]
287 fn test_bool_basic_comparisons() {
288 let arr = BoolArray::new(
289 BooleanBuffer::from_iter([true, true, false, true, false]),
290 Validity::from_iter([false, true, true, true, true]),
291 );
292
293 let matches = compare(&arr, &arr, Operator::Eq)
294 .unwrap()
295 .to_bool()
296 .unwrap();
297
298 assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);
299
300 let matches = compare(&arr, &arr, Operator::NotEq)
301 .unwrap()
302 .to_bool()
303 .unwrap();
304 let empty: [u64; 0] = [];
305 assert_eq!(to_int_indices(matches), empty);
306
307 let other = BoolArray::new(
308 BooleanBuffer::from_iter([false, false, false, true, true]),
309 Validity::from_iter([false, true, true, true, true]),
310 );
311
312 let matches = compare(&arr, &other, Operator::Lte)
313 .unwrap()
314 .to_bool()
315 .unwrap();
316 assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
317
318 let matches = compare(&arr, &other, Operator::Lt)
319 .unwrap()
320 .to_bool()
321 .unwrap();
322 assert_eq!(to_int_indices(matches), [4u64]);
323
324 let matches = compare(&other, &arr, Operator::Gte)
325 .unwrap()
326 .to_bool()
327 .unwrap();
328 assert_eq!(to_int_indices(matches), [2u64, 3, 4]);
329
330 let matches = compare(&other, &arr, Operator::Gt)
331 .unwrap()
332 .to_bool()
333 .unwrap();
334 assert_eq!(to_int_indices(matches), [4u64]);
335 }
336
337 #[test]
338 fn constant_compare() {
339 let left = ConstantArray::new(Scalar::from(2u32), 10);
340 let right = ConstantArray::new(Scalar::from(10u32), 10);
341
342 let compare = compare(&left, &right, Operator::Gt).unwrap();
343 let res = compare.as_constant().unwrap();
344 assert_eq!(res.as_bool().value(), Some(false));
345 assert_eq!(compare.len(), 10);
346
347 let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap();
348 let res = compare.as_constant().unwrap();
349 assert_eq!(res.as_bool().value(), Some(false));
350 assert_eq!(compare.len(), 10);
351 }
352
353 #[rstest]
354 #[case(Operator::Eq, vec![false, false, false, true])]
355 #[case(Operator::NotEq, vec![true, true, true, false])]
356 #[case(Operator::Gt, vec![true, true, true, false])]
357 #[case(Operator::Gte, vec![true, true, true, true])]
358 #[case(Operator::Lt, vec![false, false, false, false])]
359 #[case(Operator::Lte, vec![false, false, false, true])]
360 fn test_cmp_to_empty(#[case] op: Operator, #[case] expected: Vec<bool>) {
361 let lengths: Vec<i32> = vec![1, 5, 7, 0];
362
363 let output = compare_lengths_to_empty(lengths.iter().copied(), op);
364 assert_eq!(Vec::from_iter(output.iter()), expected);
365 }
366
367 #[rstest]
368 #[case(VarBinArray::from(vec!["a", "b"]).into_array(), VarBinViewArray::from_iter_str(["a", "b"]).into_array())]
369 #[case(VarBinViewArray::from_iter_str(["a", "b"]).into_array(), VarBinArray::from(vec!["a", "b"]).into_array())]
370 #[case(VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array())]
371 #[case(VarBinViewArray::from_iter_bin(["a".as_bytes(), "b".as_bytes()]).into_array(), VarBinArray::from(vec!["a".as_bytes(), "b".as_bytes()]).into_array())]
372 fn arrow_compare_different_encodings(#[case] left: ArrayRef, #[case] right: ArrayRef) {
373 let res = arrow_compare(&left, &right, Operator::Eq).unwrap();
374 assert_eq!(
375 res.to_bool().unwrap().boolean_buffer().count_set_bits(),
376 left.len()
377 );
378 }
379}